1use std::collections::HashSet;
2
3use std::net::{IpAddr, Ipv4Addr};
4use std::str::FromStr;
5
6use crate::hosts::types::*;
7use crate::protocol::types::*;
8
9impl Hosts {
10 pub fn deserialise(data: &str) -> Result<Self, Error> {
16 let mut hosts = Self::new();
17 for line in data.lines() {
18 if let Some((address, new_names)) = parse_line(line)? {
19 for name in new_names {
20 match address {
21 IpAddr::V4(ip) => {
22 hosts.v4.insert(name, ip);
23 }
24 IpAddr::V6(ip) => {
25 hosts.v6.insert(name, ip);
26 }
27 }
28 }
29 }
30 }
31 Ok(hosts)
32 }
33}
34
35fn parse_line(line: &str) -> Result<Option<(IpAddr, HashSet<DomainName>)>, Error> {
41 let mut state = State::SkipToAddress;
42 let mut address = IpAddr::V4(Ipv4Addr::LOCALHOST);
43 let mut new_names = HashSet::new();
44
45 for (i, octet) in line.chars().enumerate() {
46 if !octet.is_ascii() {
47 return Err(Error::ExpectedAscii { octet });
48 }
49
50 state = match (&state, octet) {
51 (_, '#') => State::CommentToEndOfLine,
52 (State::CommentToEndOfLine, _) => break,
53
54 (State::SkipToAddress, c) if c.is_whitespace() => state,
55 (State::SkipToAddress, _) => State::ReadingAddress { start: i },
56
57 (State::ReadingAddress { .. }, '%') => break,
58 (State::ReadingAddress { start }, c) if c.is_whitespace() => {
59 let addr_str = &line[*start..i];
60 match IpAddr::from_str(addr_str) {
61 Ok(addr) => address = addr,
62 Err(_) => {
63 return Err(Error::CouldNotParseAddress {
64 address: addr_str.into(),
65 })
66 }
67 }
68 State::SkipToName
69 }
70 (State::ReadingAddress { .. }, _) => state,
71
72 (State::SkipToName, c) if c.is_whitespace() => state,
73 (State::SkipToName, _) => State::ReadingName { start: i },
74
75 (State::ReadingName { start }, c) if c.is_whitespace() => {
76 let name_str = &line[*start..i];
77 match DomainName::from_relative_dotted_string(&DomainName::root_domain(), name_str)
78 {
79 Some(name) => {
80 new_names.insert(name);
81 }
82 None => {
83 return Err(Error::CouldNotParseName {
84 name: name_str.into(),
85 })
86 }
87 }
88 State::SkipToName
89 }
90 (State::ReadingName { .. }, _) => state,
91 }
92 }
93
94 if let State::ReadingName { start } = state {
95 let name_str = &line[start..];
96 match DomainName::from_relative_dotted_string(&DomainName::root_domain(), name_str) {
97 Some(name) => {
98 new_names.insert(name);
99 }
100 None => {
101 return Err(Error::CouldNotParseName {
102 name: name_str.into(),
103 })
104 }
105 }
106 }
107
108 if new_names.is_empty() {
109 Ok(None)
110 } else {
111 Ok(Some((address, new_names)))
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum Error {
118 ExpectedAscii { octet: char },
119 CouldNotParseAddress { address: String },
120 CouldNotParseName { name: String },
121}
122
123impl std::fmt::Display for Error {
124 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
125 match self {
126 Error::ExpectedAscii { octet } => write!(f, "expected ASCII octet, not '{octet:?}'"),
127 Error::CouldNotParseAddress { address } => {
128 write!(f, "could not parse address '{address:?}'")
129 }
130 Error::CouldNotParseName { name } => {
131 write!(f, "could not parse domain name '{name:?}'")
132 }
133 }
134 }
135}
136
137impl std::error::Error for Error {
138 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
139 None
140 }
141}
142
143enum State {
145 SkipToAddress,
146 ReadingAddress { start: usize },
147 SkipToName,
148 ReadingName { start: usize },
149 CommentToEndOfLine,
150}
151
152#[cfg(test)]
153mod tests {
154 use std::net::Ipv6Addr;
155
156 use super::*;
157
158 use crate::protocol::types::test_util::*;
159 use crate::zones::types::*;
160
161 #[test]
162 fn parses_all() {
163 let hosts_data = "# hark, a comment!\n\
164 1.2.3.4 one two three four\n\
165 0.0.0.0 blocked\n
166 \n\
167 127.0.0.1 localhost.\n\
168 ::1 localhost";
169
170 let hosts = Hosts::deserialise(hosts_data).unwrap();
171
172 let expected_a_records = &[
173 ("one.", Ipv4Addr::new(1, 2, 3, 4)),
174 ("two.", Ipv4Addr::new(1, 2, 3, 4)),
175 ("three.", Ipv4Addr::new(1, 2, 3, 4)),
176 ("four.", Ipv4Addr::new(1, 2, 3, 4)),
177 ("blocked.", Ipv4Addr::new(0, 0, 0, 0)),
178 ("localhost.", Ipv4Addr::new(127, 0, 0, 1)),
179 ];
180
181 let expected_aaaa_records = &[("localhost.", Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))];
182
183 for (name, addr) in expected_a_records {
184 let mut rr = a_record(name, *addr);
185 rr.ttl = TTL;
186 assert_eq!(
187 Some(ZoneResult::Answer { rrs: vec![rr] }),
188 Zone::from(hosts.clone()).resolve(&domain(name), QueryType::Record(RecordType::A))
189 );
190 }
191
192 for (name, addr) in expected_aaaa_records {
193 let mut rr = aaaa_record(name, *addr);
194 rr.ttl = TTL;
195 assert_eq!(
196 Some(ZoneResult::Answer { rrs: vec![rr] }),
197 Zone::from(hosts.clone())
198 .resolve(&domain(name), QueryType::Record(RecordType::AAAA))
199 );
200 }
201 }
202
203 #[test]
204 fn parse_line_ignores_iface_address() {
205 assert_eq!(Ok(None), parse_line("fe80::1%lo0 localhost"));
206 }
207
208 #[test]
209 fn parse_line_parses_ipv4_with_names() {
210 if let Ok(parsed) = parse_line("1.2.3.4 foo bar") {
211 assert_eq!(
212 Some((
213 IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
214 [domain("foo."), domain("bar.")].into_iter().collect()
215 )),
216 parsed
217 );
218 } else {
219 panic!("unexpected parse failure");
220 }
221 }
222
223 #[test]
224 fn parse_line_parses_ipv4_without_names() {
225 if let Ok(parsed) = parse_line("1.2.3.4") {
226 assert_eq!(None, parsed);
227 } else {
228 panic!("unexpected parse failure");
229 }
230 }
231
232 #[test]
233 fn parse_line_parses_ipv6_with_names() {
234 if let Ok(parsed) = parse_line("::1:2:3 foo bar") {
235 assert_eq!(
236 Some((
237 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 1, 2, 3)),
238 [domain("foo."), domain("bar.")].into_iter().collect()
239 )),
240 parsed
241 );
242 } else {
243 panic!("unexpected parse failure");
244 }
245 }
246
247 #[test]
248 fn parse_line_parses_ipv6_without_names() {
249 if let Ok(parsed) = parse_line("::1") {
250 assert_eq!(None, parsed);
251 } else {
252 panic!("unexpected parse failure");
253 }
254 }
255}