dns_resolver/util/
nameserver.rsuse rand::Rng;
use std::cmp::Ordering;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::{TcpStream, UdpSocket};
use tokio::time::timeout;
use dns_types::protocol::types::*;
use crate::util::net::{read_tcp_bytes, send_tcp_bytes, send_udp_bytes};
pub async fn query_nameserver(
address: SocketAddr,
question: Question,
recursion_desired: bool,
) -> Option<Message> {
let mut request = Message::from_question(rand::thread_rng().gen(), question);
request.header.recursion_desired = recursion_desired;
match request.to_octets() {
Ok(mut serialised_request) => {
tracing::trace!(message = ?request, ?address, "forwarding query to nameserver");
if let Some(response) = query_nameserver_udp(address, &mut serialised_request).await {
if response_matches_request(&request, &response) {
return Some(response);
}
}
if let Some(response) = query_nameserver_tcp(address, &mut serialised_request).await {
if response_matches_request(&request, &response) {
return Some(response);
}
}
None
}
Err(error) => {
tracing::warn!(message = ?request, ?error, "could not serialise message");
None
}
}
}
async fn query_nameserver_udp(
address: SocketAddr,
serialised_request: &mut [u8],
) -> Option<Message> {
timeout(
Duration::from_secs(5),
query_nameserver_udp_notimeout(address, serialised_request),
)
.await
.unwrap_or_default()
}
async fn query_nameserver_udp_notimeout(
address: SocketAddr,
serialised_request: &mut [u8],
) -> Option<Message> {
if serialised_request.len() > 512 {
return None;
}
let mut buf = vec![0u8; 512];
let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?;
sock.connect(address).await.ok()?;
send_udp_bytes(&sock, serialised_request).await.ok()?;
sock.recv(&mut buf).await.ok()?;
Message::from_octets(&buf).ok()
}
async fn query_nameserver_tcp(
address: SocketAddr,
serialised_request: &mut [u8],
) -> Option<Message> {
timeout(
Duration::from_secs(5),
query_nameserver_tcp_notimeout(address, serialised_request),
)
.await
.unwrap_or_default()
}
async fn query_nameserver_tcp_notimeout(
address: SocketAddr,
serialised_request: &mut [u8],
) -> Option<Message> {
let mut stream = TcpStream::connect(address).await.ok()?;
send_tcp_bytes(&mut stream, serialised_request).await.ok()?;
let bytes = read_tcp_bytes(&mut stream).await.ok()?;
Message::from_octets(bytes.as_ref()).ok()
}
fn response_matches_request(request: &Message, response: &Message) -> bool {
if request.header.id != response.header.id {
return false;
}
if !response.header.is_response {
return false;
}
if request.header.opcode != response.header.opcode {
return false;
}
if response.header.is_truncated {
return false;
}
if !(response.header.rcode == Rcode::NoError || response.header.rcode == Rcode::NameError) {
return false;
}
if request.questions != response.questions {
return false;
}
true
}
pub fn get_nxdomain_nodata_soa<'a>(
question: &Question,
response: &'a Message,
current_match_count: usize,
) -> Option<&'a ResourceRecord> {
if !response.answers.is_empty() {
return None;
}
if !(response.header.rcode == Rcode::NameError || response.header.rcode == Rcode::NoError) {
return None;
}
let mut soa_rr = None;
for rr in &response.authority {
if rr.rtype_with_data.rtype() == RecordType::SOA {
if soa_rr.is_some() {
return None;
}
soa_rr = Some(rr);
}
}
if let Some(rr) = soa_rr {
if !question.name.is_subdomain_of(&rr.name) {
return None;
}
if rr.name.labels.len().cmp(¤t_match_count) == Ordering::Less {
return None;
}
return Some(rr);
}
None
}
#[cfg(test)]
mod tests {
use super::test_util::*;
use super::*;
#[test]
fn response_matches_request_accepts() {
let (request, response) = matching_nameserver_response();
assert!(response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_checks_id() {
let (request, mut response) = matching_nameserver_response();
response.header.id += 1;
assert!(!response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_checks_qr() {
let (request, mut response) = matching_nameserver_response();
response.header.is_response = false;
assert!(!response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_checks_opcode() {
let (request, mut response) = matching_nameserver_response();
response.header.opcode = Opcode::Status;
assert!(!response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_does_not_check_aa() {
let (request, mut response) = matching_nameserver_response();
response.header.is_authoritative = !response.header.is_authoritative;
assert!(response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_checks_tc() {
let (request, mut response) = matching_nameserver_response();
response.header.is_truncated = true;
assert!(!response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_does_not_check_rd() {
let (request, mut response) = matching_nameserver_response();
response.header.recursion_desired = !response.header.recursion_desired;
assert!(response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_does_not_check_ra() {
let (request, mut response) = matching_nameserver_response();
response.header.recursion_available = !response.header.recursion_available;
assert!(response_matches_request(&request, &response));
}
#[test]
fn response_matches_request_checks_rcode() {
let (request, mut response) = matching_nameserver_response();
response.header.rcode = Rcode::ServerFailure;
assert!(!response_matches_request(&request, &response));
}
}
#[cfg(test)]
pub mod test_util {
use dns_types::protocol::types::test_util::*;
use std::net::Ipv4Addr;
use super::*;
pub fn matching_nameserver_response() -> (Message, Message) {
nameserver_response(
"www.example.com.",
&[a_record("www.example.com.", Ipv4Addr::new(1, 1, 1, 1))],
&[],
&[],
)
}
pub fn nameserver_response(
name: &str,
answers: &[ResourceRecord],
authority: &[ResourceRecord],
additional: &[ResourceRecord],
) -> (Message, Message) {
let request = Message::from_question(
1234,
Question {
name: domain(name),
qtype: QueryType::Record(RecordType::A),
qclass: QueryClass::Record(RecordClass::IN),
},
);
let mut response = request.make_response();
response.answers = answers.into();
response.authority = authority.into();
response.additional = additional.into();
(request, response)
}
}