dns_resolver/util/net.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
use bytes::BytesMut;
use std::io;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
/// Read a DNS message from a TCP stream.
///
/// A DNS TCP message is slightly different to a DNS UDP message: it
/// has a big-endian u16 prefix giving the total length of the
/// message. This is redundant (since the header is fixed-size and
/// says how many fields there are, and the fields contain length
/// information), but it means the entire message can be read before
/// parsing begins.
///
/// # Errors
///
/// If reading from the stream fails or returns an incomplete message.
pub async fn read_tcp_bytes(stream: &mut TcpStream) -> Result<BytesMut, TcpError> {
match stream.read_u16().await {
Ok(size) => {
let expected = size as usize;
let mut bytes = BytesMut::with_capacity(expected);
while bytes.len() < expected {
match stream.read_buf(&mut bytes).await {
Ok(0) if bytes.len() < expected => {
let id = if bytes.len() >= 2 {
Some(u16::from_be_bytes([bytes[0], bytes[1]]))
} else {
None
};
return Err(TcpError::TooShort {
id,
expected,
actual: bytes.len(),
});
}
Err(err) => {
let id = if bytes.len() >= 2 {
Some(u16::from_be_bytes([bytes[0], bytes[1]]))
} else {
None
};
return Err(TcpError::IO { id, error: err });
}
_ => (),
}
}
Ok(bytes)
}
Err(err) => Err(TcpError::IO {
id: None,
error: err,
}),
}
}
/// An error that can occur when reading a DNS TCP message.
#[derive(Debug)]
pub enum TcpError {
TooShort {
id: Option<u16>,
expected: usize,
actual: usize,
},
IO {
id: Option<u16>,
error: io::Error,
},
}
/// Write a serialised message to a UDP channel. This sets or clears
/// the TC flag as appropriate.
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_udp_bytes(sock: &UdpSocket, bytes: &mut [u8]) -> Result<(), io::Error> {
if bytes.len() < 12 {
tracing::error!(length = %bytes.len(), "message too short");
panic!("expected complete message");
}
if bytes.len() > 512 {
bytes[2] |= 0b0000_0010;
sock.send(&bytes[..512]).await?;
} else {
bytes[2] &= 0b1111_1101;
sock.send(bytes).await?;
}
Ok(())
}
/// Like `send_udp_bytes` but sends to the given address
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_udp_bytes_to(
sock: &UdpSocket,
target: SocketAddr,
bytes: &mut [u8],
) -> Result<(), io::Error> {
// TODO: see if this can be combined with `send_udp_bytes`
if bytes.len() < 12 {
tracing::error!(length = %bytes.len(), "message too short");
panic!("expected complete message");
}
if bytes.len() > 512 {
bytes[2] |= 0b0000_0010;
sock.send_to(&bytes[..512], target).await?;
} else {
bytes[2] &= 0b1111_1101;
sock.send_to(bytes, target).await?;
}
Ok(())
}
/// Write a serialised message to a TCP channel. This sends a
/// two-byte length prefix (big-endian u16) and sets or clears the TC
/// flag as appropriate.
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_tcp_bytes(stream: &mut TcpStream, bytes: &mut [u8]) -> Result<(), io::Error> {
if bytes.len() < 12 {
tracing::error!(length = %bytes.len(), "message too short");
panic!("expected complete message");
}
let len = if let Ok(len) = bytes.len().try_into() {
bytes[2] &= 0b1111_1101;
len
} else {
bytes[2] |= 0b0000_0010;
u16::MAX
};
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&bytes[..(len as usize)]).await?;
Ok(())
}