dns_resolver/util/
nameserver.rs

1use rand::Rng;
2use std::cmp::Ordering;
3use std::net::SocketAddr;
4use std::time::Duration;
5use tokio::net::{TcpStream, UdpSocket};
6use tokio::time::timeout;
7
8use dns_types::protocol::types::*;
9
10use crate::util::net::{read_tcp_bytes, send_tcp_bytes, send_udp_bytes};
11
12/// Send a message to a remote nameserver, preferring UDP if the request is
13/// small enough.  If the request is too large, or if the UDP response is
14/// truncated, tries again using TCP.
15///
16/// If an error occurs while sending the message or receiving the response, or
17/// the response does not match the request, `None` is returned.
18///
19/// This has a 5s timeout for each request, so 10s in total.
20pub async fn query_nameserver(
21    address: SocketAddr,
22    question: Question,
23    recursion_desired: bool,
24) -> Option<Message> {
25    let mut request = Message::from_question(rand::rng().random(), question);
26    request.header.recursion_desired = recursion_desired;
27
28    match request.to_octets() {
29        Ok(mut serialised_request) => {
30            tracing::trace!(message = ?request, ?address, "forwarding query to nameserver");
31
32            if let Some(response) = query_nameserver_udp(address, &mut serialised_request).await {
33                if response_matches_request(&request, &response) {
34                    return Some(response);
35                }
36            }
37
38            if let Some(response) = query_nameserver_tcp(address, &mut serialised_request).await {
39                if response_matches_request(&request, &response) {
40                    return Some(response);
41                }
42            }
43
44            None
45        }
46        Err(error) => {
47            tracing::warn!(message = ?request, ?error, "could not serialise message");
48            None
49        }
50    }
51}
52
53/// Send a message to a remote nameserver over UDP, returning the
54/// response.  If the message would be truncated, or an error occurs
55/// while sending it, `None` is returned.  Otherwise the deserialised
56/// response message is: but this response is NOT validated -
57/// consumers MUST validate the response before using it!
58///
59/// This has a 5s timeout.
60async fn query_nameserver_udp(
61    address: SocketAddr,
62    serialised_request: &mut [u8],
63) -> Option<Message> {
64    timeout(
65        Duration::from_secs(5),
66        query_nameserver_udp_notimeout(address, serialised_request),
67    )
68    .await
69    .unwrap_or_default()
70}
71
72/// Timeout-less version of `query_nameserver_udp`.
73async fn query_nameserver_udp_notimeout(
74    address: SocketAddr,
75    serialised_request: &mut [u8],
76) -> Option<Message> {
77    if serialised_request.len() > 512 {
78        return None;
79    }
80
81    let mut buf = vec![0u8; 512];
82    let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?;
83    sock.connect(address).await.ok()?;
84    send_udp_bytes(&sock, serialised_request).await.ok()?;
85    sock.recv(&mut buf).await.ok()?;
86
87    Message::from_octets(&buf).ok()
88}
89
90/// Send a message to a remote nameserver over TCP, returning the
91/// response.  This has the same return value caveats as
92/// `query_nameserver_udp`.
93///
94/// This has a 5s timeout.
95async fn query_nameserver_tcp(
96    address: SocketAddr,
97    serialised_request: &mut [u8],
98) -> Option<Message> {
99    timeout(
100        Duration::from_secs(5),
101        query_nameserver_tcp_notimeout(address, serialised_request),
102    )
103    .await
104    .unwrap_or_default()
105}
106
107/// Timeout-less version of `query_nameserver_tcp`.
108async fn query_nameserver_tcp_notimeout(
109    address: SocketAddr,
110    serialised_request: &mut [u8],
111) -> Option<Message> {
112    let mut stream = TcpStream::connect(address).await.ok()?;
113    send_tcp_bytes(&mut stream, serialised_request).await.ok()?;
114    let bytes = read_tcp_bytes(&mut stream).await.ok()?;
115
116    Message::from_octets(bytes.as_ref()).ok()
117}
118
119/// Very basic validation that a nameserver response matches a
120/// message:
121///
122/// - Check the ID, opcode, and questions match the question.
123///
124/// - Check it is a response.
125///
126/// - Check the response code is either `NoError` or `NameError`.
127///
128/// - Check it is not truncated.
129fn response_matches_request(request: &Message, response: &Message) -> bool {
130    if request.header.id != response.header.id {
131        return false;
132    }
133    if !response.header.is_response {
134        return false;
135    }
136    if request.header.opcode != response.header.opcode {
137        return false;
138    }
139    if response.header.is_truncated {
140        return false;
141    }
142    if !(response.header.rcode == Rcode::NoError || response.header.rcode == Rcode::NameError) {
143        return false;
144    }
145    if request.questions != response.questions {
146        return false;
147    }
148
149    true
150}
151
152/// Check if this is an NXDOMAIN or NODATA response and return the SOA if so.
153///
154/// Also sanity checks that the SOA record could be authoritative for the query
155/// domain: the domain has to be a subdomain of the SOA, and the SOA has to have
156/// at least the current match count.
157pub fn get_nxdomain_nodata_soa<'a>(
158    question: &Question,
159    response: &'a Message,
160    current_match_count: usize,
161) -> Option<&'a ResourceRecord> {
162    if !response.answers.is_empty() {
163        return None;
164    }
165    if !(response.header.rcode == Rcode::NameError || response.header.rcode == Rcode::NoError) {
166        return None;
167    }
168
169    let mut soa_rr = None;
170    for rr in &response.authority {
171        if rr.rtype_with_data.rtype() == RecordType::SOA {
172            // multiple SOAs: abort, abort!
173            if soa_rr.is_some() {
174                return None;
175            }
176
177            soa_rr = Some(rr);
178        }
179    }
180
181    if let Some(rr) = soa_rr {
182        if !question.name.is_subdomain_of(&rr.name) {
183            return None;
184        }
185
186        if rr.name.labels.len().cmp(&current_match_count) == Ordering::Less {
187            return None;
188        }
189
190        return Some(rr);
191    }
192
193    None
194}
195
196#[cfg(test)]
197mod tests {
198    use super::test_util::*;
199    use super::*;
200
201    #[test]
202    fn response_matches_request_accepts() {
203        let (request, response) = matching_nameserver_response();
204
205        assert!(response_matches_request(&request, &response));
206    }
207
208    #[test]
209    fn response_matches_request_checks_id() {
210        let (request, mut response) = matching_nameserver_response();
211        response.header.id += 1;
212
213        assert!(!response_matches_request(&request, &response));
214    }
215
216    #[test]
217    fn response_matches_request_checks_qr() {
218        let (request, mut response) = matching_nameserver_response();
219        response.header.is_response = false;
220
221        assert!(!response_matches_request(&request, &response));
222    }
223
224    #[test]
225    fn response_matches_request_checks_opcode() {
226        let (request, mut response) = matching_nameserver_response();
227        response.header.opcode = Opcode::Status;
228
229        assert!(!response_matches_request(&request, &response));
230    }
231
232    #[test]
233    fn response_matches_request_does_not_check_aa() {
234        let (request, mut response) = matching_nameserver_response();
235        response.header.is_authoritative = !response.header.is_authoritative;
236
237        assert!(response_matches_request(&request, &response));
238    }
239
240    #[test]
241    fn response_matches_request_checks_tc() {
242        let (request, mut response) = matching_nameserver_response();
243        response.header.is_truncated = true;
244
245        assert!(!response_matches_request(&request, &response));
246    }
247
248    #[test]
249    fn response_matches_request_does_not_check_rd() {
250        let (request, mut response) = matching_nameserver_response();
251        response.header.recursion_desired = !response.header.recursion_desired;
252
253        assert!(response_matches_request(&request, &response));
254    }
255
256    #[test]
257    fn response_matches_request_does_not_check_ra() {
258        let (request, mut response) = matching_nameserver_response();
259        response.header.recursion_available = !response.header.recursion_available;
260
261        assert!(response_matches_request(&request, &response));
262    }
263
264    #[test]
265    fn response_matches_request_checks_rcode() {
266        let (request, mut response) = matching_nameserver_response();
267        response.header.rcode = Rcode::ServerFailure;
268
269        assert!(!response_matches_request(&request, &response));
270    }
271}
272
273#[cfg(test)]
274pub mod test_util {
275    use dns_types::protocol::types::test_util::*;
276    use std::net::Ipv4Addr;
277
278    use super::*;
279
280    pub fn matching_nameserver_response() -> (Message, Message) {
281        nameserver_response(
282            "www.example.com.",
283            &[a_record("www.example.com.", Ipv4Addr::new(1, 1, 1, 1))],
284            &[],
285            &[],
286        )
287    }
288
289    pub fn nameserver_response(
290        name: &str,
291        answers: &[ResourceRecord],
292        authority: &[ResourceRecord],
293        additional: &[ResourceRecord],
294    ) -> (Message, Message) {
295        let request = Message::from_question(
296            1234,
297            Question {
298                name: domain(name),
299                qtype: QueryType::Record(RecordType::A),
300                qclass: QueryClass::Record(RecordClass::IN),
301            },
302        );
303
304        let mut response = request.make_response();
305        response.answers = answers.into();
306        response.authority = authority.into();
307        response.additional = additional.into();
308
309        (request, response)
310    }
311}