1use priority_queue::PriorityQueue;
2use std::cmp::Eq;
3use std::cmp::Reverse;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::marker::Copy;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10use dns_types::protocol::types::*;
11
12#[derive(Debug, Clone)]
18pub struct SharedCache {
19 cache: Arc<Mutex<Cache>>,
20}
21
22const MUTEX_POISON_MESSAGE: &str =
23 "[INTERNAL ERROR] cache mutex poisoned, cannot recover from this - aborting";
24
25impl SharedCache {
26 pub fn new() -> Self {
28 SharedCache {
29 cache: Arc::new(Mutex::new(Cache::new())),
30 }
31 }
32
33 pub fn with_desired_size(desired_size: usize) -> Self {
35 SharedCache {
36 cache: Arc::new(Mutex::new(Cache::with_desired_size(desired_size))),
37 }
38 }
39
40 pub fn get(&self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
50 self.cache
51 .lock()
52 .expect(MUTEX_POISON_MESSAGE)
53 .get(name, qtype)
54 }
55
56 pub fn get_without_checking_expiration(
65 &self,
66 name: &DomainName,
67 qtype: QueryType,
68 ) -> Vec<ResourceRecord> {
69 self.cache
70 .lock()
71 .expect(MUTEX_POISON_MESSAGE)
72 .get_without_checking_expiration(name, qtype)
73 }
74
75 pub fn insert(&self, record: &ResourceRecord) {
85 if record.ttl > 0 {
86 let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
87 cache.insert(record);
88 }
89 }
90
91 pub fn insert_all(&self, records: &[ResourceRecord]) {
104 let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
105 for record in records {
106 if record.ttl > 0 {
107 cache.insert(record);
108 }
109 }
110 }
111
112 pub fn prune(&self) -> (bool, usize, usize, usize) {
121 self.cache.lock().expect(MUTEX_POISON_MESSAGE).prune()
122 }
123}
124
125impl Default for SharedCache {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[derive(Debug, Clone)]
135pub struct Cache {
136 inner: PartitionedCache<DomainName, RecordType, RecordTypeWithData>,
137}
138
139impl Default for Cache {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl Cache {
146 pub fn new() -> Self {
148 Self {
149 inner: PartitionedCache::new(),
150 }
151 }
152
153 pub fn with_desired_size(desired_size: usize) -> Self {
158 Self {
159 inner: PartitionedCache::with_desired_size(desired_size),
160 }
161 }
162
163 pub fn get(&mut self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
169 let mut rrs = self.get_without_checking_expiration(name, qtype);
170 rrs.retain(|rr| rr.ttl > 0);
171 rrs
172 }
173
174 pub fn get_without_checking_expiration(
179 &mut self,
180 name: &DomainName,
181 qtype: QueryType,
182 ) -> Vec<ResourceRecord> {
183 let now = Instant::now();
184 let mut rrs = Vec::new();
185 match qtype {
186 QueryType::Wildcard => {
187 if let Some(records) = self.inner.get_partition_without_checking_expiration(name) {
188 for tuples in records.values() {
189 to_rrs(name, now, tuples, &mut rrs);
190 }
191 }
192 }
193 QueryType::Record(rtype) => {
194 if let Some(tuples) = self.inner.get_without_checking_expiration(name, &rtype) {
195 to_rrs(name, now, tuples, &mut rrs);
196 }
197 }
198 _ => (),
199 }
200
201 rrs
202 }
203
204 pub fn insert(&mut self, record: &ResourceRecord) {
206 self.inner.upsert(
207 record.name.clone(),
208 record.rtype_with_data.rtype(),
209 record.rtype_with_data.clone(),
210 Duration::from_secs(record.ttl.into()),
211 );
212 }
213
214 pub fn prune(&mut self) -> (bool, usize, usize, usize) {
219 self.inner.prune()
220 }
221}
222
223fn to_rrs(
226 name: &DomainName,
227 now: Instant,
228 tuples: &[(RecordTypeWithData, Instant)],
229 rrs: &mut Vec<ResourceRecord>,
230) {
231 for (rtype, expires) in tuples {
232 let ttl = if let Ok(ttl) = expires.saturating_duration_since(now).as_secs().try_into() {
233 ttl
234 } else {
235 u32::MAX
236 };
237
238 rrs.push(ResourceRecord {
239 name: name.clone(),
240 rtype_with_data: rtype.clone(),
241 rclass: RecordClass::IN,
242 ttl,
243 });
244 }
245}
246
247#[derive(Debug, Clone)]
248pub struct PartitionedCache<K1: Eq + Hash, K2: Eq + Hash, V> {
249 partitions: HashMap<K1, Partition<K2, V>>,
251
252 access_priority: PriorityQueue<K1, Reverse<Instant>>,
259
260 expiry_priority: PriorityQueue<K1, Reverse<Instant>>,
266
267 current_size: usize,
271
272 desired_size: usize,
274}
275
276#[derive(Debug, Clone, Eq, PartialEq)]
278struct Partition<K: Eq + Hash, V> {
279 last_read: Instant,
281
282 next_expiry: Instant,
286
287 size: usize,
291
292 records: HashMap<K, Vec<(V, Instant)>>,
294}
295
296impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> Default
297 for PartitionedCache<K1, K2, V>
298{
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> PartitionedCache<K1, K2, V> {
305 pub fn new() -> Self {
307 Self::with_desired_size(512)
308 }
309
310 pub fn with_desired_size(desired_size: usize) -> Self {
315 Self {
316 partitions: HashMap::with_capacity(desired_size / 2),
320 access_priority: PriorityQueue::with_capacity(desired_size),
321 expiry_priority: PriorityQueue::with_capacity(desired_size),
322 current_size: 0,
323 desired_size,
324 }
325 }
326
327 pub fn get_partition_without_checking_expiration(
332 &mut self,
333 partition_key: &K1,
334 ) -> Option<&HashMap<K2, Vec<(V, Instant)>>> {
335 if let Some(partition) = self.partitions.get_mut(partition_key) {
336 partition.last_read = Instant::now();
337 self.access_priority
338 .change_priority(partition_key, Reverse(partition.last_read));
339 return Some(&partition.records);
340 }
341
342 None
343 }
344
345 pub fn get_without_checking_expiration(
350 &mut self,
351 partition_key: &K1,
352 record_key: &K2,
353 ) -> Option<&[(V, Instant)]> {
354 if let Some(partition) = self.partitions.get_mut(partition_key) {
355 if let Some(tuples) = partition.records.get(record_key) {
356 partition.last_read = Instant::now();
357 self.access_priority
358 .change_priority(partition_key, Reverse(partition.last_read));
359 return Some(tuples);
360 }
361 }
362
363 None
364 }
365
366 pub fn upsert(&mut self, partition_key: K1, record_key: K2, value: V, ttl: Duration) {
369 let now = Instant::now();
370 let expiry = now + ttl;
371 let tuple = (value, expiry);
372 if let Some(partition) = self.partitions.get_mut(&partition_key) {
373 if let Some(tuples) = partition.records.get_mut(&record_key) {
374 let mut duplicate_expires_at = None;
375 for i in 0..tuples.len() {
376 let t = &tuples[i];
377 if t.0 == tuple.0 {
378 duplicate_expires_at = Some(t.1);
379 tuples.swap_remove(i);
380 break;
381 }
382 }
383
384 tuples.push(tuple);
385
386 if let Some(dup_expiry) = duplicate_expires_at {
387 partition.size -= 1;
388 self.current_size -= 1;
389
390 if dup_expiry == partition.next_expiry {
391 let mut new_next_expiry = expiry;
392 for (_, e) in tuples {
393 if *e < new_next_expiry {
394 new_next_expiry = *e;
395 }
396 }
397 partition.next_expiry = new_next_expiry;
398 self.expiry_priority
399 .change_priority(&partition_key, Reverse(partition.next_expiry));
400 }
401 }
402 } else {
403 partition.records.insert(record_key, vec![tuple]);
404 }
405 partition.last_read = now;
406 partition.size += 1;
407 self.access_priority
408 .change_priority(&partition_key, Reverse(partition.last_read));
409 if expiry < partition.next_expiry {
410 partition.next_expiry = expiry;
411 self.expiry_priority
412 .change_priority(&partition_key, Reverse(partition.next_expiry));
413 }
414 } else {
415 let mut records = HashMap::new();
416 records.insert(record_key, vec![tuple]);
417 let partition = Partition {
418 last_read: now,
419 next_expiry: expiry,
420 size: 1,
421 records,
422 };
423 self.access_priority
424 .push(partition_key.clone(), Reverse(partition.last_read));
425 self.expiry_priority
426 .push(partition_key.clone(), Reverse(partition.next_expiry));
427 self.partitions.insert(partition_key, partition);
428 }
429
430 self.current_size += 1;
431 }
432
433 pub fn remove_expired(&mut self) -> usize {
437 let mut pruned = 0;
438
439 loop {
440 let before = pruned;
441 pruned += self.remove_expired_step();
442 if before == pruned {
443 break;
444 }
445 }
446
447 pruned
448 }
449
450 pub fn prune(&mut self) -> (bool, usize, usize, usize) {
456 let has_overflowed = self.current_size > self.desired_size;
457 let num_expired = self.remove_expired();
458 let mut num_pruned = 0;
459
460 while self.current_size > self.desired_size {
461 num_pruned += self.remove_least_recently_used();
462 }
463
464 (has_overflowed, self.current_size, num_expired, num_pruned)
465 }
466
467 fn remove_expired_step(&mut self) -> usize {
473 if let Some((partition_key, Reverse(expiry))) = self.expiry_priority.pop() {
474 let now = Instant::now();
475
476 if expiry > now {
477 self.expiry_priority.push(partition_key, Reverse(expiry));
478 return 0;
479 }
480
481 if let Some(partition) = self.partitions.get_mut(&partition_key) {
482 let mut pruned = 0;
483
484 let record_keys = partition.records.keys().copied().collect::<Vec<K2>>();
485 let mut next_expiry = None;
486 for rkey in record_keys {
487 if let Some(tuples) = partition.records.get_mut(&rkey) {
488 let len = tuples.len();
489 tuples.retain(|(_, expiry)| expiry > &now);
490 pruned += len - tuples.len();
491 for (_, expiry) in tuples {
492 match next_expiry {
493 None => next_expiry = Some(*expiry),
494 Some(t) if *expiry < t => next_expiry = Some(*expiry),
495 _ => (),
496 }
497 }
498 }
499 }
500
501 partition.size -= pruned;
502
503 if let Some(ne) = next_expiry {
504 partition.next_expiry = ne;
505 self.expiry_priority.push(partition_key, Reverse(ne));
506 } else {
507 self.partitions.remove(&partition_key);
508 self.access_priority.remove(&partition_key);
509 }
510
511 self.current_size -= pruned;
512 pruned
513 } else {
514 self.access_priority.remove(&partition_key);
515 0
516 }
517 } else {
518 0
519 }
520 }
521
522 fn remove_least_recently_used(&mut self) -> usize {
527 if let Some((partition_key, _)) = self.access_priority.pop() {
528 self.expiry_priority.remove(&partition_key);
529
530 if let Some(partition) = self.partitions.remove(&partition_key) {
531 let pruned = partition.size;
532 self.current_size -= pruned;
533 pruned
534 } else {
535 0
536 }
537 } else {
538 0
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use dns_types::protocol::types::test_util::*;
546
547 use super::test_util::*;
548 use super::*;
549
550 #[test]
551 fn cache_put_can_get() {
552 for _ in 0..100 {
553 let mut cache = Cache::new();
554 let mut rr = arbitrary_resourcerecord();
555 rr.rclass = RecordClass::IN;
556 cache.insert(&rr);
557
558 assert_cache_response(
559 &rr,
560 &cache.get_without_checking_expiration(
561 &rr.name,
562 QueryType::Record(rr.rtype_with_data.rtype()),
563 ),
564 );
565 assert_cache_response(
566 &rr,
567 &cache.get_without_checking_expiration(&rr.name, QueryType::Wildcard),
568 );
569 }
570 }
571
572 #[test]
573 fn cache_put_deduplicates_and_maintains_invariants() {
574 let mut cache = Cache::new();
575 let mut rr = arbitrary_resourcerecord();
576 rr.rclass = RecordClass::IN;
577
578 cache.insert(&rr);
579 cache.insert(&rr);
580
581 assert_eq!(1, cache.inner.current_size);
582 assert_invariants(&cache);
583 }
584
585 #[test]
586 fn cache_put_maintains_invariants() {
587 let mut cache = Cache::new();
588
589 for _ in 0..100 {
590 let mut rr = arbitrary_resourcerecord();
591 rr.rclass = RecordClass::IN;
592 cache.insert(&rr);
593 }
594
595 assert_invariants(&cache);
596 }
597
598 #[test]
599 fn cache_put_then_get_maintains_invariants() {
600 let mut cache = Cache::new();
601 let mut queries = Vec::new();
602
603 for _ in 0..100 {
604 let mut rr = arbitrary_resourcerecord();
605 rr.rclass = RecordClass::IN;
606 cache.insert(&rr);
607 queries.push((
608 rr.name.clone(),
609 QueryType::Record(rr.rtype_with_data.rtype()),
610 ));
611 }
612 for (name, qtype) in queries {
613 cache.get_without_checking_expiration(&name, qtype);
614 }
615
616 assert_invariants(&cache);
617 }
618
619 #[test]
620 fn cache_put_then_prune_maintains_invariants() {
621 let mut cache = Cache::with_desired_size(25);
622
623 for _ in 0..100 {
624 let mut rr = arbitrary_resourcerecord();
625 rr.rclass = RecordClass::IN;
626 rr.ttl = 300; cache.insert(&rr);
628 }
629
630 let (overflow, current_size, expired, pruned) = cache.prune();
633 assert!(overflow);
634 assert_eq!(0, expired);
635 assert!(pruned >= 75);
636 assert!(cache.inner.current_size <= 25);
637 assert_eq!(cache.inner.current_size, current_size);
638 assert_invariants(&cache);
639 }
640
641 #[test]
642 fn cache_put_then_expire_maintains_invariants() {
643 let mut cache = Cache::new();
644
645 for i in 0..100 {
646 let mut rr = arbitrary_resourcerecord();
647 rr.rclass = RecordClass::IN;
648 rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
649 cache.insert(&rr);
650 }
651
652 assert_eq!(49, cache.inner.remove_expired());
653 assert_eq!(51, cache.inner.current_size);
654 assert_invariants(&cache);
655 }
656
657 #[test]
658 fn cache_prune_expires_all() {
659 let mut cache = Cache::with_desired_size(99);
660
661 for i in 0..100 {
662 let mut rr = arbitrary_resourcerecord();
663 rr.rclass = RecordClass::IN;
664 rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
665 cache.insert(&rr);
666 }
667
668 let (overflow, current_size, expired, pruned) = cache.prune();
669 assert!(overflow);
670 assert_eq!(49, expired);
671 assert_eq!(0, pruned);
672 assert_eq!(cache.inner.current_size, current_size);
673 assert_invariants(&cache);
674 }
675
676 fn assert_invariants(cache: &Cache) {
677 assert_eq!(
678 cache.inner.current_size,
679 cache
680 .inner
681 .partitions
682 .values()
683 .map(|e| e.size)
684 .sum::<usize>()
685 );
686
687 assert_eq!(
688 cache.inner.partitions.len(),
689 cache.inner.access_priority.len()
690 );
691 assert_eq!(
692 cache.inner.partitions.len(),
693 cache.inner.expiry_priority.len()
694 );
695
696 let mut access_priority = PriorityQueue::new();
697 let mut expiry_priority = PriorityQueue::new();
698
699 for (name, partition) in &cache.inner.partitions {
700 assert_eq!(
701 partition.size,
702 partition.records.values().map(Vec::len).sum::<usize>()
703 );
704
705 let mut min_expires = None;
706 for (rtype, tuples) in &partition.records {
707 for (rtype_with_data, expires) in tuples {
708 assert_eq!(*rtype, rtype_with_data.rtype());
709
710 if let Some(e) = min_expires {
711 if *expires < e {
712 min_expires = Some(*expires);
713 }
714 } else {
715 min_expires = Some(*expires);
716 }
717 }
718 }
719
720 assert_eq!(Some(partition.next_expiry), min_expires);
721
722 access_priority.push(name.clone(), Reverse(partition.last_read));
723 expiry_priority.push(name.clone(), Reverse(partition.next_expiry));
724 }
725
726 assert_eq!(cache.inner.access_priority, access_priority);
727 assert_eq!(cache.inner.expiry_priority, expiry_priority);
728 }
729}
730
731#[cfg(test)]
732#[allow(clippy::missing_panics_doc)]
733pub mod test_util {
734 use super::*;
735
736 pub fn assert_cache_response(original: &ResourceRecord, response: &[ResourceRecord]) {
740 assert_eq!(1, response.len());
741 let cached = response[0].clone();
742
743 assert_eq!(original.name, cached.name);
744 assert_eq!(original.rtype_with_data, cached.rtype_with_data);
745 assert_eq!(RecordClass::IN, cached.rclass);
746 assert!(original.ttl >= cached.ttl);
747 }
748}