dns_resolver/util/
net.rs

1use bytes::BytesMut;
2use std::io;
3use std::net::SocketAddr;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpStream, UdpSocket};
6
7/// Read a DNS message from a TCP stream.
8///
9/// A DNS TCP message is slightly different to a DNS UDP message: it
10/// has a big-endian u16 prefix giving the total length of the
11/// message.  This is redundant (since the header is fixed-size and
12/// says how many fields there are, and the fields contain length
13/// information), but it means the entire message can be read before
14/// parsing begins.
15///
16/// # Errors
17///
18/// If reading from the stream fails or returns an incomplete message.
19pub async fn read_tcp_bytes(stream: &mut TcpStream) -> Result<BytesMut, TcpError> {
20    match stream.read_u16().await {
21        Ok(size) => {
22            let expected = size as usize;
23            let mut bytes = BytesMut::with_capacity(expected);
24            while bytes.len() < expected {
25                match stream.read_buf(&mut bytes).await {
26                    Ok(0) if bytes.len() < expected => {
27                        let id = if bytes.len() >= 2 {
28                            Some(u16::from_be_bytes([bytes[0], bytes[1]]))
29                        } else {
30                            None
31                        };
32                        return Err(TcpError::TooShort {
33                            id,
34                            expected,
35                            actual: bytes.len(),
36                        });
37                    }
38                    Err(err) => {
39                        let id = if bytes.len() >= 2 {
40                            Some(u16::from_be_bytes([bytes[0], bytes[1]]))
41                        } else {
42                            None
43                        };
44                        return Err(TcpError::IO { id, error: err });
45                    }
46                    _ => (),
47                }
48            }
49            Ok(bytes)
50        }
51        Err(err) => Err(TcpError::IO {
52            id: None,
53            error: err,
54        }),
55    }
56}
57
58/// An error that can occur when reading a DNS TCP message.
59#[derive(Debug)]
60pub enum TcpError {
61    TooShort {
62        id: Option<u16>,
63        expected: usize,
64        actual: usize,
65    },
66    IO {
67        id: Option<u16>,
68        error: io::Error,
69    },
70}
71
72/// Write a serialised message to a UDP channel.  This sets or clears
73/// the TC flag as appropriate.
74///
75/// # Errors
76///
77/// If sending the message fails.
78///
79/// # Panics
80///
81/// If given an incomplete (< 12 byte) message.
82pub async fn send_udp_bytes(sock: &UdpSocket, bytes: &mut [u8]) -> Result<(), io::Error> {
83    if bytes.len() < 12 {
84        tracing::error!(length = %bytes.len(), "message too short");
85        panic!("expected complete message");
86    }
87
88    if bytes.len() > 512 {
89        bytes[2] |= 0b0000_0010;
90        sock.send(&bytes[..512]).await?;
91    } else {
92        bytes[2] &= 0b1111_1101;
93        sock.send(bytes).await?;
94    }
95
96    Ok(())
97}
98
99/// Like `send_udp_bytes` but sends to the given address
100///
101/// # Errors
102///
103/// If sending the message fails.
104///
105/// # Panics
106///
107/// If given an incomplete (< 12 byte) message.
108pub async fn send_udp_bytes_to(
109    sock: &UdpSocket,
110    target: SocketAddr,
111    bytes: &mut [u8],
112) -> Result<(), io::Error> {
113    // TODO: see if this can be combined with `send_udp_bytes`
114
115    if bytes.len() < 12 {
116        tracing::error!(length = %bytes.len(), "message too short");
117        panic!("expected complete message");
118    }
119
120    if bytes.len() > 512 {
121        bytes[2] |= 0b0000_0010;
122        sock.send_to(&bytes[..512], target).await?;
123    } else {
124        bytes[2] &= 0b1111_1101;
125        sock.send_to(bytes, target).await?;
126    }
127
128    Ok(())
129}
130
131/// Write a serialised message to a TCP channel.  This sends a
132/// two-byte length prefix (big-endian u16) and sets or clears the TC
133/// flag as appropriate.
134///
135/// # Errors
136///
137/// If sending the message fails.
138///
139/// # Panics
140///
141/// If given an incomplete (< 12 byte) message.
142pub async fn send_tcp_bytes(stream: &mut TcpStream, bytes: &mut [u8]) -> Result<(), io::Error> {
143    if bytes.len() < 12 {
144        tracing::error!(length = %bytes.len(), "message too short");
145        panic!("expected complete message");
146    }
147
148    let len = if let Ok(len) = bytes.len().try_into() {
149        bytes[2] &= 0b1111_1101;
150        len
151    } else {
152        bytes[2] |= 0b0000_0010;
153        u16::MAX
154    };
155
156    stream.write_all(&len.to_be_bytes()).await?;
157    stream.write_all(&bytes[..(len as usize)]).await?;
158
159    Ok(())
160}