dns_types/protocol/
serialise.rs

1//! Serialisation of DNS messages to the wire format.  See the `types`
2//! module for details of the format.
3
4use bytes::{BufMut, BytesMut};
5use std::collections::HashMap;
6
7use crate::protocol::types::*;
8
9impl Message {
10    /// # Errors
11    ///
12    /// If the message is invalid (the `Message` type permits more
13    /// states than strictly allowed).
14    pub fn to_octets(&self) -> Result<BytesMut, Error> {
15        let mut buffer = WritableBuffer::default();
16        self.serialise(&mut buffer)?;
17        Ok(buffer.octets)
18    }
19
20    /// # Errors
21    ///
22    /// If the message is invalid (the `Message` type permits more
23    /// states than strictly allowed).
24    fn serialise(&self, buffer: &mut WritableBuffer) -> Result<(), Error> {
25        let qdcount = usize_to_u16(self.questions.len())?;
26        let ancount = usize_to_u16(self.answers.len())?;
27        let nscount = usize_to_u16(self.authority.len())?;
28        let arcount = usize_to_u16(self.additional.len())?;
29
30        self.header.serialise(buffer);
31        buffer.write_u16(qdcount);
32        buffer.write_u16(ancount);
33        buffer.write_u16(nscount);
34        buffer.write_u16(arcount);
35
36        for question in &self.questions {
37            question.serialise(buffer);
38        }
39        for rr in &self.answers {
40            rr.serialise(buffer)?;
41        }
42        for rr in &self.authority {
43            rr.serialise(buffer)?;
44        }
45        for rr in &self.additional {
46            rr.serialise(buffer)?;
47        }
48
49        Ok(())
50    }
51}
52
53impl Header {
54    fn serialise(&self, buffer: &mut WritableBuffer) {
55        // octet 1
56        let flag_qr = if self.is_response { HEADER_MASK_QR } else { 0 };
57        let field_opcode = HEADER_MASK_OPCODE & (u8::from(self.opcode) << HEADER_OFFSET_OPCODE);
58        let flag_aa = if self.is_authoritative {
59            HEADER_MASK_AA
60        } else {
61            0
62        };
63        let flag_tc = if self.is_truncated { HEADER_MASK_TC } else { 0 };
64        let flag_rd = if self.recursion_desired {
65            HEADER_MASK_RD
66        } else {
67            0
68        };
69        // octet 2
70        let flag_ra = if self.recursion_available {
71            HEADER_MASK_RA
72        } else {
73            0
74        };
75        let field_rcode = HEADER_MASK_RCODE & (u8::from(self.rcode) << HEADER_OFFSET_RCODE);
76
77        buffer.write_u16(self.id);
78        buffer.write_u8(flag_qr | field_opcode | flag_aa | flag_tc | flag_rd);
79        buffer.write_u8(flag_ra | field_rcode);
80    }
81}
82
83impl Question {
84    fn serialise(&self, buffer: &mut WritableBuffer) {
85        self.name.serialise(buffer, true);
86        self.qtype.serialise(buffer);
87        self.qclass.serialise(buffer);
88    }
89}
90
91impl ResourceRecord {
92    /// # Errors
93    ///
94    /// If the RDATA is too long.
95    fn serialise(&self, buffer: &mut WritableBuffer) -> Result<(), Error> {
96        self.name.serialise(buffer, true);
97        self.rtype_with_data.rtype().serialise(buffer);
98        self.rclass.serialise(buffer);
99        buffer.write_u32(self.ttl);
100
101        // filled in below
102        let rdlength_index = buffer.index();
103        buffer.write_u16(0);
104
105        match &self.rtype_with_data {
106            RecordTypeWithData::A { address } => buffer.write_octets(&address.octets()),
107            RecordTypeWithData::NS { nsdname } => nsdname.serialise(buffer, false),
108            RecordTypeWithData::MD { madname } => madname.serialise(buffer, false),
109            RecordTypeWithData::MF { madname } => madname.serialise(buffer, false),
110            RecordTypeWithData::CNAME { cname } => cname.serialise(buffer, false),
111            RecordTypeWithData::SOA {
112                mname,
113                rname,
114                serial,
115                refresh,
116                retry,
117                expire,
118                minimum,
119            } => {
120                mname.serialise(buffer, false);
121                rname.serialise(buffer, false);
122                buffer.write_u32(*serial);
123                buffer.write_u32(*refresh);
124                buffer.write_u32(*retry);
125                buffer.write_u32(*expire);
126                buffer.write_u32(*minimum);
127            }
128            RecordTypeWithData::MB { madname } => madname.serialise(buffer, false),
129            RecordTypeWithData::MG { mdmname } => mdmname.serialise(buffer, false),
130            RecordTypeWithData::MR { newname } => newname.serialise(buffer, false),
131            RecordTypeWithData::NULL { octets } => buffer.write_octets(octets),
132            RecordTypeWithData::WKS { octets } => buffer.write_octets(octets),
133            RecordTypeWithData::PTR { ptrdname } => ptrdname.serialise(buffer, false),
134            RecordTypeWithData::HINFO { octets } => buffer.write_octets(octets),
135            RecordTypeWithData::MINFO { rmailbx, emailbx } => {
136                rmailbx.serialise(buffer, false);
137                emailbx.serialise(buffer, false);
138            }
139            RecordTypeWithData::MX {
140                preference,
141                exchange,
142            } => {
143                buffer.write_u16(*preference);
144                exchange.serialise(buffer, false);
145            }
146            RecordTypeWithData::TXT { octets } => buffer.write_octets(octets),
147            RecordTypeWithData::AAAA { address } => buffer.write_octets(&address.octets()),
148            RecordTypeWithData::SRV {
149                priority,
150                weight,
151                port,
152                target,
153            } => {
154                buffer.write_u16(*priority);
155                buffer.write_u16(*weight);
156                buffer.write_u16(*port);
157                target.serialise(buffer, false);
158            }
159            RecordTypeWithData::Unknown { octets, .. } => buffer.write_octets(octets),
160        }
161
162        // -2 so we don't also include the 2 octets for the rdlength
163        let rdlength = usize_to_u16(buffer.index() - rdlength_index - 2)?;
164        let [hi, lo] = rdlength.to_be_bytes();
165        buffer.octets[rdlength_index] = hi;
166        buffer.octets[rdlength_index + 1] = lo;
167
168        Ok(())
169    }
170}
171
172impl DomainName {
173    fn serialise(&self, buffer: &mut WritableBuffer, compress: bool) {
174        if compress {
175            if let Some(ptr) = buffer.name_pointer(self) {
176                buffer.write_u16(ptr);
177                return;
178            }
179        }
180
181        buffer.memoise_name(self);
182        for label in &self.labels {
183            buffer.write_u8(label.len());
184            buffer.write_octets(label.octets());
185        }
186    }
187}
188
189impl QueryType {
190    fn serialise(self, buffer: &mut WritableBuffer) {
191        buffer.write_u16(self.into());
192    }
193}
194
195impl QueryClass {
196    fn serialise(self, buffer: &mut WritableBuffer) {
197        buffer.write_u16(self.into());
198    }
199}
200
201impl RecordType {
202    fn serialise(self, buffer: &mut WritableBuffer) {
203        buffer.write_u16(self.into());
204    }
205}
206
207impl RecordClass {
208    fn serialise(self, buffer: &mut WritableBuffer) {
209        buffer.write_u16(self.into());
210    }
211}
212
213/// Errors encountered when serialising a message.
214#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
215pub enum Error {
216    /// A counter does not fit in the desired width.
217    CounterTooLarge { counter: usize, bits: u32 },
218}
219
220impl std::fmt::Display for Error {
221    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
222        match self {
223            Error::CounterTooLarge { counter, bits } => {
224                write!(f, "'{counter}' cannot be converted to a u{bits}")
225            }
226        }
227    }
228}
229
230impl std::error::Error for Error {
231    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232        None
233    }
234}
235
236/// A buffer which can be written to, for serialisation purposes.
237struct WritableBuffer {
238    octets: BytesMut,
239    name_pointers: HashMap<DomainName, u16>,
240}
241
242impl Default for WritableBuffer {
243    fn default() -> Self {
244        Self {
245            octets: BytesMut::with_capacity(512),
246            name_pointers: HashMap::new(),
247        }
248    }
249}
250
251impl WritableBuffer {
252    fn index(&self) -> usize {
253        self.octets.len()
254    }
255
256    fn memoise_name(&mut self, name: &DomainName) {
257        if !name.is_root() && !self.name_pointers.contains_key(name) {
258            if let Ok(index) = u16::try_from(self.index()) {
259                let [hi, lo] = index.to_be_bytes();
260                self.name_pointers
261                    .insert(name.clone(), u16::from_be_bytes([hi | 0b1100_0000, lo]));
262            }
263        }
264    }
265
266    fn name_pointer(&self, name: &DomainName) -> Option<u16> {
267        self.name_pointers.get(name).copied()
268    }
269
270    fn write_u8(&mut self, octet: u8) {
271        self.octets.put_u8(octet);
272    }
273
274    fn write_u16(&mut self, value: u16) {
275        self.write_octets(&value.to_be_bytes());
276    }
277
278    fn write_u32(&mut self, value: u32) {
279        self.write_octets(&value.to_be_bytes());
280    }
281
282    fn write_octets(&mut self, octets: &[u8]) {
283        self.octets.put_slice(octets);
284    }
285}
286
287/// Helper function to convert a `usize` into a `u16` (or return an error).
288///
289/// # Errors
290///
291/// If the value cannot be converted.
292fn usize_to_u16(counter: usize) -> Result<u16, Error> {
293    if let Ok(t) = u16::try_from(counter) {
294        Ok(t)
295    } else {
296        Err(Error::CounterTooLarge {
297            counter,
298            bits: u16::BITS,
299        })
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::protocol::types::test_util::*;
307
308    #[test]
309    #[rustfmt::skip]
310    fn test_name_compression_opt_in() {
311        let mut buf = WritableBuffer::default();
312        buf.write_u8(1);
313        buf.write_u8(2);
314        buf.write_u8(3);
315        buf.write_u8(4);
316        domain("www.example.com.").serialise(&mut buf, true);
317        domain("www.example.com.").serialise(&mut buf, true);
318
319        assert_eq!(
320            vec![
321                1, 2, 3, 4,
322                // domain 1
323                3, 119, 119, 119, // "www"
324                7, 101, 120, 97, 109, 112, 108, 101, // "example"
325                3, 99, 111, 109, 0, // "com"
326                // domain 2
327                0b1100_0000, 0b0000_0100 // pointer
328            ],
329            buf.octets,
330        );
331    }
332
333    #[test]
334    #[rustfmt::skip]
335    fn test_name_compression_opt_out() {
336        let mut buf = WritableBuffer::default();
337        buf.write_u8(1);
338        buf.write_u8(2);
339        buf.write_u8(3);
340        buf.write_u8(4);
341        domain("www.example.com.").serialise(&mut buf, true);
342        domain("www.example.com.").serialise(&mut buf, false);
343
344        assert_eq!(
345            vec![
346                1, 2, 3, 4,
347                // domain 1
348                3, 119, 119, 119, // "www"
349                7, 101, 120, 97, 109, 112, 108, 101, // "example"
350                3, 99, 111, 109, 0, // "com"
351                // domain 2
352                3, 119, 119, 119, // "www"
353                7, 101, 120, 97, 109, 112, 108, 101, // "example"
354                3, 99, 111, 109, 0, // "com"
355            ],
356            buf.octets,
357        );
358    }
359
360    #[test]
361    #[rustfmt::skip]
362    fn test_name_compression_records() {
363        let mut buf = WritableBuffer::default();
364        buf.write_u8(1);
365        buf.write_u8(2);
366        buf.write_u8(3);
367        buf.write_u8(4);
368
369        Question {
370            name: domain("www.example.com."),
371            qtype: QueryType::Wildcard,
372            qclass: QueryClass::Wildcard,
373        }.serialise(&mut buf);
374
375        let _ = ResourceRecord {
376            name: domain("www.example.com."),
377            rtype_with_data: RecordTypeWithData::MX {
378                preference: 32,
379                exchange: domain("mx.example.com."),
380            },
381            rclass: RecordClass::IN,
382            ttl: 300,
383        }.serialise(&mut buf);
384
385        let _ = ResourceRecord {
386            name: domain("mx.example.com."),
387            rtype_with_data: RecordTypeWithData::CNAME {
388                cname: domain("www.example.com."),
389            },
390            rclass: RecordClass::IN,
391            ttl: 300,
392        }.serialise(&mut buf);
393
394        assert_eq!(
395            vec![
396                1, 2, 3, 4,
397                // QNAME
398                3, 119, 119, 119, // "www"
399                7, 101, 120, 97, 109, 112, 108, 101, // "example"
400                3, 99, 111, 109, 0, // "com"
401                // QTYPE
402                0, 255,
403                // QCLASS
404                0, 255,
405                // NAME
406                0b1100_0000, 0b0000_0100, // pointer to "www.example.com"
407                // TYPE
408                0b0000_0000, 0b0000_1111, // MX
409                // CLASS
410                0b0000_0000, 0b0000_0001, // IN
411                // TTL
412                0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, // 300
413                // RDLENGTH
414                0b0000_0000, 0b0001_0010, // 18 octets
415                // RDATA
416                0, 32, // preference
417                2, 109, 120, // "mx"
418                7, 101, 120, 97, 109, 112, 108, 101, // "example"
419                3, 99, 111, 109, 0, // "com"
420                // NAME
421                0b1100_0000, 0b0010_0111, // pointer to "mx.example.com"
422                // TYPE
423                0b0000_0000, 0b0000_0101, // CNAME
424                // CLASS
425                0b0000_0000, 0b0000_0001, // IN
426                // TTL
427                0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, // 300
428                // RDLENGTH
429                0b0000_0000, 0b0001_0001, // 17 octets
430                // RDATA
431                3, 119, 119, 119, // "www"
432                7, 101, 120, 97, 109, 112, 108, 101, // "example"
433                3, 99, 111, 109, 0, // "com"
434            ],
435            buf.octets,
436        );
437    }
438
439    #[test]
440    #[rustfmt::skip]
441    fn test_sets_rdlength() {
442        let mut buf = WritableBuffer::default();
443        buf.write_u8(1);
444        buf.write_u8(2);
445        buf.write_u8(3);
446        buf.write_u8(4);
447
448        let rr = ResourceRecord {
449            name: domain("www.example.com."),
450            rtype_with_data: RecordTypeWithData::MX {
451                preference: 32,
452                exchange: domain("mx.example.com."),
453            },
454            rclass: RecordClass::IN,
455            ttl: 300,
456        };
457        let _ = rr.serialise(&mut buf);
458
459        assert_eq!(
460            vec![
461                1, 2, 3, 4,
462                // NAME
463                3, 119, 119, 119, // "www"
464                7, 101, 120, 97, 109, 112, 108, 101, // "example"
465                3, 99, 111, 109, 0, // "com"
466                // TYPE
467                0b0000_0000, 0b0000_1111, // MX
468                // CLASS
469                0b0000_0000, 0b0000_0001, // IN
470                // TTL
471                0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, // 300
472                // RDLENGTH
473                0b0000_0000, 0b0001_0010, // 18 octets
474                // RDATA
475                0, 32, // preference
476                2, 109, 120, // "mx"
477                7, 101, 120, 97, 109, 112, 108, 101, // "example"
478                3, 99, 111, 109, 0, // "com"
479            ],
480            buf.octets,
481        );
482    }
483}