1use bytes::{BufMut, BytesMut};
5use std::collections::HashMap;
6
7use crate::protocol::types::*;
8
9impl Message {
10 pub fn to_octets(&self) -> Result<BytesMut, Error> {
15 let mut buffer = WritableBuffer::default();
16 self.serialise(&mut buffer)?;
17 Ok(buffer.octets)
18 }
19
20 fn serialise(&self, buffer: &mut WritableBuffer) -> Result<(), Error> {
25 let qdcount = usize_to_u16(self.questions.len())?;
26 let ancount = usize_to_u16(self.answers.len())?;
27 let nscount = usize_to_u16(self.authority.len())?;
28 let arcount = usize_to_u16(self.additional.len())?;
29
30 self.header.serialise(buffer);
31 buffer.write_u16(qdcount);
32 buffer.write_u16(ancount);
33 buffer.write_u16(nscount);
34 buffer.write_u16(arcount);
35
36 for question in &self.questions {
37 question.serialise(buffer);
38 }
39 for rr in &self.answers {
40 rr.serialise(buffer)?;
41 }
42 for rr in &self.authority {
43 rr.serialise(buffer)?;
44 }
45 for rr in &self.additional {
46 rr.serialise(buffer)?;
47 }
48
49 Ok(())
50 }
51}
52
53impl Header {
54 fn serialise(&self, buffer: &mut WritableBuffer) {
55 let flag_qr = if self.is_response { HEADER_MASK_QR } else { 0 };
57 let field_opcode = HEADER_MASK_OPCODE & (u8::from(self.opcode) << HEADER_OFFSET_OPCODE);
58 let flag_aa = if self.is_authoritative {
59 HEADER_MASK_AA
60 } else {
61 0
62 };
63 let flag_tc = if self.is_truncated { HEADER_MASK_TC } else { 0 };
64 let flag_rd = if self.recursion_desired {
65 HEADER_MASK_RD
66 } else {
67 0
68 };
69 let flag_ra = if self.recursion_available {
71 HEADER_MASK_RA
72 } else {
73 0
74 };
75 let field_rcode = HEADER_MASK_RCODE & (u8::from(self.rcode) << HEADER_OFFSET_RCODE);
76
77 buffer.write_u16(self.id);
78 buffer.write_u8(flag_qr | field_opcode | flag_aa | flag_tc | flag_rd);
79 buffer.write_u8(flag_ra | field_rcode);
80 }
81}
82
83impl Question {
84 fn serialise(&self, buffer: &mut WritableBuffer) {
85 self.name.serialise(buffer, true);
86 self.qtype.serialise(buffer);
87 self.qclass.serialise(buffer);
88 }
89}
90
91impl ResourceRecord {
92 fn serialise(&self, buffer: &mut WritableBuffer) -> Result<(), Error> {
96 self.name.serialise(buffer, true);
97 self.rtype_with_data.rtype().serialise(buffer);
98 self.rclass.serialise(buffer);
99 buffer.write_u32(self.ttl);
100
101 let rdlength_index = buffer.index();
103 buffer.write_u16(0);
104
105 match &self.rtype_with_data {
106 RecordTypeWithData::A { address } => buffer.write_octets(&address.octets()),
107 RecordTypeWithData::NS { nsdname } => nsdname.serialise(buffer, false),
108 RecordTypeWithData::MD { madname } => madname.serialise(buffer, false),
109 RecordTypeWithData::MF { madname } => madname.serialise(buffer, false),
110 RecordTypeWithData::CNAME { cname } => cname.serialise(buffer, false),
111 RecordTypeWithData::SOA {
112 mname,
113 rname,
114 serial,
115 refresh,
116 retry,
117 expire,
118 minimum,
119 } => {
120 mname.serialise(buffer, false);
121 rname.serialise(buffer, false);
122 buffer.write_u32(*serial);
123 buffer.write_u32(*refresh);
124 buffer.write_u32(*retry);
125 buffer.write_u32(*expire);
126 buffer.write_u32(*minimum);
127 }
128 RecordTypeWithData::MB { madname } => madname.serialise(buffer, false),
129 RecordTypeWithData::MG { mdmname } => mdmname.serialise(buffer, false),
130 RecordTypeWithData::MR { newname } => newname.serialise(buffer, false),
131 RecordTypeWithData::NULL { octets } => buffer.write_octets(octets),
132 RecordTypeWithData::WKS { octets } => buffer.write_octets(octets),
133 RecordTypeWithData::PTR { ptrdname } => ptrdname.serialise(buffer, false),
134 RecordTypeWithData::HINFO { octets } => buffer.write_octets(octets),
135 RecordTypeWithData::MINFO { rmailbx, emailbx } => {
136 rmailbx.serialise(buffer, false);
137 emailbx.serialise(buffer, false);
138 }
139 RecordTypeWithData::MX {
140 preference,
141 exchange,
142 } => {
143 buffer.write_u16(*preference);
144 exchange.serialise(buffer, false);
145 }
146 RecordTypeWithData::TXT { octets } => buffer.write_octets(octets),
147 RecordTypeWithData::AAAA { address } => buffer.write_octets(&address.octets()),
148 RecordTypeWithData::SRV {
149 priority,
150 weight,
151 port,
152 target,
153 } => {
154 buffer.write_u16(*priority);
155 buffer.write_u16(*weight);
156 buffer.write_u16(*port);
157 target.serialise(buffer, false);
158 }
159 RecordTypeWithData::Unknown { octets, .. } => buffer.write_octets(octets),
160 }
161
162 let rdlength = usize_to_u16(buffer.index() - rdlength_index - 2)?;
164 let [hi, lo] = rdlength.to_be_bytes();
165 buffer.octets[rdlength_index] = hi;
166 buffer.octets[rdlength_index + 1] = lo;
167
168 Ok(())
169 }
170}
171
172impl DomainName {
173 fn serialise(&self, buffer: &mut WritableBuffer, compress: bool) {
174 if compress {
175 if let Some(ptr) = buffer.name_pointer(self) {
176 buffer.write_u16(ptr);
177 return;
178 }
179 }
180
181 buffer.memoise_name(self);
182 for label in &self.labels {
183 buffer.write_u8(label.len());
184 buffer.write_octets(label.octets());
185 }
186 }
187}
188
189impl QueryType {
190 fn serialise(self, buffer: &mut WritableBuffer) {
191 buffer.write_u16(self.into());
192 }
193}
194
195impl QueryClass {
196 fn serialise(self, buffer: &mut WritableBuffer) {
197 buffer.write_u16(self.into());
198 }
199}
200
201impl RecordType {
202 fn serialise(self, buffer: &mut WritableBuffer) {
203 buffer.write_u16(self.into());
204 }
205}
206
207impl RecordClass {
208 fn serialise(self, buffer: &mut WritableBuffer) {
209 buffer.write_u16(self.into());
210 }
211}
212
213#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
215pub enum Error {
216 CounterTooLarge { counter: usize, bits: u32 },
218}
219
220impl std::fmt::Display for Error {
221 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
222 match self {
223 Error::CounterTooLarge { counter, bits } => {
224 write!(f, "'{counter}' cannot be converted to a u{bits}")
225 }
226 }
227 }
228}
229
230impl std::error::Error for Error {
231 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232 None
233 }
234}
235
236struct WritableBuffer {
238 octets: BytesMut,
239 name_pointers: HashMap<DomainName, u16>,
240}
241
242impl Default for WritableBuffer {
243 fn default() -> Self {
244 Self {
245 octets: BytesMut::with_capacity(512),
246 name_pointers: HashMap::new(),
247 }
248 }
249}
250
251impl WritableBuffer {
252 fn index(&self) -> usize {
253 self.octets.len()
254 }
255
256 fn memoise_name(&mut self, name: &DomainName) {
257 if !name.is_root() && !self.name_pointers.contains_key(name) {
258 if let Ok(index) = u16::try_from(self.index()) {
259 let [hi, lo] = index.to_be_bytes();
260 self.name_pointers
261 .insert(name.clone(), u16::from_be_bytes([hi | 0b1100_0000, lo]));
262 }
263 }
264 }
265
266 fn name_pointer(&self, name: &DomainName) -> Option<u16> {
267 self.name_pointers.get(name).copied()
268 }
269
270 fn write_u8(&mut self, octet: u8) {
271 self.octets.put_u8(octet);
272 }
273
274 fn write_u16(&mut self, value: u16) {
275 self.write_octets(&value.to_be_bytes());
276 }
277
278 fn write_u32(&mut self, value: u32) {
279 self.write_octets(&value.to_be_bytes());
280 }
281
282 fn write_octets(&mut self, octets: &[u8]) {
283 self.octets.put_slice(octets);
284 }
285}
286
287fn usize_to_u16(counter: usize) -> Result<u16, Error> {
293 if let Ok(t) = u16::try_from(counter) {
294 Ok(t)
295 } else {
296 Err(Error::CounterTooLarge {
297 counter,
298 bits: u16::BITS,
299 })
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::protocol::types::test_util::*;
307
308 #[test]
309 #[rustfmt::skip]
310 fn test_name_compression_opt_in() {
311 let mut buf = WritableBuffer::default();
312 buf.write_u8(1);
313 buf.write_u8(2);
314 buf.write_u8(3);
315 buf.write_u8(4);
316 domain("www.example.com.").serialise(&mut buf, true);
317 domain("www.example.com.").serialise(&mut buf, true);
318
319 assert_eq!(
320 vec![
321 1, 2, 3, 4,
322 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0b1100_0000, 0b0000_0100 ],
329 buf.octets,
330 );
331 }
332
333 #[test]
334 #[rustfmt::skip]
335 fn test_name_compression_opt_out() {
336 let mut buf = WritableBuffer::default();
337 buf.write_u8(1);
338 buf.write_u8(2);
339 buf.write_u8(3);
340 buf.write_u8(4);
341 domain("www.example.com.").serialise(&mut buf, true);
342 domain("www.example.com.").serialise(&mut buf, false);
343
344 assert_eq!(
345 vec![
346 1, 2, 3, 4,
347 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, ],
356 buf.octets,
357 );
358 }
359
360 #[test]
361 #[rustfmt::skip]
362 fn test_name_compression_records() {
363 let mut buf = WritableBuffer::default();
364 buf.write_u8(1);
365 buf.write_u8(2);
366 buf.write_u8(3);
367 buf.write_u8(4);
368
369 Question {
370 name: domain("www.example.com."),
371 qtype: QueryType::Wildcard,
372 qclass: QueryClass::Wildcard,
373 }.serialise(&mut buf);
374
375 let _ = ResourceRecord {
376 name: domain("www.example.com."),
377 rtype_with_data: RecordTypeWithData::MX {
378 preference: 32,
379 exchange: domain("mx.example.com."),
380 },
381 rclass: RecordClass::IN,
382 ttl: 300,
383 }.serialise(&mut buf);
384
385 let _ = ResourceRecord {
386 name: domain("mx.example.com."),
387 rtype_with_data: RecordTypeWithData::CNAME {
388 cname: domain("www.example.com."),
389 },
390 rclass: RecordClass::IN,
391 ttl: 300,
392 }.serialise(&mut buf);
393
394 assert_eq!(
395 vec![
396 1, 2, 3, 4,
397 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0, 255,
403 0, 255,
405 0b1100_0000, 0b0000_0100, 0b0000_0000, 0b0000_1111, 0b0000_0000, 0b0000_0001, 0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, 0b0000_0000, 0b0001_0010, 0, 32, 2, 109, 120, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0b1100_0000, 0b0010_0111, 0b0000_0000, 0b0000_0101, 0b0000_0000, 0b0000_0001, 0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, 0b0000_0000, 0b0001_0001, 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, ],
435 buf.octets,
436 );
437 }
438
439 #[test]
440 #[rustfmt::skip]
441 fn test_sets_rdlength() {
442 let mut buf = WritableBuffer::default();
443 buf.write_u8(1);
444 buf.write_u8(2);
445 buf.write_u8(3);
446 buf.write_u8(4);
447
448 let rr = ResourceRecord {
449 name: domain("www.example.com."),
450 rtype_with_data: RecordTypeWithData::MX {
451 preference: 32,
452 exchange: domain("mx.example.com."),
453 },
454 rclass: RecordClass::IN,
455 ttl: 300,
456 };
457 let _ = rr.serialise(&mut buf);
458
459 assert_eq!(
460 vec![
461 1, 2, 3, 4,
462 3, 119, 119, 119, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0b0000_0000, 0b0000_1111, 0b0000_0000, 0b0000_0001, 0b0000_0000, 0b0000_0000, 0b0000_0001, 0b0010_1100, 0b0000_0000, 0b0001_0010, 0, 32, 2, 109, 120, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, ],
480 buf.octets,
481 );
482 }
483}