dns_resolver/util/
nameserver.rs1use rand::Rng;
2use std::cmp::Ordering;
3use std::net::SocketAddr;
4use std::time::Duration;
5use tokio::net::{TcpStream, UdpSocket};
6use tokio::time::timeout;
7
8use dns_types::protocol::types::*;
9
10use crate::util::net::{read_tcp_bytes, send_tcp_bytes, send_udp_bytes};
11
12pub async fn query_nameserver(
21 address: SocketAddr,
22 question: Question,
23 recursion_desired: bool,
24) -> Option<Message> {
25 let mut request = Message::from_question(rand::rng().random(), question);
26 request.header.recursion_desired = recursion_desired;
27
28 match request.to_octets() {
29 Ok(mut serialised_request) => {
30 tracing::trace!(message = ?request, ?address, "forwarding query to nameserver");
31
32 if let Some(response) = query_nameserver_udp(address, &mut serialised_request).await {
33 if response_matches_request(&request, &response) {
34 return Some(response);
35 }
36 }
37
38 if let Some(response) = query_nameserver_tcp(address, &mut serialised_request).await {
39 if response_matches_request(&request, &response) {
40 return Some(response);
41 }
42 }
43
44 None
45 }
46 Err(error) => {
47 tracing::warn!(message = ?request, ?error, "could not serialise message");
48 None
49 }
50 }
51}
52
53async fn query_nameserver_udp(
61 address: SocketAddr,
62 serialised_request: &mut [u8],
63) -> Option<Message> {
64 timeout(
65 Duration::from_secs(5),
66 query_nameserver_udp_notimeout(address, serialised_request),
67 )
68 .await
69 .unwrap_or_default()
70}
71
72async fn query_nameserver_udp_notimeout(
74 address: SocketAddr,
75 serialised_request: &mut [u8],
76) -> Option<Message> {
77 if serialised_request.len() > 512 {
78 return None;
79 }
80
81 let mut buf = vec![0u8; 512];
82 let sock = UdpSocket::bind("0.0.0.0:0").await.ok()?;
83 sock.connect(address).await.ok()?;
84 send_udp_bytes(&sock, serialised_request).await.ok()?;
85 sock.recv(&mut buf).await.ok()?;
86
87 Message::from_octets(&buf).ok()
88}
89
90async fn query_nameserver_tcp(
96 address: SocketAddr,
97 serialised_request: &mut [u8],
98) -> Option<Message> {
99 timeout(
100 Duration::from_secs(5),
101 query_nameserver_tcp_notimeout(address, serialised_request),
102 )
103 .await
104 .unwrap_or_default()
105}
106
107async fn query_nameserver_tcp_notimeout(
109 address: SocketAddr,
110 serialised_request: &mut [u8],
111) -> Option<Message> {
112 let mut stream = TcpStream::connect(address).await.ok()?;
113 send_tcp_bytes(&mut stream, serialised_request).await.ok()?;
114 let bytes = read_tcp_bytes(&mut stream).await.ok()?;
115
116 Message::from_octets(bytes.as_ref()).ok()
117}
118
119fn response_matches_request(request: &Message, response: &Message) -> bool {
130 if request.header.id != response.header.id {
131 return false;
132 }
133 if !response.header.is_response {
134 return false;
135 }
136 if request.header.opcode != response.header.opcode {
137 return false;
138 }
139 if response.header.is_truncated {
140 return false;
141 }
142 if !(response.header.rcode == Rcode::NoError || response.header.rcode == Rcode::NameError) {
143 return false;
144 }
145 if request.questions != response.questions {
146 return false;
147 }
148
149 true
150}
151
152pub fn get_nxdomain_nodata_soa<'a>(
158 question: &Question,
159 response: &'a Message,
160 current_match_count: usize,
161) -> Option<&'a ResourceRecord> {
162 if !response.answers.is_empty() {
163 return None;
164 }
165 if !(response.header.rcode == Rcode::NameError || response.header.rcode == Rcode::NoError) {
166 return None;
167 }
168
169 let mut soa_rr = None;
170 for rr in &response.authority {
171 if rr.rtype_with_data.rtype() == RecordType::SOA {
172 if soa_rr.is_some() {
174 return None;
175 }
176
177 soa_rr = Some(rr);
178 }
179 }
180
181 if let Some(rr) = soa_rr {
182 if !question.name.is_subdomain_of(&rr.name) {
183 return None;
184 }
185
186 if rr.name.labels.len().cmp(¤t_match_count) == Ordering::Less {
187 return None;
188 }
189
190 return Some(rr);
191 }
192
193 None
194}
195
196#[cfg(test)]
197mod tests {
198 use super::test_util::*;
199 use super::*;
200
201 #[test]
202 fn response_matches_request_accepts() {
203 let (request, response) = matching_nameserver_response();
204
205 assert!(response_matches_request(&request, &response));
206 }
207
208 #[test]
209 fn response_matches_request_checks_id() {
210 let (request, mut response) = matching_nameserver_response();
211 response.header.id += 1;
212
213 assert!(!response_matches_request(&request, &response));
214 }
215
216 #[test]
217 fn response_matches_request_checks_qr() {
218 let (request, mut response) = matching_nameserver_response();
219 response.header.is_response = false;
220
221 assert!(!response_matches_request(&request, &response));
222 }
223
224 #[test]
225 fn response_matches_request_checks_opcode() {
226 let (request, mut response) = matching_nameserver_response();
227 response.header.opcode = Opcode::Status;
228
229 assert!(!response_matches_request(&request, &response));
230 }
231
232 #[test]
233 fn response_matches_request_does_not_check_aa() {
234 let (request, mut response) = matching_nameserver_response();
235 response.header.is_authoritative = !response.header.is_authoritative;
236
237 assert!(response_matches_request(&request, &response));
238 }
239
240 #[test]
241 fn response_matches_request_checks_tc() {
242 let (request, mut response) = matching_nameserver_response();
243 response.header.is_truncated = true;
244
245 assert!(!response_matches_request(&request, &response));
246 }
247
248 #[test]
249 fn response_matches_request_does_not_check_rd() {
250 let (request, mut response) = matching_nameserver_response();
251 response.header.recursion_desired = !response.header.recursion_desired;
252
253 assert!(response_matches_request(&request, &response));
254 }
255
256 #[test]
257 fn response_matches_request_does_not_check_ra() {
258 let (request, mut response) = matching_nameserver_response();
259 response.header.recursion_available = !response.header.recursion_available;
260
261 assert!(response_matches_request(&request, &response));
262 }
263
264 #[test]
265 fn response_matches_request_checks_rcode() {
266 let (request, mut response) = matching_nameserver_response();
267 response.header.rcode = Rcode::ServerFailure;
268
269 assert!(!response_matches_request(&request, &response));
270 }
271}
272
273#[cfg(test)]
274pub mod test_util {
275 use dns_types::protocol::types::test_util::*;
276 use std::net::Ipv4Addr;
277
278 use super::*;
279
280 pub fn matching_nameserver_response() -> (Message, Message) {
281 nameserver_response(
282 "www.example.com.",
283 &[a_record("www.example.com.", Ipv4Addr::new(1, 1, 1, 1))],
284 &[],
285 &[],
286 )
287 }
288
289 pub fn nameserver_response(
290 name: &str,
291 answers: &[ResourceRecord],
292 authority: &[ResourceRecord],
293 additional: &[ResourceRecord],
294 ) -> (Message, Message) {
295 let request = Message::from_question(
296 1234,
297 Question {
298 name: domain(name),
299 qtype: QueryType::Record(RecordType::A),
300 qclass: QueryClass::Record(RecordClass::IN),
301 },
302 );
303
304 let mut response = request.make_response();
305 response.answers = answers.into();
306 response.authority = authority.into();
307 response.additional = additional.into();
308
309 (request, response)
310 }
311}