1use bytes::Bytes;
6use std::net::{Ipv4Addr, Ipv6Addr};
7
8use crate::protocol::types::*;
9
10impl Message {
11 pub fn from_octets(octets: &[u8]) -> Result<Self, Error> {
15 Self::deserialise(&mut ConsumableBuffer::new(octets))
16 }
17
18 fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
22 let header = Header::deserialise(buffer)?;
23 let qdcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
24 let ancount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
25 let nscount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
26 let arcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
27
28 let mut questions = Vec::with_capacity(qdcount.into());
29 let mut answers = Vec::with_capacity(ancount.into());
30 let mut authority = Vec::with_capacity(nscount.into());
31 let mut additional = Vec::with_capacity(arcount.into());
32
33 for _ in 0..qdcount {
34 questions.push(Question::deserialise(header.id, buffer)?);
35 }
36 for _ in 0..ancount {
37 answers.push(ResourceRecord::deserialise(header.id, buffer)?);
38 }
39 for _ in 0..nscount {
40 authority.push(ResourceRecord::deserialise(header.id, buffer)?);
41 }
42 for _ in 0..arcount {
43 additional.push(ResourceRecord::deserialise(header.id, buffer)?);
44 }
45
46 Ok(Self {
47 header,
48 questions,
49 answers,
50 authority,
51 additional,
52 })
53 }
54}
55
56impl Header {
57 fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
61 let id = buffer.next_u16().ok_or(Error::CompletelyBusted)?;
62 let flags1 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
63 let flags2 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
64
65 Ok(Self {
66 id,
67 is_response: flags1 & HEADER_MASK_QR != 0,
68 opcode: Opcode::from((flags1 & HEADER_MASK_OPCODE) >> HEADER_OFFSET_OPCODE),
69 is_authoritative: flags1 & HEADER_MASK_AA != 0,
70 is_truncated: flags1 & HEADER_MASK_TC != 0,
71 recursion_desired: flags1 & HEADER_MASK_RD != 0,
72 recursion_available: flags2 & HEADER_MASK_RA != 0,
73 rcode: Rcode::from((flags2 & HEADER_MASK_RCODE) >> HEADER_OFFSET_RCODE),
74 })
75 }
76}
77
78impl Question {
79 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
83 let name = DomainName::deserialise(id, buffer)?;
84 let qtype = QueryType::deserialise(id, buffer)?;
85 let qclass = QueryClass::deserialise(id, buffer)?;
86
87 Ok(Self {
88 name,
89 qtype,
90 qclass,
91 })
92 }
93}
94
95impl ResourceRecord {
96 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
100 let name = DomainName::deserialise(id, buffer)?;
101 let rtype = RecordType::deserialise(id, buffer)?;
102 let rclass = RecordClass::deserialise(id, buffer)?;
103 let ttl = buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?;
104 let rdlength = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
105
106 let rdata_start = buffer.position;
107
108 let mut raw_rdata = || {
109 if let Some(octets) = buffer.take(rdlength as usize) {
110 Ok(Bytes::copy_from_slice(octets))
111 } else {
112 Err(Error::ResourceRecordTooShort(id))
113 }
114 };
115
116 let rtype_with_data = match rtype {
119 RecordType::A => RecordTypeWithData::A {
120 address: Ipv4Addr::from(
121 buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
122 ),
123 },
124 RecordType::NS => RecordTypeWithData::NS {
125 nsdname: DomainName::deserialise(id, buffer)?,
126 },
127 RecordType::MD => RecordTypeWithData::MD {
128 madname: DomainName::deserialise(id, buffer)?,
129 },
130 RecordType::MF => RecordTypeWithData::MF {
131 madname: DomainName::deserialise(id, buffer)?,
132 },
133 RecordType::CNAME => RecordTypeWithData::CNAME {
134 cname: DomainName::deserialise(id, buffer)?,
135 },
136 RecordType::SOA => RecordTypeWithData::SOA {
137 mname: DomainName::deserialise(id, buffer)?,
138 rname: DomainName::deserialise(id, buffer)?,
139 serial: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
140 refresh: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
141 retry: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
142 expire: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
143 minimum: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
144 },
145 RecordType::MB => RecordTypeWithData::MB {
146 madname: DomainName::deserialise(id, buffer)?,
147 },
148 RecordType::MG => RecordTypeWithData::MG {
149 mdmname: DomainName::deserialise(id, buffer)?,
150 },
151 RecordType::MR => RecordTypeWithData::MR {
152 newname: DomainName::deserialise(id, buffer)?,
153 },
154 RecordType::NULL => RecordTypeWithData::NULL {
155 octets: raw_rdata()?,
156 },
157 RecordType::WKS => RecordTypeWithData::WKS {
158 octets: raw_rdata()?,
159 },
160 RecordType::PTR => RecordTypeWithData::PTR {
161 ptrdname: DomainName::deserialise(id, buffer)?,
162 },
163 RecordType::HINFO => RecordTypeWithData::HINFO {
164 octets: raw_rdata()?,
165 },
166 RecordType::MINFO => RecordTypeWithData::MINFO {
167 rmailbx: DomainName::deserialise(id, buffer)?,
168 emailbx: DomainName::deserialise(id, buffer)?,
169 },
170 RecordType::MX => RecordTypeWithData::MX {
171 preference: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
172 exchange: DomainName::deserialise(id, buffer)?,
173 },
174 RecordType::TXT => RecordTypeWithData::TXT {
175 octets: raw_rdata()?,
176 },
177 RecordType::AAAA => RecordTypeWithData::AAAA {
178 address: Ipv6Addr::new(
179 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
180 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
181 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
182 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
183 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
184 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
185 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
186 buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
187 ),
188 },
189 RecordType::SRV => RecordTypeWithData::SRV {
190 priority: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
191 weight: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
192 port: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
193 target: DomainName::deserialise(id, buffer)?,
194 },
195 RecordType::Unknown(tag) => RecordTypeWithData::Unknown {
196 tag,
197 octets: raw_rdata()?,
198 },
199 };
200
201 let rdata_stop = buffer.position;
202
203 if rdata_stop == rdata_start + (rdlength as usize) {
204 Ok(Self {
205 name,
206 rtype_with_data,
207 rclass,
208 ttl,
209 })
210 } else {
211 Err(Error::ResourceRecordInvalid(id))
212 }
213 }
214}
215
216impl DomainName {
217 #[allow(clippy::missing_panics_doc)]
221 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
222 let mut len = 0;
223 let mut labels = Vec::<Label>::with_capacity(5);
224 let start = buffer.position;
225
226 'outer: loop {
227 let size = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
228
229 if usize::from(size) <= LABEL_MAX_LEN {
230 len += 1;
231
232 if size == 0 {
233 labels.push(Label::new());
234 break 'outer;
235 }
236
237 if let Some(os) = buffer.take(size as usize) {
238 let label = Label::try_from(os).unwrap();
240 len += label.len() as usize;
241 labels.push(label);
242 } else {
243 return Err(Error::DomainTooShort(id));
244 }
245
246 if len > DOMAINNAME_MAX_LEN {
247 break 'outer;
248 }
249 } else if size >= 192 {
250 let hi = size & 0b0011_1111;
253 let lo = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
254 let ptr = u16::from_be_bytes([hi, lo]).into();
255
256 if ptr >= start {
260 return Err(Error::DomainPointerInvalid(id));
261 }
262
263 let mut other = DomainName::deserialise(id, &mut buffer.at_offset(ptr))?;
264 len += other.len;
265 labels.append(&mut other.labels);
266 break 'outer;
267 } else {
268 return Err(Error::DomainLabelInvalid(id));
269 }
270 }
271
272 if len <= DOMAINNAME_MAX_LEN {
273 Ok(DomainName { labels, len })
274 } else {
275 Err(Error::DomainTooLong(id))
276 }
277 }
278}
279
280impl QueryType {
281 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
285 let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
286 Ok(Self::from(value))
287 }
288}
289
290impl QueryClass {
291 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
295 let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
296 Ok(Self::from(value))
297 }
298}
299
300impl RecordType {
301 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
305 let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
306 Ok(Self::from(value))
307 }
308}
309
310impl RecordClass {
311 fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
315 let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
316 Ok(Self::from(value))
317 }
318}
319
320#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
324pub enum Error {
325 CompletelyBusted,
330
331 HeaderTooShort(u16),
333
334 QuestionTooShort(u16),
336
337 ResourceRecordTooShort(u16),
339
340 ResourceRecordInvalid(u16),
342
343 DomainTooShort(u16),
345
346 DomainTooLong(u16),
348
349 DomainPointerInvalid(u16),
351
352 DomainLabelInvalid(u16),
354}
355
356impl std::fmt::Display for Error {
357 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
358 match self {
359 Error::CompletelyBusted | Error::HeaderTooShort(_) => write!(f, "header too short"),
360 Error::QuestionTooShort(_) => write!(f, "question too short"),
361 Error::ResourceRecordTooShort(_) => write!(f, "resource record too short"),
362 Error::ResourceRecordInvalid(_) => write!(
363 f,
364 "resource record RDLENGTH field does not match parsed RDATA length"
365 ),
366 Error::DomainTooShort(_) => write!(f, "domain name too short"),
367 Error::DomainTooLong(_) => write!(f, "domain name too long"),
368 Error::DomainPointerInvalid(_) => write!(f, "domain name compression pointer invalid"),
369 Error::DomainLabelInvalid(_) => write!(f, "domain label invalid"),
370 }
371 }
372}
373
374impl std::error::Error for Error {
375 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
376 None
377 }
378}
379
380impl Error {
381 pub fn id(self) -> Option<u16> {
382 match self {
383 Error::CompletelyBusted => None,
384 Error::HeaderTooShort(id) => Some(id),
385 Error::QuestionTooShort(id) => Some(id),
386 Error::ResourceRecordTooShort(id) => Some(id),
387 Error::ResourceRecordInvalid(id) => Some(id),
388 Error::DomainTooShort(id) => Some(id),
389 Error::DomainTooLong(id) => Some(id),
390 Error::DomainPointerInvalid(id) => Some(id),
391 Error::DomainLabelInvalid(id) => Some(id),
392 }
393 }
394}
395
396struct ConsumableBuffer<'a> {
398 octets: &'a [u8],
399 position: usize,
400}
401
402impl<'a> ConsumableBuffer<'a> {
403 fn new(octets: &'a [u8]) -> Self {
404 Self {
405 octets,
406 position: 0,
407 }
408 }
409
410 fn next_u8(&mut self) -> Option<u8> {
411 if self.octets.len() > self.position {
412 let a = self.octets[self.position];
413 self.position += 1;
414 Some(a)
415 } else {
416 None
417 }
418 }
419
420 fn next_u16(&mut self) -> Option<u16> {
421 if self.octets.len() > self.position + 1 {
422 let a = self.octets[self.position];
423 let b = self.octets[self.position + 1];
424 self.position += 2;
425 Some(u16::from_be_bytes([a, b]))
426 } else {
427 None
428 }
429 }
430
431 fn next_u32(&mut self) -> Option<u32> {
432 if self.octets.len() > self.position + 3 {
433 let a = self.octets[self.position];
434 let b = self.octets[self.position + 1];
435 let c = self.octets[self.position + 2];
436 let d = self.octets[self.position + 3];
437 self.position += 4;
438 Some(u32::from_be_bytes([a, b, c, d]))
439 } else {
440 None
441 }
442 }
443
444 fn take(&mut self, size: usize) -> Option<&'a [u8]> {
445 if self.octets.len() >= self.position + size {
446 let slice = &self.octets[self.position..self.position + size];
447 self.position += size;
448 Some(slice)
449 } else {
450 None
451 }
452 }
453
454 fn at_offset(&self, position: usize) -> ConsumableBuffer<'a> {
455 Self {
456 octets: self.octets,
457 position,
458 }
459 }
460}