dns_resolver/
cache.rs

1use priority_queue::PriorityQueue;
2use std::cmp::Eq;
3use std::cmp::Reverse;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::marker::Copy;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10use dns_types::protocol::types::*;
11
12/// A convenience wrapper around a `Cache` which lets it be shared
13/// between threads.
14///
15/// Invoking `clone` on a `SharedCache` gives a new instance which
16/// refers to the same underlying `Cache` object.
17#[derive(Debug, Clone)]
18pub struct SharedCache {
19    cache: Arc<Mutex<Cache>>,
20}
21
22const MUTEX_POISON_MESSAGE: &str =
23    "[INTERNAL ERROR] cache mutex poisoned, cannot recover from this - aborting";
24
25impl SharedCache {
26    /// Make a new, empty, shared cache.
27    pub fn new() -> Self {
28        SharedCache {
29            cache: Arc::new(Mutex::new(Cache::new())),
30        }
31    }
32
33    /// Create a new cache with the given desired size.
34    pub fn with_desired_size(desired_size: usize) -> Self {
35        SharedCache {
36            cache: Arc::new(Mutex::new(Cache::with_desired_size(desired_size))),
37        }
38    }
39
40    /// Get an entry from the cache.
41    ///
42    /// The TTL in the returned `ResourceRecord` is relative to the
43    /// current time - not when the record was inserted into the
44    /// cache.
45    ///
46    /// # Panics
47    ///
48    /// If the mutex has been poisoned.
49    pub fn get(&self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
50        self.cache
51            .lock()
52            .expect(MUTEX_POISON_MESSAGE)
53            .get(name, qtype)
54    }
55
56    /// Like `get`, but may return expired entries.
57    ///
58    /// Consumers MUST check that the TTL of a record is nonzero
59    /// before using it!
60    ///
61    /// # Panics
62    ///
63    /// If the mutex has been poisoned.
64    pub fn get_without_checking_expiration(
65        &self,
66        name: &DomainName,
67        qtype: QueryType,
68    ) -> Vec<ResourceRecord> {
69        self.cache
70            .lock()
71            .expect(MUTEX_POISON_MESSAGE)
72            .get_without_checking_expiration(name, qtype)
73    }
74
75    /// Insert an entry into the cache.
76    ///
77    /// It is not inserted if its TTL is zero or negative.
78    ///
79    /// This may make the cache grow beyond the desired size.
80    ///
81    /// # Panics
82    ///
83    /// If the mutex has been poisoned.
84    pub fn insert(&self, record: &ResourceRecord) {
85        if record.ttl > 0 {
86            let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
87            cache.insert(record);
88        }
89    }
90
91    /// Insert multiple entries into the cache.
92    ///
93    /// This is more efficient than calling `insert` multiple times, as it locks
94    /// the cache just once.
95    ///
96    /// Records with a TTL of zero or negative are skipped.
97    ///
98    /// This may make the cache grow beyond the desired size.
99    ///
100    /// # Panics
101    ///
102    /// If the mutex has been poisoned.
103    pub fn insert_all(&self, records: &[ResourceRecord]) {
104        let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
105        for record in records {
106            if record.ttl > 0 {
107                cache.insert(record);
108            }
109        }
110    }
111
112    /// Atomically clears expired entries and, if the cache has grown
113    /// beyond its desired size, prunes entries to get down to size.
114    ///
115    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
116    ///
117    /// # Panics
118    ///
119    /// If the mutex has been poisoned.
120    pub fn prune(&self) -> (bool, usize, usize, usize) {
121        self.cache.lock().expect(MUTEX_POISON_MESSAGE).prune()
122    }
123}
124
125impl Default for SharedCache {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131/// Caching for `ResourceRecord`s.
132///
133/// You probably want to use `SharedCache` instead.
134#[derive(Debug, Clone)]
135pub struct Cache {
136    inner: PartitionedCache<DomainName, RecordType, RecordTypeWithData>,
137}
138
139impl Default for Cache {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl Cache {
146    /// Create a new cache with a default desired size.
147    pub fn new() -> Self {
148        Self {
149            inner: PartitionedCache::new(),
150        }
151    }
152
153    /// Create a new cache with the given desired size.
154    ///
155    /// The `prune` method will remove expired entries, and also enough entries
156    /// (in least-recently-used order) to get down to this size.
157    pub fn with_desired_size(desired_size: usize) -> Self {
158        Self {
159            inner: PartitionedCache::with_desired_size(desired_size),
160        }
161    }
162
163    /// Get RRs from the cache.
164    ///
165    /// The TTL in the returned `ResourceRecord` is relative to the
166    /// current time - not when the record was inserted into the
167    /// cache.
168    pub fn get(&mut self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
169        let mut rrs = self.get_without_checking_expiration(name, qtype);
170        rrs.retain(|rr| rr.ttl > 0);
171        rrs
172    }
173
174    /// Like `get`, but may return expired RRs.
175    ///
176    /// Consumers MUST check that the TTL of a record is nonzero before using
177    /// it!
178    pub fn get_without_checking_expiration(
179        &mut self,
180        name: &DomainName,
181        qtype: QueryType,
182    ) -> Vec<ResourceRecord> {
183        let now = Instant::now();
184        let mut rrs = Vec::new();
185        match qtype {
186            QueryType::Wildcard => {
187                if let Some(records) = self.inner.get_partition_without_checking_expiration(name) {
188                    for tuples in records.values() {
189                        to_rrs(name, now, tuples, &mut rrs);
190                    }
191                }
192            }
193            QueryType::Record(rtype) => {
194                if let Some(tuples) = self.inner.get_without_checking_expiration(name, &rtype) {
195                    to_rrs(name, now, tuples, &mut rrs);
196                }
197            }
198            _ => (),
199        }
200
201        rrs
202    }
203
204    /// Insert an RR into the cache.
205    pub fn insert(&mut self, record: &ResourceRecord) {
206        self.inner.upsert(
207            record.name.clone(),
208            record.rtype_with_data.rtype(),
209            record.rtype_with_data.clone(),
210            Duration::from_secs(record.ttl.into()),
211        );
212    }
213
214    /// Clear expired RRs and, if the cache has grown beyond its desired size,
215    /// prunes domains to get down to size.
216    ///
217    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
218    pub fn prune(&mut self) -> (bool, usize, usize, usize) {
219        self.inner.prune()
220    }
221}
222
223/// Helper for `get_without_checking_expiration`: converts the cached
224/// record tuples into RRs.
225fn to_rrs(
226    name: &DomainName,
227    now: Instant,
228    tuples: &[(RecordTypeWithData, Instant)],
229    rrs: &mut Vec<ResourceRecord>,
230) {
231    for (rtype, expires) in tuples {
232        let ttl = if let Ok(ttl) = expires.saturating_duration_since(now).as_secs().try_into() {
233            ttl
234        } else {
235            u32::MAX
236        };
237
238        rrs.push(ResourceRecord {
239            name: name.clone(),
240            rtype_with_data: rtype.clone(),
241            rclass: RecordClass::IN,
242            ttl,
243        });
244    }
245}
246
247#[derive(Debug, Clone)]
248pub struct PartitionedCache<K1: Eq + Hash, K2: Eq + Hash, V> {
249    /// Cached entries, indexed by partition key.
250    partitions: HashMap<K1, Partition<K2, V>>,
251
252    /// Priority queue of partition keys ordered by access times.
253    ///
254    /// When the cache is full and there are no expired records to prune,
255    /// partitions will instead be pruned in LRU order.
256    ///
257    /// INVARIANT: the keys in here are exactly the keys in `partitions`.
258    access_priority: PriorityQueue<K1, Reverse<Instant>>,
259
260    /// Priority queue of partition keys ordered by expiry time.
261    ///
262    /// When the cache is pruned, expired records are removed first.
263    ///
264    /// INVARIANT: the keys in here are exactly the keys in `partitions`.
265    expiry_priority: PriorityQueue<K1, Reverse<Instant>>,
266
267    /// The number of records in the cache, across all partitions.
268    ///
269    /// INVARIANT: this is the sum of the `size` fields of the `partitions`.
270    current_size: usize,
271
272    /// The desired maximum number of records in the cache.
273    desired_size: usize,
274}
275
276/// The cached records for a domain.
277#[derive(Debug, Clone, Eq, PartialEq)]
278struct Partition<K: Eq + Hash, V> {
279    /// The time this partition was last read at.
280    last_read: Instant,
281
282    /// When the next record expires.
283    ///
284    /// INVARIANT: this is the minimum of the expiry times of the `records`.
285    next_expiry: Instant,
286
287    /// How many records there are.
288    ///
289    /// INVARIANT: this is the sum of the vector lengths in `records`.
290    size: usize,
291
292    /// The records, further divided by record key.
293    records: HashMap<K, Vec<(V, Instant)>>,
294}
295
296impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> Default
297    for PartitionedCache<K1, K2, V>
298{
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> PartitionedCache<K1, K2, V> {
305    /// Create a new cache with a default desired size.
306    pub fn new() -> Self {
307        Self::with_desired_size(512)
308    }
309
310    /// Create a new cache with the given desired size.
311    ///
312    /// The `prune` method will remove expired records, and also enough records
313    /// (in least-recently-used order) to get down to this size.
314    pub fn with_desired_size(desired_size: usize) -> Self {
315        Self {
316            // `desired_size / 2` is a compromise: most partitions will have
317            // more than one record, so `desired_size` would be too big for the
318            // `partitions`.
319            partitions: HashMap::with_capacity(desired_size / 2),
320            access_priority: PriorityQueue::with_capacity(desired_size),
321            expiry_priority: PriorityQueue::with_capacity(desired_size),
322            current_size: 0,
323            desired_size,
324        }
325    }
326
327    /// Get all records for the given partition key from the cache, along with
328    /// their expiration times.
329    ///
330    /// These records may have expired if `prune` has not been called recently.
331    pub fn get_partition_without_checking_expiration(
332        &mut self,
333        partition_key: &K1,
334    ) -> Option<&HashMap<K2, Vec<(V, Instant)>>> {
335        if let Some(partition) = self.partitions.get_mut(partition_key) {
336            partition.last_read = Instant::now();
337            self.access_priority
338                .change_priority(partition_key, Reverse(partition.last_read));
339            return Some(&partition.records);
340        }
341
342        None
343    }
344
345    /// Get all records for the given partition and record key from the cache,
346    /// along with their expiration times.
347    ///
348    /// These records may have expired if `prune` has not been called recently.
349    pub fn get_without_checking_expiration(
350        &mut self,
351        partition_key: &K1,
352        record_key: &K2,
353    ) -> Option<&[(V, Instant)]> {
354        if let Some(partition) = self.partitions.get_mut(partition_key) {
355            if let Some(tuples) = partition.records.get(record_key) {
356                partition.last_read = Instant::now();
357                self.access_priority
358                    .change_priority(partition_key, Reverse(partition.last_read));
359                return Some(tuples);
360            }
361        }
362
363        None
364    }
365
366    /// Insert a record into the cache, or reset the expiry time if already
367    /// present.
368    pub fn upsert(&mut self, partition_key: K1, record_key: K2, value: V, ttl: Duration) {
369        let now = Instant::now();
370        let expiry = now + ttl;
371        let tuple = (value, expiry);
372        if let Some(partition) = self.partitions.get_mut(&partition_key) {
373            if let Some(tuples) = partition.records.get_mut(&record_key) {
374                let mut duplicate_expires_at = None;
375                for i in 0..tuples.len() {
376                    let t = &tuples[i];
377                    if t.0 == tuple.0 {
378                        duplicate_expires_at = Some(t.1);
379                        tuples.swap_remove(i);
380                        break;
381                    }
382                }
383
384                tuples.push(tuple);
385
386                if let Some(dup_expiry) = duplicate_expires_at {
387                    partition.size -= 1;
388                    self.current_size -= 1;
389
390                    if dup_expiry == partition.next_expiry {
391                        let mut new_next_expiry = expiry;
392                        for (_, e) in tuples {
393                            if *e < new_next_expiry {
394                                new_next_expiry = *e;
395                            }
396                        }
397                        partition.next_expiry = new_next_expiry;
398                        self.expiry_priority
399                            .change_priority(&partition_key, Reverse(partition.next_expiry));
400                    }
401                }
402            } else {
403                partition.records.insert(record_key, vec![tuple]);
404            }
405            partition.last_read = now;
406            partition.size += 1;
407            self.access_priority
408                .change_priority(&partition_key, Reverse(partition.last_read));
409            if expiry < partition.next_expiry {
410                partition.next_expiry = expiry;
411                self.expiry_priority
412                    .change_priority(&partition_key, Reverse(partition.next_expiry));
413            }
414        } else {
415            let mut records = HashMap::new();
416            records.insert(record_key, vec![tuple]);
417            let partition = Partition {
418                last_read: now,
419                next_expiry: expiry,
420                size: 1,
421                records,
422            };
423            self.access_priority
424                .push(partition_key.clone(), Reverse(partition.last_read));
425            self.expiry_priority
426                .push(partition_key.clone(), Reverse(partition.next_expiry));
427            self.partitions.insert(partition_key, partition);
428        }
429
430        self.current_size += 1;
431    }
432
433    /// Delete all expired records.
434    ///
435    /// Returns the number of records deleted.
436    pub fn remove_expired(&mut self) -> usize {
437        let mut pruned = 0;
438
439        loop {
440            let before = pruned;
441            pruned += self.remove_expired_step();
442            if before == pruned {
443                break;
444            }
445        }
446
447        pruned
448    }
449
450    /// Delete all expired records, and then enough
451    /// least-recently-used records to reduce the cache to the desired
452    /// size.
453    ///
454    /// Returns `(has overflowed?, current size, num expired, num pruned)`.
455    pub fn prune(&mut self) -> (bool, usize, usize, usize) {
456        let has_overflowed = self.current_size > self.desired_size;
457        let num_expired = self.remove_expired();
458        let mut num_pruned = 0;
459
460        while self.current_size > self.desired_size {
461            num_pruned += self.remove_least_recently_used();
462        }
463
464        (has_overflowed, self.current_size, num_expired, num_pruned)
465    }
466
467    /// Helper for `remove_expired`: looks at the next-to-expire
468    /// domain and cleans up expired records from it.  This may delete
469    /// more than one record, and may even delete the whole domain.
470    ///
471    /// Returns the number of records removed.
472    fn remove_expired_step(&mut self) -> usize {
473        if let Some((partition_key, Reverse(expiry))) = self.expiry_priority.pop() {
474            let now = Instant::now();
475
476            if expiry > now {
477                self.expiry_priority.push(partition_key, Reverse(expiry));
478                return 0;
479            }
480
481            if let Some(partition) = self.partitions.get_mut(&partition_key) {
482                let mut pruned = 0;
483
484                let record_keys = partition.records.keys().copied().collect::<Vec<K2>>();
485                let mut next_expiry = None;
486                for rkey in record_keys {
487                    if let Some(tuples) = partition.records.get_mut(&rkey) {
488                        let len = tuples.len();
489                        tuples.retain(|(_, expiry)| expiry > &now);
490                        pruned += len - tuples.len();
491                        for (_, expiry) in tuples {
492                            match next_expiry {
493                                None => next_expiry = Some(*expiry),
494                                Some(t) if *expiry < t => next_expiry = Some(*expiry),
495                                _ => (),
496                            }
497                        }
498                    }
499                }
500
501                partition.size -= pruned;
502
503                if let Some(ne) = next_expiry {
504                    partition.next_expiry = ne;
505                    self.expiry_priority.push(partition_key, Reverse(ne));
506                } else {
507                    self.partitions.remove(&partition_key);
508                    self.access_priority.remove(&partition_key);
509                }
510
511                self.current_size -= pruned;
512                pruned
513            } else {
514                self.access_priority.remove(&partition_key);
515                0
516            }
517        } else {
518            0
519        }
520    }
521
522    /// Helper for `prune`: deletes all records associated with the
523    /// least recently used domain.
524    ///
525    /// Returns the number of records removed.
526    fn remove_least_recently_used(&mut self) -> usize {
527        if let Some((partition_key, _)) = self.access_priority.pop() {
528            self.expiry_priority.remove(&partition_key);
529
530            if let Some(partition) = self.partitions.remove(&partition_key) {
531                let pruned = partition.size;
532                self.current_size -= pruned;
533                pruned
534            } else {
535                0
536            }
537        } else {
538            0
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use dns_types::protocol::types::test_util::*;
546
547    use super::test_util::*;
548    use super::*;
549
550    #[test]
551    fn cache_put_can_get() {
552        for _ in 0..100 {
553            let mut cache = Cache::new();
554            let mut rr = arbitrary_resourcerecord();
555            rr.rclass = RecordClass::IN;
556            cache.insert(&rr);
557
558            assert_cache_response(
559                &rr,
560                &cache.get_without_checking_expiration(
561                    &rr.name,
562                    QueryType::Record(rr.rtype_with_data.rtype()),
563                ),
564            );
565            assert_cache_response(
566                &rr,
567                &cache.get_without_checking_expiration(&rr.name, QueryType::Wildcard),
568            );
569        }
570    }
571
572    #[test]
573    fn cache_put_deduplicates_and_maintains_invariants() {
574        let mut cache = Cache::new();
575        let mut rr = arbitrary_resourcerecord();
576        rr.rclass = RecordClass::IN;
577
578        cache.insert(&rr);
579        cache.insert(&rr);
580
581        assert_eq!(1, cache.inner.current_size);
582        assert_invariants(&cache);
583    }
584
585    #[test]
586    fn cache_put_maintains_invariants() {
587        let mut cache = Cache::new();
588
589        for _ in 0..100 {
590            let mut rr = arbitrary_resourcerecord();
591            rr.rclass = RecordClass::IN;
592            cache.insert(&rr);
593        }
594
595        assert_invariants(&cache);
596    }
597
598    #[test]
599    fn cache_put_then_get_maintains_invariants() {
600        let mut cache = Cache::new();
601        let mut queries = Vec::new();
602
603        for _ in 0..100 {
604            let mut rr = arbitrary_resourcerecord();
605            rr.rclass = RecordClass::IN;
606            cache.insert(&rr);
607            queries.push((
608                rr.name.clone(),
609                QueryType::Record(rr.rtype_with_data.rtype()),
610            ));
611        }
612        for (name, qtype) in queries {
613            cache.get_without_checking_expiration(&name, qtype);
614        }
615
616        assert_invariants(&cache);
617    }
618
619    #[test]
620    fn cache_put_then_prune_maintains_invariants() {
621        let mut cache = Cache::with_desired_size(25);
622
623        for _ in 0..100 {
624            let mut rr = arbitrary_resourcerecord();
625            rr.rclass = RecordClass::IN;
626            rr.ttl = 300; // this case isn't testing expiration
627            cache.insert(&rr);
628        }
629
630        // might be more than 75 because the size is measured in
631        // records, but pruning is done on whole domains
632        let (overflow, current_size, expired, pruned) = cache.prune();
633        assert!(overflow);
634        assert_eq!(0, expired);
635        assert!(pruned >= 75);
636        assert!(cache.inner.current_size <= 25);
637        assert_eq!(cache.inner.current_size, current_size);
638        assert_invariants(&cache);
639    }
640
641    #[test]
642    fn cache_put_then_expire_maintains_invariants() {
643        let mut cache = Cache::new();
644
645        for i in 0..100 {
646            let mut rr = arbitrary_resourcerecord();
647            rr.rclass = RecordClass::IN;
648            rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
649            cache.insert(&rr);
650        }
651
652        assert_eq!(49, cache.inner.remove_expired());
653        assert_eq!(51, cache.inner.current_size);
654        assert_invariants(&cache);
655    }
656
657    #[test]
658    fn cache_prune_expires_all() {
659        let mut cache = Cache::with_desired_size(99);
660
661        for i in 0..100 {
662            let mut rr = arbitrary_resourcerecord();
663            rr.rclass = RecordClass::IN;
664            rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
665            cache.insert(&rr);
666        }
667
668        let (overflow, current_size, expired, pruned) = cache.prune();
669        assert!(overflow);
670        assert_eq!(49, expired);
671        assert_eq!(0, pruned);
672        assert_eq!(cache.inner.current_size, current_size);
673        assert_invariants(&cache);
674    }
675
676    fn assert_invariants(cache: &Cache) {
677        assert_eq!(
678            cache.inner.current_size,
679            cache
680                .inner
681                .partitions
682                .values()
683                .map(|e| e.size)
684                .sum::<usize>()
685        );
686
687        assert_eq!(
688            cache.inner.partitions.len(),
689            cache.inner.access_priority.len()
690        );
691        assert_eq!(
692            cache.inner.partitions.len(),
693            cache.inner.expiry_priority.len()
694        );
695
696        let mut access_priority = PriorityQueue::new();
697        let mut expiry_priority = PriorityQueue::new();
698
699        for (name, partition) in &cache.inner.partitions {
700            assert_eq!(
701                partition.size,
702                partition.records.values().map(Vec::len).sum::<usize>()
703            );
704
705            let mut min_expires = None;
706            for (rtype, tuples) in &partition.records {
707                for (rtype_with_data, expires) in tuples {
708                    assert_eq!(*rtype, rtype_with_data.rtype());
709
710                    if let Some(e) = min_expires {
711                        if *expires < e {
712                            min_expires = Some(*expires);
713                        }
714                    } else {
715                        min_expires = Some(*expires);
716                    }
717                }
718            }
719
720            assert_eq!(Some(partition.next_expiry), min_expires);
721
722            access_priority.push(name.clone(), Reverse(partition.last_read));
723            expiry_priority.push(name.clone(), Reverse(partition.next_expiry));
724        }
725
726        assert_eq!(cache.inner.access_priority, access_priority);
727        assert_eq!(cache.inner.expiry_priority, expiry_priority);
728    }
729}
730
731#[cfg(test)]
732#[allow(clippy::missing_panics_doc)]
733pub mod test_util {
734    use super::*;
735
736    /// Assert that the cache response has exactly one element and
737    /// that it matches the original (all fields equal except TTL,
738    /// where the original is >=).
739    pub fn assert_cache_response(original: &ResourceRecord, response: &[ResourceRecord]) {
740        assert_eq!(1, response.len());
741        let cached = response[0].clone();
742
743        assert_eq!(original.name, cached.name);
744        assert_eq!(original.rtype_with_data, cached.rtype_with_data);
745        assert_eq!(RecordClass::IN, cached.rclass);
746        assert!(original.ttl >= cached.ttl);
747    }
748}