dns_types/protocol/
deserialise.rs

1//! Deserialisation of DNS messages from the network.  See the `types`
2//!
3//! module for details of the format.
4
5use bytes::Bytes;
6use std::net::{Ipv4Addr, Ipv6Addr};
7
8use crate::protocol::types::*;
9
10impl Message {
11    /// # Errors
12    ///
13    /// If the message cannot be parsed.
14    pub fn from_octets(octets: &[u8]) -> Result<Self, Error> {
15        Self::deserialise(&mut ConsumableBuffer::new(octets))
16    }
17
18    /// # Errors
19    ///
20    /// If the message cannot be parsed.
21    fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
22        let header = Header::deserialise(buffer)?;
23        let qdcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
24        let ancount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
25        let nscount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
26        let arcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
27
28        let mut questions = Vec::with_capacity(qdcount.into());
29        let mut answers = Vec::with_capacity(ancount.into());
30        let mut authority = Vec::with_capacity(nscount.into());
31        let mut additional = Vec::with_capacity(arcount.into());
32
33        for _ in 0..qdcount {
34            questions.push(Question::deserialise(header.id, buffer)?);
35        }
36        for _ in 0..ancount {
37            answers.push(ResourceRecord::deserialise(header.id, buffer)?);
38        }
39        for _ in 0..nscount {
40            authority.push(ResourceRecord::deserialise(header.id, buffer)?);
41        }
42        for _ in 0..arcount {
43            additional.push(ResourceRecord::deserialise(header.id, buffer)?);
44        }
45
46        Ok(Self {
47            header,
48            questions,
49            answers,
50            authority,
51            additional,
52        })
53    }
54}
55
56impl Header {
57    /// # Errors
58    ///
59    /// If the header is too short.
60    fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
61        let id = buffer.next_u16().ok_or(Error::CompletelyBusted)?;
62        let flags1 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
63        let flags2 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
64
65        Ok(Self {
66            id,
67            is_response: flags1 & HEADER_MASK_QR != 0,
68            opcode: Opcode::from((flags1 & HEADER_MASK_OPCODE) >> HEADER_OFFSET_OPCODE),
69            is_authoritative: flags1 & HEADER_MASK_AA != 0,
70            is_truncated: flags1 & HEADER_MASK_TC != 0,
71            recursion_desired: flags1 & HEADER_MASK_RD != 0,
72            recursion_available: flags2 & HEADER_MASK_RA != 0,
73            rcode: Rcode::from((flags2 & HEADER_MASK_RCODE) >> HEADER_OFFSET_RCODE),
74        })
75    }
76}
77
78impl Question {
79    /// # Errors
80    ///
81    /// If the question cannot be parsed.
82    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
83        let name = DomainName::deserialise(id, buffer)?;
84        let qtype = QueryType::deserialise(id, buffer)?;
85        let qclass = QueryClass::deserialise(id, buffer)?;
86
87        Ok(Self {
88            name,
89            qtype,
90            qclass,
91        })
92    }
93}
94
95impl ResourceRecord {
96    /// # Errors
97    ///
98    /// If the record cannot be parsed.
99    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
100        let name = DomainName::deserialise(id, buffer)?;
101        let rtype = RecordType::deserialise(id, buffer)?;
102        let rclass = RecordClass::deserialise(id, buffer)?;
103        let ttl = buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?;
104        let rdlength = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
105
106        let rdata_start = buffer.position;
107
108        let mut raw_rdata = || {
109            if let Some(octets) = buffer.take(rdlength as usize) {
110                Ok(Bytes::copy_from_slice(octets))
111            } else {
112                Err(Error::ResourceRecordTooShort(id))
113            }
114        };
115
116        // for records which include domain names, deserialise them to
117        // expand pointers.
118        let rtype_with_data = match rtype {
119            RecordType::A => RecordTypeWithData::A {
120                address: Ipv4Addr::from(
121                    buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
122                ),
123            },
124            RecordType::NS => RecordTypeWithData::NS {
125                nsdname: DomainName::deserialise(id, buffer)?,
126            },
127            RecordType::MD => RecordTypeWithData::MD {
128                madname: DomainName::deserialise(id, buffer)?,
129            },
130            RecordType::MF => RecordTypeWithData::MF {
131                madname: DomainName::deserialise(id, buffer)?,
132            },
133            RecordType::CNAME => RecordTypeWithData::CNAME {
134                cname: DomainName::deserialise(id, buffer)?,
135            },
136            RecordType::SOA => RecordTypeWithData::SOA {
137                mname: DomainName::deserialise(id, buffer)?,
138                rname: DomainName::deserialise(id, buffer)?,
139                serial: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
140                refresh: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
141                retry: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
142                expire: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
143                minimum: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
144            },
145            RecordType::MB => RecordTypeWithData::MB {
146                madname: DomainName::deserialise(id, buffer)?,
147            },
148            RecordType::MG => RecordTypeWithData::MG {
149                mdmname: DomainName::deserialise(id, buffer)?,
150            },
151            RecordType::MR => RecordTypeWithData::MR {
152                newname: DomainName::deserialise(id, buffer)?,
153            },
154            RecordType::NULL => RecordTypeWithData::NULL {
155                octets: raw_rdata()?,
156            },
157            RecordType::WKS => RecordTypeWithData::WKS {
158                octets: raw_rdata()?,
159            },
160            RecordType::PTR => RecordTypeWithData::PTR {
161                ptrdname: DomainName::deserialise(id, buffer)?,
162            },
163            RecordType::HINFO => RecordTypeWithData::HINFO {
164                octets: raw_rdata()?,
165            },
166            RecordType::MINFO => RecordTypeWithData::MINFO {
167                rmailbx: DomainName::deserialise(id, buffer)?,
168                emailbx: DomainName::deserialise(id, buffer)?,
169            },
170            RecordType::MX => RecordTypeWithData::MX {
171                preference: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
172                exchange: DomainName::deserialise(id, buffer)?,
173            },
174            RecordType::TXT => RecordTypeWithData::TXT {
175                octets: raw_rdata()?,
176            },
177            RecordType::AAAA => RecordTypeWithData::AAAA {
178                address: Ipv6Addr::new(
179                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
180                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
181                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
182                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
183                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
184                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
185                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
186                    buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
187                ),
188            },
189            RecordType::SRV => RecordTypeWithData::SRV {
190                priority: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
191                weight: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
192                port: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
193                target: DomainName::deserialise(id, buffer)?,
194            },
195            RecordType::Unknown(tag) => RecordTypeWithData::Unknown {
196                tag,
197                octets: raw_rdata()?,
198            },
199        };
200
201        let rdata_stop = buffer.position;
202
203        if rdata_stop == rdata_start + (rdlength as usize) {
204            Ok(Self {
205                name,
206                rtype_with_data,
207                rclass,
208                ttl,
209            })
210        } else {
211            Err(Error::ResourceRecordInvalid(id))
212        }
213    }
214}
215
216impl DomainName {
217    /// # Errors
218    ///
219    /// If the domain cannot be parsed.
220    #[allow(clippy::missing_panics_doc)]
221    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
222        let mut len = 0;
223        let mut labels = Vec::<Label>::with_capacity(5);
224        let start = buffer.position;
225
226        'outer: loop {
227            let size = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
228
229            if usize::from(size) <= LABEL_MAX_LEN {
230                len += 1;
231
232                if size == 0 {
233                    labels.push(Label::new());
234                    break 'outer;
235                }
236
237                if let Some(os) = buffer.take(size as usize) {
238                    // safe because of the bounds check above
239                    let label = Label::try_from(os).unwrap();
240                    len += label.len() as usize;
241                    labels.push(label);
242                } else {
243                    return Err(Error::DomainTooShort(id));
244                }
245
246                if len > DOMAINNAME_MAX_LEN {
247                    break 'outer;
248                }
249            } else if size >= 192 {
250                // this requires re-parsing the pointed-to domain -
251                // not great but works for now.
252                let hi = size & 0b0011_1111;
253                let lo = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
254                let ptr = u16::from_be_bytes([hi, lo]).into();
255
256                // pointer must be to an earlier record (not merely a
257                // different one: an earlier one: RFC 1035 section
258                // 4.1.4)
259                if ptr >= start {
260                    return Err(Error::DomainPointerInvalid(id));
261                }
262
263                let mut other = DomainName::deserialise(id, &mut buffer.at_offset(ptr))?;
264                len += other.len;
265                labels.append(&mut other.labels);
266                break 'outer;
267            } else {
268                return Err(Error::DomainLabelInvalid(id));
269            }
270        }
271
272        if len <= DOMAINNAME_MAX_LEN {
273            Ok(DomainName { labels, len })
274        } else {
275            Err(Error::DomainTooLong(id))
276        }
277    }
278}
279
280impl QueryType {
281    /// # Errors
282    ///
283    /// If the query type is too short.
284    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
285        let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
286        Ok(Self::from(value))
287    }
288}
289
290impl QueryClass {
291    /// # Errors
292    ///
293    /// If the query class is too short.
294    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
295        let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
296        Ok(Self::from(value))
297    }
298}
299
300impl RecordType {
301    /// # Errors
302    ///
303    /// If the record type is too short.
304    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
305        let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
306        Ok(Self::from(value))
307    }
308}
309
310impl RecordClass {
311    /// # Errors
312    ///
313    /// If the record class is too short.
314    fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
315        let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
316        Ok(Self::from(value))
317    }
318}
319
320/// Errors encountered when parsing a datagram.  In all the errors
321/// which have a `u16` parameter, that is the ID from the header - so
322/// that an error response can be sent.
323#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
324pub enum Error {
325    /// The datagram is not even 2 octets long, so it doesn't even
326    /// contain a valid ID.  An error cannot even be sent back to the
327    /// client in this case as, without an ID, it cannot be linked
328    /// with the correct query.
329    CompletelyBusted,
330
331    /// The header is missing one or more required fields.
332    HeaderTooShort(u16),
333
334    /// A question ends with an incomplete field.
335    QuestionTooShort(u16),
336
337    /// A resource record ends with an incomplete field.
338    ResourceRecordTooShort(u16),
339
340    /// A resource record is the wrong format.
341    ResourceRecordInvalid(u16),
342
343    /// A domain is incomplete.
344    DomainTooShort(u16),
345
346    /// A domain is over 255 octets in size.
347    DomainTooLong(u16),
348
349    /// A domain pointer points to or after the current record.
350    DomainPointerInvalid(u16),
351
352    /// A domain label is longer than 63 octets, but not a pointer.
353    DomainLabelInvalid(u16),
354}
355
356impl std::fmt::Display for Error {
357    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
358        match self {
359            Error::CompletelyBusted | Error::HeaderTooShort(_) => write!(f, "header too short"),
360            Error::QuestionTooShort(_) => write!(f, "question too short"),
361            Error::ResourceRecordTooShort(_) => write!(f, "resource record too short"),
362            Error::ResourceRecordInvalid(_) => write!(
363                f,
364                "resource record RDLENGTH field does not match parsed RDATA length"
365            ),
366            Error::DomainTooShort(_) => write!(f, "domain name too short"),
367            Error::DomainTooLong(_) => write!(f, "domain name too long"),
368            Error::DomainPointerInvalid(_) => write!(f, "domain name compression pointer invalid"),
369            Error::DomainLabelInvalid(_) => write!(f, "domain label invalid"),
370        }
371    }
372}
373
374impl std::error::Error for Error {
375    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
376        None
377    }
378}
379
380impl Error {
381    pub fn id(self) -> Option<u16> {
382        match self {
383            Error::CompletelyBusted => None,
384            Error::HeaderTooShort(id) => Some(id),
385            Error::QuestionTooShort(id) => Some(id),
386            Error::ResourceRecordTooShort(id) => Some(id),
387            Error::ResourceRecordInvalid(id) => Some(id),
388            Error::DomainTooShort(id) => Some(id),
389            Error::DomainTooLong(id) => Some(id),
390            Error::DomainPointerInvalid(id) => Some(id),
391            Error::DomainLabelInvalid(id) => Some(id),
392        }
393    }
394}
395
396/// A buffer which will be consumed by the parsing process.
397struct ConsumableBuffer<'a> {
398    octets: &'a [u8],
399    position: usize,
400}
401
402impl<'a> ConsumableBuffer<'a> {
403    fn new(octets: &'a [u8]) -> Self {
404        Self {
405            octets,
406            position: 0,
407        }
408    }
409
410    fn next_u8(&mut self) -> Option<u8> {
411        if self.octets.len() > self.position {
412            let a = self.octets[self.position];
413            self.position += 1;
414            Some(a)
415        } else {
416            None
417        }
418    }
419
420    fn next_u16(&mut self) -> Option<u16> {
421        if self.octets.len() > self.position + 1 {
422            let a = self.octets[self.position];
423            let b = self.octets[self.position + 1];
424            self.position += 2;
425            Some(u16::from_be_bytes([a, b]))
426        } else {
427            None
428        }
429    }
430
431    fn next_u32(&mut self) -> Option<u32> {
432        if self.octets.len() > self.position + 3 {
433            let a = self.octets[self.position];
434            let b = self.octets[self.position + 1];
435            let c = self.octets[self.position + 2];
436            let d = self.octets[self.position + 3];
437            self.position += 4;
438            Some(u32::from_be_bytes([a, b, c, d]))
439        } else {
440            None
441        }
442    }
443
444    fn take(&mut self, size: usize) -> Option<&'a [u8]> {
445        if self.octets.len() >= self.position + size {
446            let slice = &self.octets[self.position..self.position + size];
447            self.position += size;
448            Some(slice)
449        } else {
450            None
451        }
452    }
453
454    fn at_offset(&self, position: usize) -> ConsumableBuffer<'a> {
455        Self {
456            octets: self.octets,
457            position,
458        }
459    }
460}