dns_resolver/util/
net.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use bytes::BytesMut;
use std::io;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};

/// Read a DNS message from a TCP stream.
///
/// A DNS TCP message is slightly different to a DNS UDP message: it
/// has a big-endian u16 prefix giving the total length of the
/// message.  This is redundant (since the header is fixed-size and
/// says how many fields there are, and the fields contain length
/// information), but it means the entire message can be read before
/// parsing begins.
///
/// # Errors
///
/// If reading from the stream fails or returns an incomplete message.
pub async fn read_tcp_bytes(stream: &mut TcpStream) -> Result<BytesMut, TcpError> {
    match stream.read_u16().await {
        Ok(size) => {
            let expected = size as usize;
            let mut bytes = BytesMut::with_capacity(expected);
            while bytes.len() < expected {
                match stream.read_buf(&mut bytes).await {
                    Ok(0) if bytes.len() < expected => {
                        let id = if bytes.len() >= 2 {
                            Some(u16::from_be_bytes([bytes[0], bytes[1]]))
                        } else {
                            None
                        };
                        return Err(TcpError::TooShort {
                            id,
                            expected,
                            actual: bytes.len(),
                        });
                    }
                    Err(err) => {
                        let id = if bytes.len() >= 2 {
                            Some(u16::from_be_bytes([bytes[0], bytes[1]]))
                        } else {
                            None
                        };
                        return Err(TcpError::IO { id, error: err });
                    }
                    _ => (),
                }
            }
            Ok(bytes)
        }
        Err(err) => Err(TcpError::IO {
            id: None,
            error: err,
        }),
    }
}

/// An error that can occur when reading a DNS TCP message.
#[derive(Debug)]
pub enum TcpError {
    TooShort {
        id: Option<u16>,
        expected: usize,
        actual: usize,
    },
    IO {
        id: Option<u16>,
        error: io::Error,
    },
}

/// Write a serialised message to a UDP channel.  This sets or clears
/// the TC flag as appropriate.
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_udp_bytes(sock: &UdpSocket, bytes: &mut [u8]) -> Result<(), io::Error> {
    if bytes.len() < 12 {
        tracing::error!(length = %bytes.len(), "message too short");
        panic!("expected complete message");
    }

    if bytes.len() > 512 {
        bytes[2] |= 0b0000_0010;
        sock.send(&bytes[..512]).await?;
    } else {
        bytes[2] &= 0b1111_1101;
        sock.send(bytes).await?;
    }

    Ok(())
}

/// Like `send_udp_bytes` but sends to the given address
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_udp_bytes_to(
    sock: &UdpSocket,
    target: SocketAddr,
    bytes: &mut [u8],
) -> Result<(), io::Error> {
    // TODO: see if this can be combined with `send_udp_bytes`

    if bytes.len() < 12 {
        tracing::error!(length = %bytes.len(), "message too short");
        panic!("expected complete message");
    }

    if bytes.len() > 512 {
        bytes[2] |= 0b0000_0010;
        sock.send_to(&bytes[..512], target).await?;
    } else {
        bytes[2] &= 0b1111_1101;
        sock.send_to(bytes, target).await?;
    }

    Ok(())
}

/// Write a serialised message to a TCP channel.  This sends a
/// two-byte length prefix (big-endian u16) and sets or clears the TC
/// flag as appropriate.
///
/// # Errors
///
/// If sending the message fails.
///
/// # Panics
///
/// If given an incomplete (< 12 byte) message.
pub async fn send_tcp_bytes(stream: &mut TcpStream, bytes: &mut [u8]) -> Result<(), io::Error> {
    if bytes.len() < 12 {
        tracing::error!(length = %bytes.len(), "message too short");
        panic!("expected complete message");
    }

    let len = if let Ok(len) = bytes.len().try_into() {
        bytes[2] &= 0b1111_1101;
        len
    } else {
        bytes[2] |= 0b0000_0010;
        u16::MAX
    };

    stream.write_all(&len.to_be_bytes()).await?;
    stream.write_all(&bytes[..(len as usize)]).await?;

    Ok(())
}