use bytes::Bytes;
use std::net::{Ipv4Addr, Ipv6Addr};
use crate::protocol::types::*;
impl Message {
pub fn from_octets(octets: &[u8]) -> Result<Self, Error> {
Self::deserialise(&mut ConsumableBuffer::new(octets))
}
fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let header = Header::deserialise(buffer)?;
let qdcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
let ancount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
let nscount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
let arcount = buffer.next_u16().ok_or(Error::HeaderTooShort(header.id))?;
let mut questions = Vec::with_capacity(qdcount.into());
let mut answers = Vec::with_capacity(ancount.into());
let mut authority = Vec::with_capacity(nscount.into());
let mut additional = Vec::with_capacity(arcount.into());
for _ in 0..qdcount {
questions.push(Question::deserialise(header.id, buffer)?);
}
for _ in 0..ancount {
answers.push(ResourceRecord::deserialise(header.id, buffer)?);
}
for _ in 0..nscount {
authority.push(ResourceRecord::deserialise(header.id, buffer)?);
}
for _ in 0..arcount {
additional.push(ResourceRecord::deserialise(header.id, buffer)?);
}
Ok(Self {
header,
questions,
answers,
authority,
additional,
})
}
}
impl Header {
fn deserialise(buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let id = buffer.next_u16().ok_or(Error::CompletelyBusted)?;
let flags1 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
let flags2 = buffer.next_u8().ok_or(Error::HeaderTooShort(id))?;
Ok(Self {
id,
is_response: flags1 & HEADER_MASK_QR != 0,
opcode: Opcode::from((flags1 & HEADER_MASK_OPCODE) >> HEADER_OFFSET_OPCODE),
is_authoritative: flags1 & HEADER_MASK_AA != 0,
is_truncated: flags1 & HEADER_MASK_TC != 0,
recursion_desired: flags1 & HEADER_MASK_RD != 0,
recursion_available: flags2 & HEADER_MASK_RA != 0,
rcode: Rcode::from((flags2 & HEADER_MASK_RCODE) >> HEADER_OFFSET_RCODE),
})
}
}
impl Question {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let name = DomainName::deserialise(id, buffer)?;
let qtype = QueryType::deserialise(id, buffer)?;
let qclass = QueryClass::deserialise(id, buffer)?;
Ok(Self {
name,
qtype,
qclass,
})
}
}
impl ResourceRecord {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let name = DomainName::deserialise(id, buffer)?;
let rtype = RecordType::deserialise(id, buffer)?;
let rclass = RecordClass::deserialise(id, buffer)?;
let ttl = buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?;
let rdlength = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
let rdata_start = buffer.position;
let mut raw_rdata = || {
if let Some(octets) = buffer.take(rdlength as usize) {
Ok(Bytes::copy_from_slice(octets))
} else {
Err(Error::ResourceRecordTooShort(id))
}
};
let rtype_with_data = match rtype {
RecordType::A => RecordTypeWithData::A {
address: Ipv4Addr::from(
buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
),
},
RecordType::NS => RecordTypeWithData::NS {
nsdname: DomainName::deserialise(id, buffer)?,
},
RecordType::MD => RecordTypeWithData::MD {
madname: DomainName::deserialise(id, buffer)?,
},
RecordType::MF => RecordTypeWithData::MF {
madname: DomainName::deserialise(id, buffer)?,
},
RecordType::CNAME => RecordTypeWithData::CNAME {
cname: DomainName::deserialise(id, buffer)?,
},
RecordType::SOA => RecordTypeWithData::SOA {
mname: DomainName::deserialise(id, buffer)?,
rname: DomainName::deserialise(id, buffer)?,
serial: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
refresh: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
retry: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
expire: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
minimum: buffer.next_u32().ok_or(Error::ResourceRecordTooShort(id))?,
},
RecordType::MB => RecordTypeWithData::MB {
madname: DomainName::deserialise(id, buffer)?,
},
RecordType::MG => RecordTypeWithData::MG {
mdmname: DomainName::deserialise(id, buffer)?,
},
RecordType::MR => RecordTypeWithData::MR {
newname: DomainName::deserialise(id, buffer)?,
},
RecordType::NULL => RecordTypeWithData::NULL {
octets: raw_rdata()?,
},
RecordType::WKS => RecordTypeWithData::WKS {
octets: raw_rdata()?,
},
RecordType::PTR => RecordTypeWithData::PTR {
ptrdname: DomainName::deserialise(id, buffer)?,
},
RecordType::HINFO => RecordTypeWithData::HINFO {
octets: raw_rdata()?,
},
RecordType::MINFO => RecordTypeWithData::MINFO {
rmailbx: DomainName::deserialise(id, buffer)?,
emailbx: DomainName::deserialise(id, buffer)?,
},
RecordType::MX => RecordTypeWithData::MX {
preference: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
exchange: DomainName::deserialise(id, buffer)?,
},
RecordType::TXT => RecordTypeWithData::TXT {
octets: raw_rdata()?,
},
RecordType::AAAA => RecordTypeWithData::AAAA {
address: Ipv6Addr::new(
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
),
},
RecordType::SRV => RecordTypeWithData::SRV {
priority: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
weight: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
port: buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?,
target: DomainName::deserialise(id, buffer)?,
},
RecordType::Unknown(tag) => RecordTypeWithData::Unknown {
tag,
octets: raw_rdata()?,
},
};
let rdata_stop = buffer.position;
if rdata_stop == rdata_start + (rdlength as usize) {
Ok(Self {
name,
rtype_with_data,
rclass,
ttl,
})
} else {
Err(Error::ResourceRecordInvalid(id))
}
}
}
impl DomainName {
#[allow(clippy::missing_panics_doc)]
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let mut len = 0;
let mut labels = Vec::<Label>::with_capacity(5);
let start = buffer.position;
'outer: loop {
let size = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
if usize::from(size) <= LABEL_MAX_LEN {
len += 1;
if size == 0 {
labels.push(Label::new());
break 'outer;
}
if let Some(os) = buffer.take(size as usize) {
let label = Label::try_from(os).unwrap();
len += label.len() as usize;
labels.push(label);
} else {
return Err(Error::DomainTooShort(id));
}
if len > DOMAINNAME_MAX_LEN {
break 'outer;
}
} else if size >= 192 {
let hi = size & 0b0011_1111;
let lo = buffer.next_u8().ok_or(Error::DomainTooShort(id))?;
let ptr = u16::from_be_bytes([hi, lo]).into();
if ptr >= start {
return Err(Error::DomainPointerInvalid(id));
}
let mut other = DomainName::deserialise(id, &mut buffer.at_offset(ptr))?;
len += other.len;
labels.append(&mut other.labels);
break 'outer;
} else {
return Err(Error::DomainLabelInvalid(id));
}
}
if len <= DOMAINNAME_MAX_LEN {
Ok(DomainName { labels, len })
} else {
Err(Error::DomainTooLong(id))
}
}
}
impl QueryType {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
Ok(Self::from(value))
}
}
impl QueryClass {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let value = buffer.next_u16().ok_or(Error::QuestionTooShort(id))?;
Ok(Self::from(value))
}
}
impl RecordType {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
Ok(Self::from(value))
}
}
impl RecordClass {
fn deserialise(id: u16, buffer: &mut ConsumableBuffer) -> Result<Self, Error> {
let value = buffer.next_u16().ok_or(Error::ResourceRecordTooShort(id))?;
Ok(Self::from(value))
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Error {
CompletelyBusted,
HeaderTooShort(u16),
QuestionTooShort(u16),
ResourceRecordTooShort(u16),
ResourceRecordInvalid(u16),
DomainTooShort(u16),
DomainTooLong(u16),
DomainPointerInvalid(u16),
DomainLabelInvalid(u16),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::CompletelyBusted | Error::HeaderTooShort(_) => write!(f, "header too short"),
Error::QuestionTooShort(_) => write!(f, "question too short"),
Error::ResourceRecordTooShort(_) => write!(f, "resource record too short"),
Error::ResourceRecordInvalid(_) => write!(
f,
"resource record RDLENGTH field does not match parsed RDATA length"
),
Error::DomainTooShort(_) => write!(f, "domain name too short"),
Error::DomainTooLong(_) => write!(f, "domain name too long"),
Error::DomainPointerInvalid(_) => write!(f, "domain name compression pointer invalid"),
Error::DomainLabelInvalid(_) => write!(f, "domain label invalid"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
impl Error {
pub fn id(self) -> Option<u16> {
match self {
Error::CompletelyBusted => None,
Error::HeaderTooShort(id) => Some(id),
Error::QuestionTooShort(id) => Some(id),
Error::ResourceRecordTooShort(id) => Some(id),
Error::ResourceRecordInvalid(id) => Some(id),
Error::DomainTooShort(id) => Some(id),
Error::DomainTooLong(id) => Some(id),
Error::DomainPointerInvalid(id) => Some(id),
Error::DomainLabelInvalid(id) => Some(id),
}
}
}
struct ConsumableBuffer<'a> {
octets: &'a [u8],
position: usize,
}
impl<'a> ConsumableBuffer<'a> {
fn new(octets: &'a [u8]) -> Self {
Self {
octets,
position: 0,
}
}
fn next_u8(&mut self) -> Option<u8> {
if self.octets.len() > self.position {
let a = self.octets[self.position];
self.position += 1;
Some(a)
} else {
None
}
}
fn next_u16(&mut self) -> Option<u16> {
if self.octets.len() > self.position + 1 {
let a = self.octets[self.position];
let b = self.octets[self.position + 1];
self.position += 2;
Some(u16::from_be_bytes([a, b]))
} else {
None
}
}
fn next_u32(&mut self) -> Option<u32> {
if self.octets.len() > self.position + 3 {
let a = self.octets[self.position];
let b = self.octets[self.position + 1];
let c = self.octets[self.position + 2];
let d = self.octets[self.position + 3];
self.position += 4;
Some(u32::from_be_bytes([a, b, c, d]))
} else {
None
}
}
fn take(&mut self, size: usize) -> Option<&'a [u8]> {
if self.octets.len() >= self.position + size {
let slice = &self.octets[self.position..self.position + size];
self.position += size;
Some(slice)
} else {
None
}
}
fn at_offset(&self, position: usize) -> ConsumableBuffer<'a> {
Self {
octets: self.octets,
position,
}
}
}