dns_resolver/
forwarding.rs1use async_recursion::async_recursion;
2use std::net::SocketAddr;
3use std::time::Duration;
4use tokio::time::timeout;
5use tracing::Instrument;
6
7use dns_types::protocol::types::*;
8
9use crate::context::Context;
10use crate::local::{resolve_local, LocalResolutionResult};
11use crate::util::nameserver::*;
12use crate::util::types::*;
13
14pub struct ForwardingContextInner {
15 pub forward_address: SocketAddr,
16}
17
18pub type ForwardingContext<'a> = Context<'a, ForwardingContextInner>;
19
20pub async fn resolve_forwarding<'a>(
33 context: &mut ForwardingContext<'a>,
34 question: &Question,
35) -> Result<ResolvedRecord, ResolutionError> {
36 if let Ok(res) = timeout(
37 Duration::from_secs(60),
38 resolve_forwarding_notimeout(context, question),
39 )
40 .await
41 {
42 res
43 } else {
44 tracing::debug!("timed out");
45 Err(ResolutionError::Timeout)
46 }
47}
48
49#[async_recursion]
51async fn resolve_forwarding_notimeout<'a>(
52 context: &mut ForwardingContext<'a>,
53 question: &Question,
54) -> Result<ResolvedRecord, ResolutionError> {
55 if context.at_recursion_limit() {
56 tracing::debug!("hit recursion limit");
57 return Err(ResolutionError::RecursionLimit);
58 }
59 if context.is_duplicate_question(question) {
60 tracing::debug!("hit duplicate question");
61 return Err(ResolutionError::DuplicateQuestion {
62 question: question.clone(),
63 });
64 }
65
66 let mut combined_rrs = Vec::new();
67
68 match resolve_local(context, question) {
73 Ok(LocalResolutionResult::Done { resolved }) => return Ok(resolved),
74 Ok(LocalResolutionResult::Partial { rrs }) => combined_rrs = rrs,
75 Ok(LocalResolutionResult::Delegation { .. }) => (),
76 Ok(LocalResolutionResult::CNAME {
77 mut rrs,
78 cname_question,
79 ..
80 }) => {
81 context.push_question(question);
82 let answer = match resolve_forwarding_notimeout(context, &cname_question)
83 .instrument(tracing::error_span!("resolve_forwarding", %cname_question))
84 .await
85 {
86 Ok(resolved) => {
87 let soa_rr = resolved.soa_rr().cloned();
88 let mut r_rrs = resolved.rrs();
89 let mut combined_rrs = Vec::with_capacity(rrs.len() + r_rrs.len());
90 combined_rrs.append(&mut rrs);
91 combined_rrs.append(&mut r_rrs);
92 Ok(ResolvedRecord::NonAuthoritative {
93 rrs: combined_rrs,
94 soa_rr,
95 })
96 }
97 Err(_) => Err(ResolutionError::DeadEnd {
98 question: cname_question,
99 }),
100 };
101 context.pop_question();
102 return answer;
103 }
104 Err(_) => (),
105 }
106
107 if let Some(response) = query_nameserver(context.r.forward_address, question.clone(), true)
108 .instrument(tracing::error_span!("query_nameserver"))
109 .await
110 {
111 context.metrics().nameserver_hit();
112 tracing::trace!("nameserver HIT");
113 let soa_rr = get_nxdomain_nodata_soa(question, &response, 0).cloned();
115 let rrs = response.answers;
116 context.cache.insert_all(&rrs);
117 prioritising_merge(&mut combined_rrs, rrs);
118 Ok(ResolvedRecord::NonAuthoritative {
119 rrs: combined_rrs,
120 soa_rr,
121 })
122 } else {
123 context.metrics().nameserver_miss();
124 tracing::trace!("nameserver MISS");
125 Err(ResolutionError::DeadEnd {
126 question: question.clone(),
127 })
128 }
129}