dns_resolver/
recursive.rs

1use async_recursion::async_recursion;
2use std::cmp::Ordering;
3use std::collections::{HashMap, HashSet};
4use std::net::IpAddr;
5use std::time::Duration;
6use tokio::time::timeout;
7use tracing::Instrument;
8
9use dns_types::protocol::types::*;
10
11use crate::context::Context;
12use crate::local::{resolve_local, LocalResolutionResult};
13use crate::util::nameserver::*;
14use crate::util::types::*;
15
16pub struct RecursiveContextInner {
17    pub protocol_mode: ProtocolMode,
18    pub upstream_dns_port: u16,
19}
20
21pub type RecursiveContext<'a> = Context<'a, RecursiveContextInner>;
22
23/// Recursive DNS resolution.
24///
25/// This corresponds to the standard resolver algorithm.  If
26/// information is not held locally, it will call out to remote
27/// nameservers, starting with the given root hints.  Since it may
28/// make network requests, this function is async.
29///
30/// This has a 60s timeout.
31///
32/// See section 5.3.3 of RFC 1034.
33///
34/// # Errors
35///
36/// See `ResolutionError`.
37pub async fn resolve_recursive<'a>(
38    context: &mut RecursiveContext<'a>,
39    question: &Question,
40) -> Result<ResolvedRecord, ResolutionError> {
41    if let Ok(res) = timeout(
42        Duration::from_secs(60),
43        resolve_recursive_notimeout(context, question),
44    )
45    .await
46    {
47        res
48    } else {
49        tracing::debug!("timed out");
50        Err(ResolutionError::Timeout)
51    }
52}
53
54/// Timeout-less version of `resolve_recursive`.
55#[async_recursion]
56async fn resolve_recursive_notimeout<'a>(
57    context: &mut RecursiveContext<'a>,
58    question: &Question,
59) -> Result<ResolvedRecord, ResolutionError> {
60    if context.at_recursion_limit() {
61        tracing::debug!("hit recursion limit");
62        return Err(ResolutionError::RecursionLimit);
63    }
64    if context.is_duplicate_question(question) {
65        tracing::debug!("hit duplicate question");
66        return Err(ResolutionError::DuplicateQuestion {
67            question: question.clone(),
68        });
69    }
70
71    let mut candidates = None;
72    let mut combined_rrs = Vec::new();
73
74    match resolve_local(context, question) {
75        Ok(LocalResolutionResult::Done { resolved }) => return Ok(resolved),
76        Ok(LocalResolutionResult::Partial { rrs }) => combined_rrs = rrs,
77        Ok(LocalResolutionResult::Delegation { delegation, .. }) => candidates = Some(delegation),
78        Ok(LocalResolutionResult::CNAME {
79            rrs,
80            cname_question,
81            ..
82        }) => {
83            context.push_question(question);
84            let answer = resolve_combined_recursive(context, rrs, cname_question).await;
85            context.pop_question();
86            return answer;
87        }
88        Err(_) => (),
89    }
90
91    context.push_question(question);
92
93    if candidates.is_none() {
94        candidates = candidate_nameservers(context, &question.name);
95    }
96
97    if let Some(candidates) = candidates {
98        let mut match_count = candidates.match_count();
99        let mut candidate_hostnames = candidates.hostnames;
100        let mut next_candidate_hostnames = Vec::with_capacity(candidate_hostnames.len());
101        let mut resolve_candidates_locally = true;
102
103        while let Some(candidate) = candidate_hostnames.pop() {
104            tracing::trace!(?candidate, "got candidate nameserver");
105            if let Some(ip) =
106                resolve_hostname_to_ip(context, resolve_candidates_locally, candidate.clone()).await
107            {
108                if let Some(nameserver_response) = query_nameserver(
109                    (ip, context.r.upstream_dns_port).into(),
110                    question.clone(),
111                    false,
112                )
113                .instrument(tracing::error_span!("query_nameserver", address = %ip, %match_count))
114                .await
115                .and_then(|res| validate_nameserver_response(question, &res, match_count))
116                {
117                    if resolve_candidates_locally {
118                        tracing::trace!(?candidate, "resolved fast candidate");
119                    } else {
120                        tracing::trace!(?candidate, "resolved slow candidate");
121                    }
122                    context.metrics().nameserver_hit();
123                    match resolve_with_nameserver_response(
124                        context,
125                        combined_rrs.clone(),
126                        nameserver_response,
127                        question,
128                    )
129                    .await
130                    {
131                        Ok(result) => {
132                            context.pop_question();
133                            return result;
134                        }
135                        Err(delegation) => {
136                            match_count = delegation.match_count();
137                            candidate_hostnames = delegation.hostnames;
138                            next_candidate_hostnames =
139                                Vec::with_capacity(candidate_hostnames.len());
140                            resolve_candidates_locally = true;
141                        }
142                    }
143                } else {
144                    context.metrics().nameserver_miss();
145                    // TODO: should distinguish between timeouts and other
146                    // failures here, and try the next nameserver after a
147                    // timeout.
148                    context.pop_question();
149                    return Err(ResolutionError::DeadEnd {
150                        question: question.clone(),
151                    });
152                }
153            } else if resolve_candidates_locally {
154                tracing::trace!(?candidate, "skipping slow candidate");
155                next_candidate_hostnames.push(candidate.clone());
156                // try slow candidates if out of fast ones
157                if candidate_hostnames.is_empty() {
158                    tracing::trace!("restarting with slow candidates");
159                    candidate_hostnames = next_candidate_hostnames;
160                    next_candidate_hostnames = Vec::new();
161                    resolve_candidates_locally = false;
162                }
163            } else {
164                // failed to resolve the candidate recursively, just drop it.
165                tracing::trace!(?candidate, "dropping unresolvable candidate");
166            }
167        }
168    }
169
170    tracing::trace!("out of candidates");
171    context.pop_question();
172    Err(ResolutionError::DeadEnd {
173        question: question.clone(),
174    })
175}
176
177/// Helper function for answering a question given a response from an upstream
178/// nameserver: this will only do further querying if the response is a CNAME.
179#[async_recursion]
180async fn resolve_with_nameserver_response<'a>(
181    context: &mut RecursiveContext<'a>,
182    mut combined_rrs: Vec<ResourceRecord>,
183    nameserver_response: NameserverResponse,
184    question: &Question,
185) -> Result<Result<ResolvedRecord, ResolutionError>, Nameservers> {
186    match nameserver_response {
187        NameserverResponse::Answer { rrs, soa_rr, .. } => {
188            tracing::trace!("got recursive answer");
189            context.cache.insert_all(&rrs);
190            prioritising_merge(&mut combined_rrs, rrs);
191            Ok(Ok(ResolvedRecord::NonAuthoritative {
192                rrs: combined_rrs,
193                soa_rr,
194            }))
195        }
196        NameserverResponse::Delegation {
197            rrs, delegation, ..
198        } => {
199            context.cache.insert_all(&rrs);
200            if question.qtype == QueryType::Record(RecordType::A) {
201                if let Some(rr) = get_record(&rrs, &question.name, RecordType::A) {
202                    tracing::trace!("got recursive delegation - using glue A record");
203                    prioritising_merge(&mut combined_rrs, vec![rr.clone()]);
204                    return Ok(Ok(ResolvedRecord::NonAuthoritative {
205                        rrs: combined_rrs,
206                        soa_rr: None,
207                    }));
208                }
209            } else if question.qtype == QueryType::Record(RecordType::AAAA) {
210                if let Some(rr) = get_record(&rrs, &question.name, RecordType::AAAA) {
211                    tracing::trace!("got recursive delegation - using glue AAAA record");
212                    prioritising_merge(&mut combined_rrs, vec![rr.clone()]);
213                    return Ok(Ok(ResolvedRecord::NonAuthoritative {
214                        rrs: combined_rrs,
215                        soa_rr: None,
216                    }));
217                }
218            }
219            tracing::trace!("got recursive delegation - using as candidate");
220            Err(delegation)
221        }
222        NameserverResponse::CNAME { rrs, cname, .. } => {
223            tracing::trace!("got recursive CNAME");
224            context.cache.insert_all(&rrs);
225            prioritising_merge(&mut combined_rrs, rrs);
226            let cname_question = Question {
227                name: cname,
228                qclass: question.qclass,
229                qtype: question.qtype,
230            };
231            let cname_answer =
232                resolve_combined_recursive(context, combined_rrs, cname_question).await;
233            Ok(cname_answer)
234        }
235    }
236}
237
238/// Helper function for resolving CNAMEs: resolve, and add some existing RRs to
239/// the ANSWER section of the result.
240async fn resolve_combined_recursive<'a>(
241    context: &mut RecursiveContext<'a>,
242    mut rrs: Vec<ResourceRecord>,
243    question: Question,
244) -> Result<ResolvedRecord, ResolutionError> {
245    match resolve_recursive_notimeout(context, &question)
246        .instrument(tracing::error_span!("resolve_combined_recursive", %question))
247        .await
248    {
249        Ok(resolved) => {
250            let soa_rr = resolved.soa_rr().cloned();
251            rrs.append(&mut resolved.rrs());
252            Ok(ResolvedRecord::NonAuthoritative { rrs, soa_rr })
253        }
254        Err(_) => Err(ResolutionError::DeadEnd { question }),
255    }
256}
257
258/// Resolve a hostname into an IP address, optionally only doing local
259/// resolution.
260async fn resolve_hostname_to_ip<'a>(
261    context: &mut RecursiveContext<'a>,
262    resolve_locally: bool,
263    hostname: DomainName,
264) -> Option<IpAddr> {
265    let rtypes = match context.r.protocol_mode {
266        ProtocolMode::OnlyV4 => vec![RecordType::A],
267        ProtocolMode::PreferV4 => vec![RecordType::A, RecordType::AAAA],
268        ProtocolMode::PreferV6 => vec![RecordType::AAAA, RecordType::A],
269        ProtocolMode::OnlyV6 => vec![RecordType::AAAA],
270    };
271
272    let mut question = Question {
273        name: hostname,
274        qclass: QueryClass::Record(RecordClass::IN),
275        // immediately replaced in the loop
276        qtype: QueryType::AXFR,
277    };
278    for rtype in rtypes {
279        question.qtype = QueryType::Record(rtype);
280        if resolve_locally {
281            if let Ok(LocalResolutionResult::Done { resolved }) = resolve_local(context, &question)
282            {
283                let address = get_ip(&resolved.rrs(), &question.name, rtype);
284                if address.is_some() {
285                    return address;
286                }
287            }
288        } else if let Ok(result) = resolve_recursive_notimeout(context, &question).await {
289            let address = get_ip(&result.rrs(), &question.name, rtype);
290            if address.is_some() {
291                return address;
292            }
293        }
294    }
295
296    None
297}
298
299/// Get the best nameservers by non-recursively looking them up for
300/// the domain and all its superdomains, in order.  If no nameservers
301/// are found, the root hints are returned.
302///
303/// This corresponds to step 2 of the standard resolver algorithm.
304fn candidate_nameservers(
305    context: &mut RecursiveContext<'_>,
306    question: &DomainName,
307) -> Option<Nameservers> {
308    for i in 0..question.labels.len() {
309        let labels = &question.labels[i..];
310        if let Some(name) = DomainName::from_labels(labels.into()) {
311            let ns_q = Question {
312                name: name.clone(),
313                qtype: QueryType::Record(RecordType::NS),
314                qclass: QueryClass::Record(RecordClass::IN),
315            };
316
317            let mut hostnames = Vec::new();
318
319            if let Ok(LocalResolutionResult::Done { resolved }) = resolve_local(context, &ns_q) {
320                for ns_rr in resolved.rrs() {
321                    if let RecordTypeWithData::NS { nsdname } = &ns_rr.rtype_with_data {
322                        hostnames.push(nsdname.clone());
323                    }
324                }
325            }
326
327            if !hostnames.is_empty() {
328                return Some(Nameservers {
329                    hostnames,
330                    name: ns_q.name,
331                });
332            }
333        }
334    }
335
336    None
337}
338
339/// Validate a nameserver response against the question by only keeping valid
340/// RRs:
341///
342/// - RRs matching the query domain (or the name it ends up being
343///   after following `CNAME`s) and type (or `CNAME`)
344///
345/// - `NS` RRs for a superdomain of the query domain (if it matches
346///   better than our current nameservers).
347///
348/// - `A` RRs corresponding to a selected `NS` RR
349///
350/// Then, decide whether:
351///
352/// - This is an answer: it has a possibly-empty sequence of CNAME RRs
353///   and a record of the right type at the final name.
354///
355/// - This is a cname to follow: it has a non-empty sequence of CNAME
356///   RRs but no final record of the right type.
357///
358/// - This is a delegation to other nameservers: there's at least one
359///   NS RR.
360///
361/// This makes the simplifying assumption that the question message
362/// has a single question in it, because that is how this function is
363/// used by this module.  If that assumption does not hold, a valid
364/// answer may be reported as invalid.
365fn validate_nameserver_response(
366    question: &Question,
367    response: &Message,
368    current_match_count: usize,
369) -> Option<NameserverResponse> {
370    if let Some((final_name, cname_map)) =
371        follow_cnames(&response.answers, &question.name, question.qtype)
372    {
373        // get RRs matching the query name or the names it `CNAME`s to
374
375        let mut rrs_for_query = Vec::<ResourceRecord>::with_capacity(response.answers.len());
376        let mut seen_final_record = false;
377        let mut all_unknown = true;
378        for an in &response.answers {
379            if an.is_unknown() {
380                continue;
381            }
382
383            let rtype = an.rtype_with_data.rtype();
384            all_unknown = false;
385
386            if rtype.matches(question.qtype) && an.name == final_name {
387                rrs_for_query.push(an.clone());
388                seen_final_record = true;
389            } else if rtype == RecordType::CNAME && cname_map.contains_key(&an.name) {
390                rrs_for_query.push(an.clone());
391            }
392        }
393
394        if all_unknown {
395            None
396        } else if rrs_for_query.is_empty() {
397            tracing::warn!("expected RRs");
398            None
399        } else {
400            // what sort of answer is this?
401            if seen_final_record {
402                Some(NameserverResponse::Answer {
403                    rrs: rrs_for_query,
404                    soa_rr: None,
405                })
406            } else {
407                Some(NameserverResponse::CNAME {
408                    rrs: rrs_for_query,
409                    cname: final_name,
410                })
411            }
412        }
413    } else {
414        // get NS RRs and their associated A RRs.
415        //
416        // NOTE: `NS` RRs may be in the ANSWER *or* AUTHORITY sections.
417
418        let (match_name, ns_names) = {
419            let ns_from_answers =
420                get_better_ns_names(&response.answers, &question.name, current_match_count);
421            let ns_from_authority =
422                get_better_ns_names(&response.authority, &question.name, current_match_count);
423            match (ns_from_answers, ns_from_authority) {
424                (Some((mn1, nss1)), Some((mn2, nss2))) => {
425                    match mn1.labels.len().cmp(&mn2.labels.len()) {
426                        Ordering::Greater => (mn1, nss1),
427                        Ordering::Equal => (mn1, nss1.union(&nss2).cloned().collect()),
428                        Ordering::Less => (mn2, nss2),
429                    }
430                }
431                (Some((mn, nss)), None) => (mn, nss),
432                (None, Some((mn, nss))) => (mn, nss),
433                (None, None) => {
434                    // No records and no delegation - check if this is an
435                    // NXDOMAIN / NODATA response and if so propagate the SOA RR
436                    return get_nxdomain_nodata_soa(question, response, current_match_count).map(
437                        |soa_rr| NameserverResponse::Answer {
438                            rrs: Vec::new(),
439                            soa_rr: Some(soa_rr).cloned(),
440                        },
441                    );
442                }
443            }
444        };
445
446        // you never know, the upstream nameserver may have been kind enough to
447        // give an A record along with each NS record, if we're lucky.
448        let mut nameserver_rrs = Vec::<ResourceRecord>::with_capacity(ns_names.len() * 2);
449        for rr in &response.answers {
450            match &rr.rtype_with_data {
451                RecordTypeWithData::NS { nsdname } if ns_names.contains(nsdname) => {
452                    nameserver_rrs.push(rr.clone());
453                }
454                RecordTypeWithData::A { .. } if ns_names.contains(&rr.name) => {
455                    nameserver_rrs.push(rr.clone());
456                }
457                RecordTypeWithData::AAAA { .. } if ns_names.contains(&rr.name) => {
458                    nameserver_rrs.push(rr.clone());
459                }
460                _ => (),
461            }
462        }
463        for rr in &response.authority {
464            match &rr.rtype_with_data {
465                RecordTypeWithData::NS { nsdname } if ns_names.contains(nsdname) => {
466                    nameserver_rrs.push(rr.clone());
467                }
468                _ => (),
469            }
470        }
471        for rr in &response.additional {
472            match &rr.rtype_with_data {
473                RecordTypeWithData::A { .. } if ns_names.contains(&rr.name) => {
474                    nameserver_rrs.push(rr.clone());
475                }
476                RecordTypeWithData::AAAA { .. } if ns_names.contains(&rr.name) => {
477                    nameserver_rrs.push(rr.clone());
478                }
479                _ => (),
480            }
481        }
482
483        // this is a delegation
484        Some(NameserverResponse::Delegation {
485            rrs: nameserver_rrs,
486            delegation: Nameservers {
487                hostnames: ns_names.into_iter().collect(),
488                name: match_name,
489            },
490        })
491    }
492}
493
494/// Given a set of RRs and a domain name we're looking for, follow
495/// `CNAME`s in the response and return the final name (which is the
496/// name that will have the non-`CNAME` records associated with it).
497///
498/// Returns `None` if CNAMEs form a loop, or there is no RR which
499/// matches the target name (a CNAME or one with the right type).
500fn follow_cnames(
501    rrs: &[ResourceRecord],
502    target: &DomainName,
503    qtype: QueryType,
504) -> Option<(DomainName, HashMap<DomainName, DomainName>)> {
505    let mut got_match = false;
506    let mut cname_map = HashMap::<DomainName, DomainName>::new();
507    for rr in rrs {
508        if &rr.name == target && rr.rtype_with_data.matches(qtype) {
509            got_match = true;
510        }
511        if let RecordTypeWithData::CNAME { cname } = &rr.rtype_with_data {
512            cname_map.insert(rr.name.clone(), cname.clone());
513        }
514    }
515
516    let mut seen = HashSet::new();
517    let mut final_name = target.clone();
518    while let Some(target) = cname_map.get(&final_name) {
519        if seen.contains(target) {
520            return None;
521        }
522        seen.insert(target.clone());
523        final_name = target.clone();
524    }
525
526    if got_match || !seen.is_empty() {
527        Some((final_name, cname_map))
528    } else {
529        None
530    }
531}
532
533/// Given a set of RRs and a domain name we're looking for, look for
534/// better matching NS RRs (by comparing the current match count).
535/// Returns the new matching superdomain and the nameserver hostnames.
536fn get_better_ns_names(
537    rrs: &[ResourceRecord],
538    target: &DomainName,
539    current_match_count: usize,
540) -> Option<(DomainName, HashSet<DomainName>)> {
541    let mut ns_names = HashSet::new();
542    let mut match_count = current_match_count;
543    let mut match_name = None;
544
545    for rr in rrs {
546        if let RecordTypeWithData::NS { nsdname } = &rr.rtype_with_data {
547            if target.is_subdomain_of(&rr.name) {
548                match rr.name.labels.len().cmp(&match_count) {
549                    Ordering::Greater => {
550                        match_count = rr.name.labels.len();
551                        match_name = Some(rr.name.clone());
552
553                        ns_names.clear();
554                        ns_names.insert(nsdname.clone());
555                    }
556                    Ordering::Equal => {
557                        ns_names.insert(nsdname.clone());
558                    }
559                    Ordering::Less => (),
560                }
561            }
562        }
563    }
564
565    match_name.map(|mn| (mn, ns_names))
566}
567
568/// Given a set of RRs and a domain name we're looking for, follow any
569/// `CNAME`s in the response and get the address from the final `A` / `AAAA`
570/// record.
571fn get_ip(rrs: &[ResourceRecord], target: &DomainName, rtype: RecordType) -> Option<IpAddr> {
572    if let Some((final_name, _)) = follow_cnames(rrs, target, QueryType::Wildcard) {
573        if let Some(rr) = get_record(rrs, &final_name, rtype) {
574            match rr.rtype_with_data {
575                RecordTypeWithData::A { address } => Some(IpAddr::V4(address)),
576                RecordTypeWithData::AAAA { address } => Some(IpAddr::V6(address)),
577                _ => None,
578            }
579        } else {
580            None
581        }
582    } else {
583        None
584    }
585}
586
587/// Given a set of RRs and a domain we're looking for, return the record we're
588/// looking for (if any).
589///
590/// Unlike `get_ip` this does not follow `CNAME`s.
591fn get_record<'a>(
592    rrs: &'a [ResourceRecord],
593    target: &DomainName,
594    rtype: RecordType,
595) -> Option<&'a ResourceRecord> {
596    rrs.iter()
597        .find(|&rr| rr.rtype_with_data.rtype() == rtype && rr.name == *target)
598}
599
600/// A response from a remote nameserver
601#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
602pub enum NameserverResponse {
603    Answer {
604        rrs: Vec<ResourceRecord>,
605        soa_rr: Option<ResourceRecord>,
606    },
607    CNAME {
608        rrs: Vec<ResourceRecord>,
609        cname: DomainName,
610    },
611    Delegation {
612        rrs: Vec<ResourceRecord>,
613        delegation: Nameservers,
614    },
615}
616
617#[cfg(test)]
618mod tests {
619    use std::net::{Ipv4Addr, Ipv6Addr};
620
621    use dns_types::protocol::types::test_util::*;
622    use dns_types::zones::types::*;
623
624    use super::*;
625    use crate::cache::SharedCache;
626    use crate::util::nameserver::test_util::*;
627
628    #[test]
629    fn candidate_nameservers_gets_all_matches() {
630        let qdomain = domain("com.");
631        assert_eq!(
632            Some(Nameservers {
633                hostnames: vec![domain("ns1.example.com."), domain("ns2.example.com.")],
634                name: qdomain.clone(),
635            }),
636            candidate_nameservers(
637                &mut Context::new(
638                    RecursiveContextInner {
639                        protocol_mode: ProtocolMode::PreferV4,
640                        upstream_dns_port: 53,
641                    },
642                    &Zones::new(),
643                    &cache_with_nameservers(&["com."]),
644                    10,
645                ),
646                &qdomain
647            )
648        );
649    }
650
651    #[test]
652    fn candidate_nameservers_returns_longest_match() {
653        assert_eq!(
654            Some(Nameservers {
655                hostnames: vec![domain("ns1.example.com."), domain("ns2.example.com.")],
656                name: domain("example.com."),
657            }),
658            candidate_nameservers(
659                &mut Context::new(
660                    RecursiveContextInner {
661                        protocol_mode: ProtocolMode::PreferV4,
662                        upstream_dns_port: 53,
663                    },
664                    &Zones::new(),
665                    &cache_with_nameservers(&["example.com.", "com."]),
666                    10,
667                ),
668                &domain("www.example.com.")
669            )
670        );
671    }
672
673    #[test]
674    fn candidate_nameservers_returns_none_on_failure() {
675        assert_eq!(
676            None,
677            candidate_nameservers(
678                &mut Context::new(
679                    RecursiveContextInner {
680                        protocol_mode: ProtocolMode::PreferV4,
681                        upstream_dns_port: 53,
682                    },
683                    &Zones::new(),
684                    &cache_with_nameservers(&["com."]),
685                    10,
686                ),
687                &domain("net.")
688            )
689        );
690    }
691
692    #[test]
693    fn validate_nameserver_response_returns_answer() {
694        let (request, response) = nameserver_response(
695            "www.example.com.",
696            &[a_record("www.example.com.", Ipv4Addr::new(127, 0, 0, 1))],
697            &[],
698            &[],
699        );
700
701        assert_eq!(
702            Some(NameserverResponse::Answer {
703                rrs: vec![a_record("www.example.com.", Ipv4Addr::new(127, 0, 0, 1))],
704                soa_rr: None,
705            }),
706            validate_nameserver_response(&request.questions[0], &response, 0)
707        );
708    }
709
710    #[test]
711    fn validate_nameserver_response_drops_unknown_rrs() {
712        let request = Message::from_question(
713            1234,
714            Question {
715                name: domain("www.example.com."),
716                qtype: QueryType::Wildcard,
717                qclass: QueryClass::Record(RecordClass::IN),
718            },
719        );
720
721        let mut response = request.make_response();
722        response.answers = [
723            unknown_record("www.example.com.", &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
724            a_record("www.example.com.", Ipv4Addr::new(1, 1, 1, 1)),
725        ]
726        .into();
727
728        assert_eq!(
729            Some(NameserverResponse::Answer {
730                rrs: vec![a_record("www.example.com.", Ipv4Addr::new(1, 1, 1, 1))],
731                soa_rr: None,
732            }),
733            validate_nameserver_response(&request.questions[0], &response, 0)
734        );
735    }
736
737    #[test]
738    fn validate_nameserver_response_returns_none_if_all_rrs_unknown() {
739        let request = Message::from_question(
740            1234,
741            Question {
742                name: domain("www.example.com."),
743                qtype: QueryType::Wildcard,
744                qclass: QueryClass::Record(RecordClass::IN),
745            },
746        );
747
748        let mut response = request.make_response();
749        response.answers = [unknown_record(
750            "www.example.com.",
751            &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
752        )]
753        .into();
754
755        assert_eq!(
756            None,
757            validate_nameserver_response(&request.questions[0], &response, 0)
758        );
759    }
760
761    #[test]
762    fn validate_nameserver_response_follows_cnames() {
763        let (request, response) = nameserver_response(
764            "www.example.com.",
765            &[
766                cname_record("www.example.com.", "cname-target.example.com."),
767                a_record("cname-target.example.com.", Ipv4Addr::new(127, 0, 0, 1)),
768            ],
769            &[],
770            &[],
771        );
772
773        assert_eq!(
774            Some(NameserverResponse::Answer {
775                rrs: vec![
776                    cname_record("www.example.com.", "cname-target.example.com."),
777                    a_record("cname-target.example.com.", Ipv4Addr::new(127, 0, 0, 1))
778                ],
779                soa_rr: None,
780            }),
781            validate_nameserver_response(&request.questions[0], &response, 0)
782        );
783    }
784
785    #[test]
786    fn validate_nameserver_response_returns_partial_answer() {
787        let (request, response) = nameserver_response(
788            "www.example.com.",
789            &[cname_record(
790                "www.example.com.",
791                "cname-target.example.com.",
792            )],
793            &[],
794            &[],
795        );
796
797        assert_eq!(
798            Some(NameserverResponse::CNAME {
799                rrs: vec![cname_record(
800                    "www.example.com.",
801                    "cname-target.example.com."
802                )],
803                cname: domain("cname-target.example.com."),
804            }),
805            validate_nameserver_response(&request.questions[0], &response, 0)
806        );
807    }
808
809    #[test]
810    fn validate_nameserver_response_gets_ns_from_answers_and_authority_but_not_additional() {
811        let (request, response) = nameserver_response(
812            "www.example.com.",
813            &[ns_record("example.com.", "ns-an.example.net.")],
814            &[ns_record("example.com.", "ns-ns.example.net.")],
815            &[ns_record("example.com.", "ns-ar.example.net.")],
816        );
817
818        match validate_nameserver_response(&request.questions[0], &response, 0) {
819            Some(NameserverResponse::Delegation {
820                rrs: mut actual_rrs,
821                delegation: mut actual_delegation,
822            }) => {
823                let mut expected_rrs = vec![
824                    ns_record("example.com.", "ns-an.example.net."),
825                    ns_record("example.com.", "ns-ns.example.net."),
826                ];
827
828                expected_rrs.sort();
829                actual_rrs.sort();
830
831                assert_eq!(expected_rrs, actual_rrs);
832
833                let mut expected_delegation = Nameservers {
834                    hostnames: vec![domain("ns-an.example.net."), domain("ns-ns.example.net.")],
835                    name: domain("example.com."),
836                };
837
838                expected_delegation.hostnames.sort();
839                actual_delegation.hostnames.sort();
840
841                assert_eq!(expected_delegation, actual_delegation);
842            }
843            actual => panic!("Expected delegation, got {actual:?}"),
844        }
845    }
846
847    #[test]
848    fn validate_nameserver_response_only_returns_better_ns() {
849        let (request, response) = nameserver_response(
850            "long.subdomain.example.com.",
851            &[ns_record("example.com.", "ns.example.net.")],
852            &[],
853            &[],
854        );
855
856        assert_eq!(
857            None,
858            validate_nameserver_response(
859                &request.questions[0],
860                &response,
861                domain("subdomain.example.com.").labels.len()
862            )
863        );
864    }
865
866    #[test]
867    fn validate_nameserver_response_prefers_best_ns() {
868        let (request, response1) = nameserver_response(
869            "long.subdomain.example.com.",
870            &[ns_record(
871                "subdomain.example.com.",
872                "ns-better.example.net.",
873            )],
874            &[ns_record("example.com.", "ns-worse.example.net.")],
875            &[],
876        );
877        let (_, response2) = nameserver_response(
878            "long.subdomain.example.com.",
879            &[ns_record("example.com.", "ns-worse.example.net.")],
880            &[ns_record(
881                "subdomain.example.com.",
882                "ns-better.example.net.",
883            )],
884            &[],
885        );
886
887        assert_eq!(
888            Some(NameserverResponse::Delegation {
889                rrs: vec![ns_record(
890                    "subdomain.example.com.",
891                    "ns-better.example.net."
892                )],
893                delegation: Nameservers {
894                    hostnames: vec![domain("ns-better.example.net.")],
895                    name: domain("subdomain.example.com."),
896                },
897            }),
898            validate_nameserver_response(&request.questions[0], &response1, 0)
899        );
900
901        assert_eq!(
902            Some(NameserverResponse::Delegation {
903                rrs: vec![ns_record(
904                    "subdomain.example.com.",
905                    "ns-better.example.net."
906                )],
907                delegation: Nameservers {
908                    hostnames: vec![domain("ns-better.example.net.")],
909                    name: domain("subdomain.example.com."),
910                },
911            }),
912            validate_nameserver_response(&request.questions[0], &response2, 0)
913        );
914    }
915
916    #[test]
917    fn validate_nameserver_response_gets_ns_a_from_answers_and_additional_but_not_authority() {
918        let (request, response) = nameserver_response(
919            "www.example.com.",
920            &[
921                ns_record("example.com.", "ns-an.example.net."),
922                a_record("ns-an.example.net.", Ipv4Addr::new(1, 1, 1, 1)),
923                a_record("ns-ns.example.net.", Ipv4Addr::new(1, 1, 1, 1)),
924            ],
925            &[
926                ns_record("example.com.", "ns-ns.example.net."),
927                a_record("ns-an.example.net.", Ipv4Addr::new(2, 2, 2, 2)),
928                a_record("ns-ns.example.net.", Ipv4Addr::new(2, 2, 2, 2)),
929            ],
930            &[
931                a_record("ns-an.example.net.", Ipv4Addr::new(3, 3, 3, 3)),
932                a_record("ns-ns.example.net.", Ipv4Addr::new(3, 3, 3, 3)),
933            ],
934        );
935
936        match validate_nameserver_response(&request.questions[0], &response, 0) {
937            Some(NameserverResponse::Delegation {
938                rrs: mut actual_rrs,
939                delegation: _,
940            }) => {
941                let mut expected_rrs = vec![
942                    ns_record("example.com.", "ns-an.example.net."),
943                    ns_record("example.com.", "ns-ns.example.net."),
944                    a_record("ns-an.example.net.", Ipv4Addr::new(1, 1, 1, 1)),
945                    a_record("ns-ns.example.net.", Ipv4Addr::new(1, 1, 1, 1)),
946                    a_record("ns-an.example.net.", Ipv4Addr::new(3, 3, 3, 3)),
947                    a_record("ns-ns.example.net.", Ipv4Addr::new(3, 3, 3, 3)),
948                ];
949
950                expected_rrs.sort();
951                actual_rrs.sort();
952
953                assert_eq!(expected_rrs, actual_rrs);
954            }
955            actual => panic!("Expected delegation, got {actual:?}"),
956        }
957    }
958
959    #[test]
960    fn validate_nameserver_response_propagates_nodata() {
961        let soa_record = ResourceRecord {
962            name: domain("com."),
963            rtype_with_data: RecordTypeWithData::SOA {
964                mname: domain("mname."),
965                rname: domain("rname."),
966                serial: 0,
967                refresh: 0,
968                retry: 0,
969                expire: 0,
970                minimum: 0,
971            },
972            rclass: RecordClass::IN,
973            ttl: 300,
974        };
975
976        let (request, response) =
977            nameserver_response("www.example.com.", &[], &[soa_record.clone()], &[]);
978
979        assert_eq!(
980            validate_nameserver_response(&request.questions[0], &response, 0),
981            Some(NameserverResponse::Answer {
982                rrs: Vec::new(),
983                soa_rr: Some(soa_record)
984            }),
985        );
986    }
987
988    #[test]
989    fn validate_nameserver_response_rejects_nodata_if_soa_too_generic() {
990        let soa_record = ResourceRecord {
991            name: domain("com."),
992            rtype_with_data: RecordTypeWithData::SOA {
993                mname: domain("mname."),
994                rname: domain("rname."),
995                serial: 0,
996                refresh: 0,
997                retry: 0,
998                expire: 0,
999                minimum: 0,
1000            },
1001            rclass: RecordClass::IN,
1002            ttl: 300,
1003        };
1004
1005        let (request, response) = nameserver_response("www.example.com.", &[], &[soa_record], &[]);
1006
1007        // pretend we're querying the nameserver for example.com
1008        let current_match_count = domain("example.com.").labels.len();
1009
1010        assert_eq!(
1011            validate_nameserver_response(&request.questions[0], &response, current_match_count),
1012            None,
1013        );
1014    }
1015
1016    #[test]
1017    fn validate_nameserver_response_rejects_nodata_if_soa_too_specific() {
1018        let soa_record = ResourceRecord {
1019            name: domain("foo.example.com."),
1020            rtype_with_data: RecordTypeWithData::SOA {
1021                mname: domain("mname."),
1022                rname: domain("rname."),
1023                serial: 0,
1024                refresh: 0,
1025                retry: 0,
1026                expire: 0,
1027                minimum: 0,
1028            },
1029            rclass: RecordClass::IN,
1030            ttl: 300,
1031        };
1032
1033        let (request, response) = nameserver_response("www.example.com.", &[], &[soa_record], &[]);
1034
1035        assert_eq!(
1036            validate_nameserver_response(&request.questions[0], &response, 0),
1037            None,
1038        );
1039    }
1040
1041    #[test]
1042    fn follow_cnames_empty() {
1043        assert_eq!(
1044            None,
1045            follow_cnames(&[], &domain("www.example.com."), QueryType::Wildcard)
1046        );
1047    }
1048
1049    #[test]
1050    fn follow_cnames_no_name_match() {
1051        assert_eq!(
1052            None,
1053            follow_cnames(
1054                &[a_record("www.example.net.", Ipv4Addr::new(1, 1, 1, 1))],
1055                &domain("www.example.com."),
1056                QueryType::Wildcard
1057            )
1058        );
1059    }
1060
1061    #[test]
1062    fn follow_cnames_no_type_match() {
1063        assert_eq!(
1064            None,
1065            follow_cnames(
1066                &[a_record("www.example.net.", Ipv4Addr::new(1, 1, 1, 1))],
1067                &domain("www.example.com."),
1068                QueryType::Record(RecordType::NS)
1069            )
1070        );
1071    }
1072
1073    #[test]
1074    fn follow_cnames_no_cname() {
1075        let rr_a = a_record("www.example.com.", Ipv4Addr::new(127, 0, 0, 1));
1076        assert_eq!(
1077            Some((domain("www.example.com."), HashMap::new())),
1078            follow_cnames(&[rr_a], &domain("www.example.com."), QueryType::Wildcard)
1079        );
1080    }
1081
1082    #[test]
1083    fn follow_cnames_chain() {
1084        let rr_cname1 = cname_record("www.example.com.", "www2.example.com.");
1085        let rr_cname2 = cname_record("www2.example.com.", "www3.example.com.");
1086        let rr_a = a_record("www3.example.com.", Ipv4Addr::new(127, 0, 0, 1));
1087
1088        let mut expected_map = HashMap::new();
1089        expected_map.insert(domain("www.example.com."), domain("www2.example.com."));
1090        expected_map.insert(domain("www2.example.com."), domain("www3.example.com."));
1091
1092        // order of records does not matter, so pick the "worst"
1093        // order: the records are in the opposite order to what we'd
1094        // expect
1095        assert_eq!(
1096            Some((domain("www3.example.com."), expected_map)),
1097            follow_cnames(
1098                &[rr_a, rr_cname2, rr_cname1],
1099                &domain("www.example.com."),
1100                QueryType::Wildcard
1101            )
1102        );
1103    }
1104
1105    #[test]
1106    fn follow_cnames_loop() {
1107        let rr_cname1 = cname_record("www.example.com.", "bad.example.com.");
1108        let rr_cname2 = cname_record("bad.example.com.", "www.example.com.");
1109
1110        assert_eq!(
1111            None,
1112            follow_cnames(
1113                &[rr_cname1, rr_cname2],
1114                &domain("www.example.com."),
1115                QueryType::Wildcard
1116            )
1117        );
1118    }
1119
1120    #[test]
1121    fn get_better_ns_names_no_match() {
1122        let rr_ns = ns_record("example.", "ns1.icann.org.");
1123        assert_eq!(
1124            None,
1125            get_better_ns_names(&[rr_ns], &domain("www.example.com."), 0)
1126        );
1127    }
1128
1129    #[test]
1130    fn get_better_ns_names_no_better() {
1131        let rr_ns = ns_record("com.", "ns1.icann.org.");
1132        assert_eq!(
1133            None,
1134            get_better_ns_names(&[rr_ns], &domain("www.example.com."), 2)
1135        );
1136    }
1137
1138    #[test]
1139    fn get_better_ns_names_better() {
1140        let rr_ns = ns_record("example.com.", "ns2.icann.org.");
1141        assert_eq!(
1142            Some((
1143                domain("example.com."),
1144                [domain("ns2.icann.org.")].into_iter().collect()
1145            )),
1146            get_better_ns_names(&[rr_ns], &domain("www.example.com."), 0)
1147        );
1148    }
1149
1150    #[test]
1151    fn get_better_ns_names_better_better() {
1152        let rr_ns1 = ns_record("example.com.", "ns2.icann.org.");
1153        let rr_ns2 = ns_record("www.example.com.", "ns3.icann.org.");
1154        assert_eq!(
1155            Some((
1156                domain("www.example.com."),
1157                [domain("ns3.icann.org.")].into_iter().collect()
1158            )),
1159            get_better_ns_names(&[rr_ns1, rr_ns2], &domain("www.example.com."), 0)
1160        );
1161    }
1162
1163    #[test]
1164    fn get_ip_domain_mismatch() {
1165        let a_rr = a_record("www.example.net.", Ipv4Addr::new(127, 0, 0, 1));
1166        assert_eq!(
1167            None,
1168            get_ip(&[a_rr], &domain("www.example.com."), RecordType::A)
1169        );
1170    }
1171
1172    #[test]
1173    fn get_ip_type_mismatch() {
1174        let aaaa_rr = aaaa_record("www.example.com.", Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
1175        assert_eq!(
1176            None,
1177            get_ip(&[aaaa_rr], &domain("www.example.com."), RecordType::A,)
1178        );
1179    }
1180
1181    #[test]
1182    fn get_ip_domain_and_type_match() {
1183        let a_rr = a_record("www.example.com.", Ipv4Addr::new(127, 0, 0, 1));
1184        let aaaa_rr = aaaa_record("www.example.com.", Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
1185        let rrs = [a_rr, aaaa_rr];
1186        assert_eq!(
1187            Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
1188            get_ip(&rrs, &domain("www.example.com."), RecordType::A)
1189        );
1190        assert_eq!(
1191            Some(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
1192            get_ip(&rrs, &domain("www.example.com."), RecordType::AAAA)
1193        );
1194    }
1195
1196    #[test]
1197    fn get_ip_cname_match() {
1198        let cname_rr = cname_record("www.example.com.", "www.example.net.");
1199        let a_rr = a_record("www.example.net.", Ipv4Addr::new(127, 0, 0, 1));
1200        assert_eq!(
1201            Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
1202            get_ip(
1203                &[cname_rr, a_rr],
1204                &domain("www.example.com."),
1205                RecordType::A,
1206            )
1207        );
1208    }
1209
1210    fn cache_with_nameservers(names: &[&str]) -> SharedCache {
1211        let cache = SharedCache::new();
1212
1213        for name in names {
1214            cache.insert(&ns_record(name, "ns1.example.com."));
1215            cache.insert(&ns_record(name, "ns2.example.com."));
1216        }
1217
1218        cache
1219    }
1220}