dns_resolver/
forwarding.rs

1use async_recursion::async_recursion;
2use std::net::SocketAddr;
3use std::time::Duration;
4use tokio::time::timeout;
5use tracing::Instrument;
6
7use dns_types::protocol::types::*;
8
9use crate::context::Context;
10use crate::local::{resolve_local, LocalResolutionResult};
11use crate::util::nameserver::*;
12use crate::util::types::*;
13
14pub struct ForwardingContextInner {
15    pub forward_address: SocketAddr,
16}
17
18pub type ForwardingContext<'a> = Context<'a, ForwardingContextInner>;
19
20/// Forwarding DNS resolution.
21///
22/// Attempts to resolve a query locally and, if it cannot, calls out
23/// to another nameserver and returns its response.  As this other
24/// nameserver can spoof any records it wants, very little validation
25/// is done of its responses.
26///
27/// This has a 60s timeout.
28///
29/// # Errors
30///
31/// See `ResolutionError`.
32pub async fn resolve_forwarding<'a>(
33    context: &mut ForwardingContext<'a>,
34    question: &Question,
35) -> Result<ResolvedRecord, ResolutionError> {
36    if let Ok(res) = timeout(
37        Duration::from_secs(60),
38        resolve_forwarding_notimeout(context, question),
39    )
40    .await
41    {
42        res
43    } else {
44        tracing::debug!("timed out");
45        Err(ResolutionError::Timeout)
46    }
47}
48
49/// Timeout-less version of `resolve_forwarding`.
50#[async_recursion]
51async fn resolve_forwarding_notimeout<'a>(
52    context: &mut ForwardingContext<'a>,
53    question: &Question,
54) -> Result<ResolvedRecord, ResolutionError> {
55    if context.at_recursion_limit() {
56        tracing::debug!("hit recursion limit");
57        return Err(ResolutionError::RecursionLimit);
58    }
59    if context.is_duplicate_question(question) {
60        tracing::debug!("hit duplicate question");
61        return Err(ResolutionError::DuplicateQuestion {
62            question: question.clone(),
63        });
64    }
65
66    let mut combined_rrs = Vec::new();
67
68    // this is almost the same as in the recursive resolver, but:
69    //
70    // - delegations are ignored (we just forward to the upstream nameserver)
71    // - CNAMEs are resolved by calling the forwarding resolver recursively
72    match resolve_local(context, question) {
73        Ok(LocalResolutionResult::Done { resolved }) => return Ok(resolved),
74        Ok(LocalResolutionResult::Partial { rrs }) => combined_rrs = rrs,
75        Ok(LocalResolutionResult::Delegation { .. }) => (),
76        Ok(LocalResolutionResult::CNAME {
77            mut rrs,
78            cname_question,
79            ..
80        }) => {
81            context.push_question(question);
82            let answer = match resolve_forwarding_notimeout(context, &cname_question)
83                .instrument(tracing::error_span!("resolve_forwarding", %cname_question))
84                .await
85            {
86                Ok(resolved) => {
87                    let soa_rr = resolved.soa_rr().cloned();
88                    let mut r_rrs = resolved.rrs();
89                    let mut combined_rrs = Vec::with_capacity(rrs.len() + r_rrs.len());
90                    combined_rrs.append(&mut rrs);
91                    combined_rrs.append(&mut r_rrs);
92                    Ok(ResolvedRecord::NonAuthoritative {
93                        rrs: combined_rrs,
94                        soa_rr,
95                    })
96                }
97                Err(_) => Err(ResolutionError::DeadEnd {
98                    question: cname_question,
99                }),
100            };
101            context.pop_question();
102            return answer;
103        }
104        Err(_) => (),
105    }
106
107    if let Some(response) = query_nameserver(context.r.forward_address, question.clone(), true)
108        .instrument(tracing::error_span!("query_nameserver"))
109        .await
110    {
111        context.metrics().nameserver_hit();
112        tracing::trace!("nameserver HIT");
113        // Propagate SOA RR for NXDOMAIN / NODATA responses
114        let soa_rr = get_nxdomain_nodata_soa(question, &response, 0).cloned();
115        let rrs = response.answers;
116        context.cache.insert_all(&rrs);
117        prioritising_merge(&mut combined_rrs, rrs);
118        Ok(ResolvedRecord::NonAuthoritative {
119            rrs: combined_rrs,
120            soa_rr,
121        })
122    } else {
123        context.metrics().nameserver_miss();
124        tracing::trace!("nameserver MISS");
125        Err(ResolutionError::DeadEnd {
126            question: question.clone(),
127        })
128    }
129}