dns_resolver/
cache.rs

1use priority_queue::PriorityQueue;
2use std::cmp::Eq;
3use std::cmp::Reverse;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::marker::Copy;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10use dns_types::protocol::types::*;
11
12/// A convenience wrapper around a `Cache` which lets it be shared
13/// between threads.
14///
15/// Invoking `clone` on a `SharedCache` gives a new instance which
16/// refers to the same underlying `Cache` object.
17#[derive(Debug, Clone)]
18pub struct SharedCache {
19    cache: Arc<Mutex<Cache>>,
20}
21
22const MUTEX_POISON_MESSAGE: &str =
23    "[INTERNAL ERROR] cache mutex poisoned, cannot recover from this - aborting";
24
25impl SharedCache {
26    /// Make a new, empty, shared cache.
27    pub fn new() -> Self {
28        SharedCache {
29            cache: Arc::new(Mutex::new(Cache::new())),
30        }
31    }
32
33    /// Create a new cache with the given desired size.
34    pub fn with_desired_size(desired_size: usize) -> Self {
35        SharedCache {
36            cache: Arc::new(Mutex::new(Cache::with_desired_size(desired_size))),
37        }
38    }
39
40    /// Get an entry from the cache.
41    ///
42    /// The TTL in the returned `ResourceRecord` is relative to the
43    /// current time - not when the record was inserted into the
44    /// cache.
45    ///
46    /// # Panics
47    ///
48    /// If the mutex has been poisoned.
49    pub fn get(&self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
50        self.cache
51            .lock()
52            .expect(MUTEX_POISON_MESSAGE)
53            .get(name, qtype)
54    }
55
56    /// Like `get`, but may return expired entries.
57    ///
58    /// Consumers MUST check that the TTL of a record is nonzero
59    /// before using it!
60    ///
61    /// # Panics
62    ///
63    /// If the mutex has been poisoned.
64    pub fn get_without_checking_expiration(
65        &self,
66        name: &DomainName,
67        qtype: QueryType,
68    ) -> Vec<ResourceRecord> {
69        self.cache
70            .lock()
71            .expect(MUTEX_POISON_MESSAGE)
72            .get_without_checking_expiration(name, qtype)
73    }
74
75    /// Insert an entry into the cache.
76    ///
77    /// It is not inserted if its TTL is zero or negative.
78    ///
79    /// This may make the cache grow beyond the desired size.
80    ///
81    /// # Panics
82    ///
83    /// If the mutex has been poisoned.
84    pub fn insert(&self, record: &ResourceRecord) {
85        if record.ttl > 0 {
86            let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
87            cache.insert(record);
88        }
89    }
90
91    /// Insert multiple entries into the cache.
92    ///
93    /// This is more efficient than calling `insert` multiple times, as it locks
94    /// the cache just once.
95    ///
96    /// Records with a TTL of zero or negative are skipped.
97    ///
98    /// This may make the cache grow beyond the desired size.
99    ///
100    /// # Panics
101    ///
102    /// If the mutex has been poisoned.
103    pub fn insert_all(&self, records: &[ResourceRecord]) {
104        let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
105        for record in records {
106            if record.ttl > 0 {
107                cache.insert(record);
108            }
109        }
110    }
111
112    /// Atomically clears expired entries and, if the cache has grown
113    /// beyond its desired size, prunes entries to get down to size.
114    ///
115    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
116    ///
117    /// # Panics
118    ///
119    /// If the mutex has been poisoned.
120    pub fn prune(&self) -> (bool, usize, usize, usize) {
121        self.cache.lock().expect(MUTEX_POISON_MESSAGE).prune()
122    }
123}
124
125impl Default for SharedCache {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131/// Caching for `ResourceRecord`s.
132///
133/// You probably want to use `SharedCache` instead.
134#[derive(Debug, Clone)]
135pub struct Cache {
136    inner: PartitionedCache<DomainName, RecordType, RecordTypeWithData>,
137}
138
139impl Default for Cache {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl Cache {
146    /// Create a new cache with a default desired size.
147    pub fn new() -> Self {
148        Self {
149            inner: PartitionedCache::new(),
150        }
151    }
152
153    /// Create a new cache with the given desired size.
154    ///
155    /// The `prune` method will remove expired entries, and also enough entries
156    /// (in least-recently-used order) to get down to this size.
157    pub fn with_desired_size(desired_size: usize) -> Self {
158        Self {
159            inner: PartitionedCache::with_desired_size(desired_size),
160        }
161    }
162
163    /// Get RRs from the cache.
164    ///
165    /// The TTL in the returned `ResourceRecord` is relative to the
166    /// current time - not when the record was inserted into the
167    /// cache.
168    pub fn get(&mut self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
169        let mut rrs = self.get_without_checking_expiration(name, qtype);
170        rrs.retain(|rr| rr.ttl > 0);
171        rrs
172    }
173
174    /// Like `get`, but may return expired RRs.
175    ///
176    /// Consumers MUST check that the TTL of a record is nonzero before using
177    /// it!
178    pub fn get_without_checking_expiration(
179        &mut self,
180        name: &DomainName,
181        qtype: QueryType,
182    ) -> Vec<ResourceRecord> {
183        let now = Instant::now();
184        let mut rrs = Vec::new();
185        match qtype {
186            QueryType::Wildcard => {
187                if let Some(records) = self.inner.get_partition_without_checking_expiration(name) {
188                    for tuples in records.values() {
189                        to_rrs(name, now, tuples, &mut rrs);
190                    }
191                }
192            }
193            QueryType::Record(rtype) => {
194                if let Some(tuples) = self.inner.get_without_checking_expiration(name, &rtype) {
195                    to_rrs(name, now, tuples, &mut rrs);
196                }
197            }
198            _ => (),
199        }
200
201        rrs
202    }
203
204    /// Insert an RR into the cache.
205    pub fn insert(&mut self, record: &ResourceRecord) {
206        self.inner.upsert(
207            record.name.clone(),
208            record.rtype_with_data.rtype(),
209            record.rtype_with_data.clone(),
210            Duration::from_secs(record.ttl.into()),
211        );
212    }
213
214    /// Clear expired RRs and, if the cache has grown beyond its desired size,
215    /// prunes domains to get down to size.
216    ///
217    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
218    pub fn prune(&mut self) -> (bool, usize, usize, usize) {
219        self.inner.prune()
220    }
221}
222
223/// Helper for `get_without_checking_expiration`: converts the cached
224/// record tuples into RRs.
225fn to_rrs(
226    name: &DomainName,
227    now: Instant,
228    tuples: &[(RecordTypeWithData, Instant)],
229    rrs: &mut Vec<ResourceRecord>,
230) {
231    for (rtype, expires) in tuples {
232        rrs.push(ResourceRecord {
233            name: name.clone(),
234            rtype_with_data: rtype.clone(),
235            rclass: RecordClass::IN,
236            ttl: expires
237                .saturating_duration_since(now)
238                .as_secs()
239                .try_into()
240                .unwrap_or(u32::MAX),
241        });
242    }
243}
244
245#[derive(Debug, Clone)]
246pub struct PartitionedCache<K1: Eq + Hash, K2: Eq + Hash, V> {
247    /// Cached entries, indexed by partition key.
248    partitions: HashMap<K1, Partition<K2, V>>,
249
250    /// Priority queue of partition keys ordered by access times.
251    ///
252    /// When the cache is full and there are no expired records to prune,
253    /// partitions will instead be pruned in LRU order.
254    ///
255    /// INVARIANT: the keys in here are exactly the keys in `partitions`.
256    access_priority: PriorityQueue<K1, Reverse<Instant>>,
257
258    /// Priority queue of partition keys ordered by expiry time.
259    ///
260    /// When the cache is pruned, expired records are removed first.
261    ///
262    /// INVARIANT: the keys in here are exactly the keys in `partitions`.
263    expiry_priority: PriorityQueue<K1, Reverse<Instant>>,
264
265    /// The number of records in the cache, across all partitions.
266    ///
267    /// INVARIANT: this is the sum of the `size` fields of the `partitions`.
268    current_size: usize,
269
270    /// The desired maximum number of records in the cache.
271    desired_size: usize,
272}
273
274/// The cached records for a domain.
275#[derive(Debug, Clone, Eq, PartialEq)]
276struct Partition<K: Eq + Hash, V> {
277    /// The time this partition was last read at.
278    last_read: Instant,
279
280    /// When the next record expires.
281    ///
282    /// INVARIANT: this is the minimum of the expiry times of the `records`.
283    next_expiry: Instant,
284
285    /// How many records there are.
286    ///
287    /// INVARIANT: this is the sum of the vector lengths in `records`.
288    size: usize,
289
290    /// The records, further divided by record key.
291    records: HashMap<K, Vec<(V, Instant)>>,
292}
293
294impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> Default
295    for PartitionedCache<K1, K2, V>
296{
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> PartitionedCache<K1, K2, V> {
303    /// Create a new cache with a default desired size.
304    pub fn new() -> Self {
305        Self::with_desired_size(512)
306    }
307
308    /// Create a new cache with the given desired size.
309    ///
310    /// The `prune` method will remove expired records, and also enough records
311    /// (in least-recently-used order) to get down to this size.
312    pub fn with_desired_size(desired_size: usize) -> Self {
313        Self {
314            // `desired_size / 2` is a compromise: most partitions will have
315            // more than one record, so `desired_size` would be too big for the
316            // `partitions`.
317            partitions: HashMap::with_capacity(desired_size / 2),
318            access_priority: PriorityQueue::with_capacity(desired_size),
319            expiry_priority: PriorityQueue::with_capacity(desired_size),
320            current_size: 0,
321            desired_size,
322        }
323    }
324
325    /// Get all records for the given partition key from the cache, along with
326    /// their expiration times.
327    ///
328    /// These records may have expired if `prune` has not been called recently.
329    pub fn get_partition_without_checking_expiration(
330        &mut self,
331        partition_key: &K1,
332    ) -> Option<&HashMap<K2, Vec<(V, Instant)>>> {
333        if let Some(partition) = self.partitions.get_mut(partition_key) {
334            partition.last_read = Instant::now();
335            self.access_priority
336                .change_priority(partition_key, Reverse(partition.last_read));
337            return Some(&partition.records);
338        }
339
340        None
341    }
342
343    /// Get all records for the given partition and record key from the cache,
344    /// along with their expiration times.
345    ///
346    /// These records may have expired if `prune` has not been called recently.
347    pub fn get_without_checking_expiration(
348        &mut self,
349        partition_key: &K1,
350        record_key: &K2,
351    ) -> Option<&[(V, Instant)]> {
352        if let Some(partition) = self.partitions.get_mut(partition_key) {
353            if let Some(tuples) = partition.records.get(record_key) {
354                partition.last_read = Instant::now();
355                self.access_priority
356                    .change_priority(partition_key, Reverse(partition.last_read));
357                return Some(tuples);
358            }
359        }
360
361        None
362    }
363
364    /// Insert a record into the cache, or reset the expiry time if already
365    /// present.
366    pub fn upsert(&mut self, partition_key: K1, record_key: K2, value: V, ttl: Duration) {
367        let now = Instant::now();
368        let expiry = now + ttl;
369        let tuple = (value, expiry);
370        if let Some(partition) = self.partitions.get_mut(&partition_key) {
371            if let Some(tuples) = partition.records.get_mut(&record_key) {
372                let mut duplicate_expires_at = None;
373                for i in 0..tuples.len() {
374                    let t = &tuples[i];
375                    if t.0 == tuple.0 {
376                        duplicate_expires_at = Some(t.1);
377                        tuples.swap_remove(i);
378                        break;
379                    }
380                }
381
382                tuples.push(tuple);
383
384                if let Some(dup_expiry) = duplicate_expires_at {
385                    partition.size -= 1;
386                    self.current_size -= 1;
387
388                    if dup_expiry == partition.next_expiry {
389                        let mut new_next_expiry = expiry;
390                        for (_, e) in tuples {
391                            if *e < new_next_expiry {
392                                new_next_expiry = *e;
393                            }
394                        }
395                        partition.next_expiry = new_next_expiry;
396                        self.expiry_priority
397                            .change_priority(&partition_key, Reverse(partition.next_expiry));
398                    }
399                }
400            } else {
401                partition.records.insert(record_key, vec![tuple]);
402            }
403            partition.last_read = now;
404            partition.size += 1;
405            self.access_priority
406                .change_priority(&partition_key, Reverse(partition.last_read));
407            if expiry < partition.next_expiry {
408                partition.next_expiry = expiry;
409                self.expiry_priority
410                    .change_priority(&partition_key, Reverse(partition.next_expiry));
411            }
412        } else {
413            let mut records = HashMap::new();
414            records.insert(record_key, vec![tuple]);
415            let partition = Partition {
416                last_read: now,
417                next_expiry: expiry,
418                size: 1,
419                records,
420            };
421            self.access_priority
422                .push(partition_key.clone(), Reverse(partition.last_read));
423            self.expiry_priority
424                .push(partition_key.clone(), Reverse(partition.next_expiry));
425            self.partitions.insert(partition_key, partition);
426        }
427
428        self.current_size += 1;
429    }
430
431    /// Delete all expired records.
432    ///
433    /// Returns the number of records deleted.
434    pub fn remove_expired(&mut self) -> usize {
435        let mut pruned = 0;
436
437        loop {
438            let before = pruned;
439            pruned += self.remove_expired_step();
440            if before == pruned {
441                break;
442            }
443        }
444
445        pruned
446    }
447
448    /// Delete all expired records, and then enough
449    /// least-recently-used records to reduce the cache to the desired
450    /// size.
451    ///
452    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
453    pub fn prune(&mut self) -> (bool, usize, usize, usize) {
454        let has_overflowed = self.current_size > self.desired_size;
455        let num_expired = self.remove_expired();
456        let mut num_pruned = 0;
457
458        while self.current_size > self.desired_size {
459            num_pruned += self.remove_least_recently_used();
460        }
461
462        (has_overflowed, self.current_size, num_expired, num_pruned)
463    }
464
465    /// Helper for `remove_expired`: looks at the next-to-expire
466    /// domain and cleans up expired records from it.  This may delete
467    /// more than one record, and may even delete the whole domain.
468    ///
469    /// Returns the number of records removed.
470    fn remove_expired_step(&mut self) -> usize {
471        if let Some((partition_key, Reverse(expiry))) = self.expiry_priority.pop() {
472            let now = Instant::now();
473
474            if expiry > now {
475                self.expiry_priority.push(partition_key, Reverse(expiry));
476                return 0;
477            }
478
479            if let Some(partition) = self.partitions.get_mut(&partition_key) {
480                let mut pruned = 0;
481
482                let record_keys = partition.records.keys().copied().collect::<Vec<K2>>();
483                let mut next_expiry = None;
484                for rkey in record_keys {
485                    if let Some(tuples) = partition.records.get_mut(&rkey) {
486                        let len = tuples.len();
487                        tuples.retain(|(_, expiry)| expiry > &now);
488                        pruned += len - tuples.len();
489                        for (_, expiry) in tuples {
490                            match next_expiry {
491                                None => next_expiry = Some(*expiry),
492                                Some(t) if *expiry < t => next_expiry = Some(*expiry),
493                                _ => (),
494                            }
495                        }
496                    }
497                }
498
499                partition.size -= pruned;
500
501                if let Some(ne) = next_expiry {
502                    partition.next_expiry = ne;
503                    self.expiry_priority.push(partition_key, Reverse(ne));
504                } else {
505                    self.partitions.remove(&partition_key);
506                    self.access_priority.remove(&partition_key);
507                }
508
509                self.current_size -= pruned;
510                pruned
511            } else {
512                self.access_priority.remove(&partition_key);
513                0
514            }
515        } else {
516            0
517        }
518    }
519
520    /// Helper for `prune`: deletes all records associated with the
521    /// least recently used domain.
522    ///
523    /// Returns the number of records removed.
524    fn remove_least_recently_used(&mut self) -> usize {
525        if let Some((partition_key, _)) = self.access_priority.pop() {
526            self.expiry_priority.remove(&partition_key);
527
528            if let Some(partition) = self.partitions.remove(&partition_key) {
529                let pruned = partition.size;
530                self.current_size -= pruned;
531                pruned
532            } else {
533                0
534            }
535        } else {
536            0
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use dns_types::protocol::types::test_util::*;
544
545    use super::test_util::*;
546    use super::*;
547
548    #[test]
549    fn cache_put_can_get() {
550        for _ in 0..100 {
551            let mut cache = Cache::new();
552            let mut rr = arbitrary_resourcerecord();
553            rr.rclass = RecordClass::IN;
554            cache.insert(&rr);
555
556            assert_cache_response(
557                &rr,
558                &cache.get_without_checking_expiration(
559                    &rr.name,
560                    QueryType::Record(rr.rtype_with_data.rtype()),
561                ),
562            );
563            assert_cache_response(
564                &rr,
565                &cache.get_without_checking_expiration(&rr.name, QueryType::Wildcard),
566            );
567        }
568    }
569
570    #[test]
571    fn cache_put_deduplicates_and_maintains_invariants() {
572        let mut cache = Cache::new();
573        let mut rr = arbitrary_resourcerecord();
574        rr.rclass = RecordClass::IN;
575
576        cache.insert(&rr);
577        cache.insert(&rr);
578
579        assert_eq!(1, cache.inner.current_size);
580        assert_invariants(&cache);
581    }
582
583    #[test]
584    fn cache_put_maintains_invariants() {
585        let mut cache = Cache::new();
586
587        for _ in 0..100 {
588            let mut rr = arbitrary_resourcerecord();
589            rr.rclass = RecordClass::IN;
590            cache.insert(&rr);
591        }
592
593        assert_invariants(&cache);
594    }
595
596    #[test]
597    fn cache_put_then_get_maintains_invariants() {
598        let mut cache = Cache::new();
599        let mut queries = Vec::new();
600
601        for _ in 0..100 {
602            let mut rr = arbitrary_resourcerecord();
603            rr.rclass = RecordClass::IN;
604            cache.insert(&rr);
605            queries.push((
606                rr.name.clone(),
607                QueryType::Record(rr.rtype_with_data.rtype()),
608            ));
609        }
610        for (name, qtype) in queries {
611            cache.get_without_checking_expiration(&name, qtype);
612        }
613
614        assert_invariants(&cache);
615    }
616
617    #[test]
618    fn cache_put_then_prune_maintains_invariants() {
619        let mut cache = Cache::with_desired_size(25);
620
621        for _ in 0..100 {
622            let mut rr = arbitrary_resourcerecord();
623            rr.rclass = RecordClass::IN;
624            rr.ttl = 300; // this case isn't testing expiration
625            cache.insert(&rr);
626        }
627
628        // might be more than 75 because the size is measured in
629        // records, but pruning is done on whole domains
630        let (overflow, current_size, expired, pruned) = cache.prune();
631        assert!(overflow);
632        assert_eq!(0, expired);
633        assert!(pruned >= 75);
634        assert!(cache.inner.current_size <= 25);
635        assert_eq!(cache.inner.current_size, current_size);
636        assert_invariants(&cache);
637    }
638
639    #[test]
640    fn cache_put_then_expire_maintains_invariants() {
641        let mut cache = Cache::new();
642
643        for i in 0..100 {
644            let mut rr = arbitrary_resourcerecord();
645            rr.rclass = RecordClass::IN;
646            rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
647            cache.insert(&rr);
648        }
649
650        assert_eq!(49, cache.inner.remove_expired());
651        assert_eq!(51, cache.inner.current_size);
652        assert_invariants(&cache);
653    }
654
655    #[test]
656    fn cache_prune_expires_all() {
657        let mut cache = Cache::with_desired_size(99);
658
659        for i in 0..100 {
660            let mut rr = arbitrary_resourcerecord();
661            rr.rclass = RecordClass::IN;
662            rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
663            cache.insert(&rr);
664        }
665
666        let (overflow, current_size, expired, pruned) = cache.prune();
667        assert!(overflow);
668        assert_eq!(49, expired);
669        assert_eq!(0, pruned);
670        assert_eq!(cache.inner.current_size, current_size);
671        assert_invariants(&cache);
672    }
673
674    fn assert_invariants(cache: &Cache) {
675        assert_eq!(
676            cache.inner.current_size,
677            cache
678                .inner
679                .partitions
680                .values()
681                .map(|e| e.size)
682                .sum::<usize>()
683        );
684
685        assert_eq!(
686            cache.inner.partitions.len(),
687            cache.inner.access_priority.len()
688        );
689        assert_eq!(
690            cache.inner.partitions.len(),
691            cache.inner.expiry_priority.len()
692        );
693
694        let mut access_priority = PriorityQueue::new();
695        let mut expiry_priority = PriorityQueue::new();
696
697        for (name, partition) in &cache.inner.partitions {
698            assert_eq!(
699                partition.size,
700                partition.records.values().map(Vec::len).sum::<usize>()
701            );
702
703            let mut min_expires = None;
704            for (rtype, tuples) in &partition.records {
705                for (rtype_with_data, expires) in tuples {
706                    assert_eq!(*rtype, rtype_with_data.rtype());
707
708                    if let Some(e) = min_expires {
709                        if *expires < e {
710                            min_expires = Some(*expires);
711                        }
712                    } else {
713                        min_expires = Some(*expires);
714                    }
715                }
716            }
717
718            assert_eq!(Some(partition.next_expiry), min_expires);
719
720            access_priority.push(name.clone(), Reverse(partition.last_read));
721            expiry_priority.push(name.clone(), Reverse(partition.next_expiry));
722        }
723
724        assert_eq!(cache.inner.access_priority, access_priority);
725        assert_eq!(cache.inner.expiry_priority, expiry_priority);
726    }
727}
728
729#[cfg(test)]
730#[allow(clippy::missing_panics_doc)]
731pub mod test_util {
732    use super::*;
733
734    /// Assert that the cache response has exactly one element and
735    /// that it matches the original (all fields equal except TTL,
736    /// where the original is >=).
737    pub fn assert_cache_response(original: &ResourceRecord, response: &[ResourceRecord]) {
738        assert_eq!(1, response.len());
739        let cached = response[0].clone();
740
741        assert_eq!(original.name, cached.name);
742        assert_eq!(original.rtype_with_data, cached.rtype_with_data);
743        assert_eq!(RecordClass::IN, cached.rclass);
744        assert!(original.ttl >= cached.ttl);
745    }
746}