1use priority_queue::PriorityQueue;
2use std::cmp::Eq;
3use std::cmp::Reverse;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::marker::Copy;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10use dns_types::protocol::types::*;
11
12#[derive(Debug, Clone)]
18pub struct SharedCache {
19 cache: Arc<Mutex<Cache>>,
20}
21
22const MUTEX_POISON_MESSAGE: &str =
23 "[INTERNAL ERROR] cache mutex poisoned, cannot recover from this - aborting";
24
25impl SharedCache {
26 pub fn new() -> Self {
28 SharedCache {
29 cache: Arc::new(Mutex::new(Cache::new())),
30 }
31 }
32
33 pub fn with_desired_size(desired_size: usize) -> Self {
35 SharedCache {
36 cache: Arc::new(Mutex::new(Cache::with_desired_size(desired_size))),
37 }
38 }
39
40 pub fn get(&self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
50 self.cache
51 .lock()
52 .expect(MUTEX_POISON_MESSAGE)
53 .get(name, qtype)
54 }
55
56 pub fn get_without_checking_expiration(
65 &self,
66 name: &DomainName,
67 qtype: QueryType,
68 ) -> Vec<ResourceRecord> {
69 self.cache
70 .lock()
71 .expect(MUTEX_POISON_MESSAGE)
72 .get_without_checking_expiration(name, qtype)
73 }
74
75 pub fn insert(&self, record: &ResourceRecord) {
85 if record.ttl > 0 {
86 let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
87 cache.insert(record);
88 }
89 }
90
91 pub fn insert_all(&self, records: &[ResourceRecord]) {
104 let mut cache = self.cache.lock().expect(MUTEX_POISON_MESSAGE);
105 for record in records {
106 if record.ttl > 0 {
107 cache.insert(record);
108 }
109 }
110 }
111
112 pub fn prune(&self) -> (bool, usize, usize, usize) {
121 self.cache.lock().expect(MUTEX_POISON_MESSAGE).prune()
122 }
123}
124
125impl Default for SharedCache {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[derive(Debug, Clone)]
135pub struct Cache {
136 inner: PartitionedCache<DomainName, RecordType, RecordTypeWithData>,
137}
138
139impl Default for Cache {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl Cache {
146 pub fn new() -> Self {
148 Self {
149 inner: PartitionedCache::new(),
150 }
151 }
152
153 pub fn with_desired_size(desired_size: usize) -> Self {
158 Self {
159 inner: PartitionedCache::with_desired_size(desired_size),
160 }
161 }
162
163 pub fn get(&mut self, name: &DomainName, qtype: QueryType) -> Vec<ResourceRecord> {
169 let mut rrs = self.get_without_checking_expiration(name, qtype);
170 rrs.retain(|rr| rr.ttl > 0);
171 rrs
172 }
173
174 pub fn get_without_checking_expiration(
179 &mut self,
180 name: &DomainName,
181 qtype: QueryType,
182 ) -> Vec<ResourceRecord> {
183 let now = Instant::now();
184 let mut rrs = Vec::new();
185 match qtype {
186 QueryType::Wildcard => {
187 if let Some(records) = self.inner.get_partition_without_checking_expiration(name) {
188 for tuples in records.values() {
189 to_rrs(name, now, tuples, &mut rrs);
190 }
191 }
192 }
193 QueryType::Record(rtype) => {
194 if let Some(tuples) = self.inner.get_without_checking_expiration(name, &rtype) {
195 to_rrs(name, now, tuples, &mut rrs);
196 }
197 }
198 _ => (),
199 }
200
201 rrs
202 }
203
204 pub fn insert(&mut self, record: &ResourceRecord) {
206 self.inner.upsert(
207 record.name.clone(),
208 record.rtype_with_data.rtype(),
209 record.rtype_with_data.clone(),
210 Duration::from_secs(record.ttl.into()),
211 );
212 }
213
214 pub fn prune(&mut self) -> (bool, usize, usize, usize) {
219 self.inner.prune()
220 }
221}
222
223fn to_rrs(
226 name: &DomainName,
227 now: Instant,
228 tuples: &[(RecordTypeWithData, Instant)],
229 rrs: &mut Vec<ResourceRecord>,
230) {
231 for (rtype, expires) in tuples {
232 rrs.push(ResourceRecord {
233 name: name.clone(),
234 rtype_with_data: rtype.clone(),
235 rclass: RecordClass::IN,
236 ttl: expires
237 .saturating_duration_since(now)
238 .as_secs()
239 .try_into()
240 .unwrap_or(u32::MAX),
241 });
242 }
243}
244
245#[derive(Debug, Clone)]
246pub struct PartitionedCache<K1: Eq + Hash, K2: Eq + Hash, V> {
247 partitions: HashMap<K1, Partition<K2, V>>,
249
250 access_priority: PriorityQueue<K1, Reverse<Instant>>,
257
258 expiry_priority: PriorityQueue<K1, Reverse<Instant>>,
264
265 current_size: usize,
269
270 desired_size: usize,
272}
273
274#[derive(Debug, Clone, Eq, PartialEq)]
276struct Partition<K: Eq + Hash, V> {
277 last_read: Instant,
279
280 next_expiry: Instant,
284
285 size: usize,
289
290 records: HashMap<K, Vec<(V, Instant)>>,
292}
293
294impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> Default
295 for PartitionedCache<K1, K2, V>
296{
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302impl<K1: Clone + Eq + Hash, K2: Copy + Eq + Hash, V: PartialEq> PartitionedCache<K1, K2, V> {
303 pub fn new() -> Self {
305 Self::with_desired_size(512)
306 }
307
308 pub fn with_desired_size(desired_size: usize) -> Self {
313 Self {
314 partitions: HashMap::with_capacity(desired_size / 2),
318 access_priority: PriorityQueue::with_capacity(desired_size),
319 expiry_priority: PriorityQueue::with_capacity(desired_size),
320 current_size: 0,
321 desired_size,
322 }
323 }
324
325 pub fn get_partition_without_checking_expiration(
330 &mut self,
331 partition_key: &K1,
332 ) -> Option<&HashMap<K2, Vec<(V, Instant)>>> {
333 if let Some(partition) = self.partitions.get_mut(partition_key) {
334 partition.last_read = Instant::now();
335 self.access_priority
336 .change_priority(partition_key, Reverse(partition.last_read));
337 return Some(&partition.records);
338 }
339
340 None
341 }
342
343 pub fn get_without_checking_expiration(
348 &mut self,
349 partition_key: &K1,
350 record_key: &K2,
351 ) -> Option<&[(V, Instant)]> {
352 if let Some(partition) = self.partitions.get_mut(partition_key) {
353 if let Some(tuples) = partition.records.get(record_key) {
354 partition.last_read = Instant::now();
355 self.access_priority
356 .change_priority(partition_key, Reverse(partition.last_read));
357 return Some(tuples);
358 }
359 }
360
361 None
362 }
363
364 pub fn upsert(&mut self, partition_key: K1, record_key: K2, value: V, ttl: Duration) {
367 let now = Instant::now();
368 let expiry = now + ttl;
369 let tuple = (value, expiry);
370 if let Some(partition) = self.partitions.get_mut(&partition_key) {
371 if let Some(tuples) = partition.records.get_mut(&record_key) {
372 let mut duplicate_expires_at = None;
373 for i in 0..tuples.len() {
374 let t = &tuples[i];
375 if t.0 == tuple.0 {
376 duplicate_expires_at = Some(t.1);
377 tuples.swap_remove(i);
378 break;
379 }
380 }
381
382 tuples.push(tuple);
383
384 if let Some(dup_expiry) = duplicate_expires_at {
385 partition.size -= 1;
386 self.current_size -= 1;
387
388 if dup_expiry == partition.next_expiry {
389 let mut new_next_expiry = expiry;
390 for (_, e) in tuples {
391 if *e < new_next_expiry {
392 new_next_expiry = *e;
393 }
394 }
395 partition.next_expiry = new_next_expiry;
396 self.expiry_priority
397 .change_priority(&partition_key, Reverse(partition.next_expiry));
398 }
399 }
400 } else {
401 partition.records.insert(record_key, vec![tuple]);
402 }
403 partition.last_read = now;
404 partition.size += 1;
405 self.access_priority
406 .change_priority(&partition_key, Reverse(partition.last_read));
407 if expiry < partition.next_expiry {
408 partition.next_expiry = expiry;
409 self.expiry_priority
410 .change_priority(&partition_key, Reverse(partition.next_expiry));
411 }
412 } else {
413 let mut records = HashMap::new();
414 records.insert(record_key, vec![tuple]);
415 let partition = Partition {
416 last_read: now,
417 next_expiry: expiry,
418 size: 1,
419 records,
420 };
421 self.access_priority
422 .push(partition_key.clone(), Reverse(partition.last_read));
423 self.expiry_priority
424 .push(partition_key.clone(), Reverse(partition.next_expiry));
425 self.partitions.insert(partition_key, partition);
426 }
427
428 self.current_size += 1;
429 }
430
431 pub fn remove_expired(&mut self) -> usize {
435 let mut pruned = 0;
436
437 loop {
438 let before = pruned;
439 pruned += self.remove_expired_step();
440 if before == pruned {
441 break;
442 }
443 }
444
445 pruned
446 }
447
448 pub fn prune(&mut self) -> (bool, usize, usize, usize) {
454 let has_overflowed = self.current_size > self.desired_size;
455 let num_expired = self.remove_expired();
456 let mut num_pruned = 0;
457
458 while self.current_size > self.desired_size {
459 num_pruned += self.remove_least_recently_used();
460 }
461
462 (has_overflowed, self.current_size, num_expired, num_pruned)
463 }
464
465 fn remove_expired_step(&mut self) -> usize {
471 if let Some((partition_key, Reverse(expiry))) = self.expiry_priority.pop() {
472 let now = Instant::now();
473
474 if expiry > now {
475 self.expiry_priority.push(partition_key, Reverse(expiry));
476 return 0;
477 }
478
479 if let Some(partition) = self.partitions.get_mut(&partition_key) {
480 let mut pruned = 0;
481
482 let record_keys = partition.records.keys().copied().collect::<Vec<K2>>();
483 let mut next_expiry = None;
484 for rkey in record_keys {
485 if let Some(tuples) = partition.records.get_mut(&rkey) {
486 let len = tuples.len();
487 tuples.retain(|(_, expiry)| expiry > &now);
488 pruned += len - tuples.len();
489 for (_, expiry) in tuples {
490 match next_expiry {
491 None => next_expiry = Some(*expiry),
492 Some(t) if *expiry < t => next_expiry = Some(*expiry),
493 _ => (),
494 }
495 }
496 }
497 }
498
499 partition.size -= pruned;
500
501 if let Some(ne) = next_expiry {
502 partition.next_expiry = ne;
503 self.expiry_priority.push(partition_key, Reverse(ne));
504 } else {
505 self.partitions.remove(&partition_key);
506 self.access_priority.remove(&partition_key);
507 }
508
509 self.current_size -= pruned;
510 pruned
511 } else {
512 self.access_priority.remove(&partition_key);
513 0
514 }
515 } else {
516 0
517 }
518 }
519
520 fn remove_least_recently_used(&mut self) -> usize {
525 if let Some((partition_key, _)) = self.access_priority.pop() {
526 self.expiry_priority.remove(&partition_key);
527
528 if let Some(partition) = self.partitions.remove(&partition_key) {
529 let pruned = partition.size;
530 self.current_size -= pruned;
531 pruned
532 } else {
533 0
534 }
535 } else {
536 0
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use dns_types::protocol::types::test_util::*;
544
545 use super::test_util::*;
546 use super::*;
547
548 #[test]
549 fn cache_put_can_get() {
550 for _ in 0..100 {
551 let mut cache = Cache::new();
552 let mut rr = arbitrary_resourcerecord();
553 rr.rclass = RecordClass::IN;
554 cache.insert(&rr);
555
556 assert_cache_response(
557 &rr,
558 &cache.get_without_checking_expiration(
559 &rr.name,
560 QueryType::Record(rr.rtype_with_data.rtype()),
561 ),
562 );
563 assert_cache_response(
564 &rr,
565 &cache.get_without_checking_expiration(&rr.name, QueryType::Wildcard),
566 );
567 }
568 }
569
570 #[test]
571 fn cache_put_deduplicates_and_maintains_invariants() {
572 let mut cache = Cache::new();
573 let mut rr = arbitrary_resourcerecord();
574 rr.rclass = RecordClass::IN;
575
576 cache.insert(&rr);
577 cache.insert(&rr);
578
579 assert_eq!(1, cache.inner.current_size);
580 assert_invariants(&cache);
581 }
582
583 #[test]
584 fn cache_put_maintains_invariants() {
585 let mut cache = Cache::new();
586
587 for _ in 0..100 {
588 let mut rr = arbitrary_resourcerecord();
589 rr.rclass = RecordClass::IN;
590 cache.insert(&rr);
591 }
592
593 assert_invariants(&cache);
594 }
595
596 #[test]
597 fn cache_put_then_get_maintains_invariants() {
598 let mut cache = Cache::new();
599 let mut queries = Vec::new();
600
601 for _ in 0..100 {
602 let mut rr = arbitrary_resourcerecord();
603 rr.rclass = RecordClass::IN;
604 cache.insert(&rr);
605 queries.push((
606 rr.name.clone(),
607 QueryType::Record(rr.rtype_with_data.rtype()),
608 ));
609 }
610 for (name, qtype) in queries {
611 cache.get_without_checking_expiration(&name, qtype);
612 }
613
614 assert_invariants(&cache);
615 }
616
617 #[test]
618 fn cache_put_then_prune_maintains_invariants() {
619 let mut cache = Cache::with_desired_size(25);
620
621 for _ in 0..100 {
622 let mut rr = arbitrary_resourcerecord();
623 rr.rclass = RecordClass::IN;
624 rr.ttl = 300; cache.insert(&rr);
626 }
627
628 let (overflow, current_size, expired, pruned) = cache.prune();
631 assert!(overflow);
632 assert_eq!(0, expired);
633 assert!(pruned >= 75);
634 assert!(cache.inner.current_size <= 25);
635 assert_eq!(cache.inner.current_size, current_size);
636 assert_invariants(&cache);
637 }
638
639 #[test]
640 fn cache_put_then_expire_maintains_invariants() {
641 let mut cache = Cache::new();
642
643 for i in 0..100 {
644 let mut rr = arbitrary_resourcerecord();
645 rr.rclass = RecordClass::IN;
646 rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
647 cache.insert(&rr);
648 }
649
650 assert_eq!(49, cache.inner.remove_expired());
651 assert_eq!(51, cache.inner.current_size);
652 assert_invariants(&cache);
653 }
654
655 #[test]
656 fn cache_prune_expires_all() {
657 let mut cache = Cache::with_desired_size(99);
658
659 for i in 0..100 {
660 let mut rr = arbitrary_resourcerecord();
661 rr.rclass = RecordClass::IN;
662 rr.ttl = if i > 0 && i % 2 == 0 { 0 } else { 300 };
663 cache.insert(&rr);
664 }
665
666 let (overflow, current_size, expired, pruned) = cache.prune();
667 assert!(overflow);
668 assert_eq!(49, expired);
669 assert_eq!(0, pruned);
670 assert_eq!(cache.inner.current_size, current_size);
671 assert_invariants(&cache);
672 }
673
674 fn assert_invariants(cache: &Cache) {
675 assert_eq!(
676 cache.inner.current_size,
677 cache
678 .inner
679 .partitions
680 .values()
681 .map(|e| e.size)
682 .sum::<usize>()
683 );
684
685 assert_eq!(
686 cache.inner.partitions.len(),
687 cache.inner.access_priority.len()
688 );
689 assert_eq!(
690 cache.inner.partitions.len(),
691 cache.inner.expiry_priority.len()
692 );
693
694 let mut access_priority = PriorityQueue::new();
695 let mut expiry_priority = PriorityQueue::new();
696
697 for (name, partition) in &cache.inner.partitions {
698 assert_eq!(
699 partition.size,
700 partition.records.values().map(Vec::len).sum::<usize>()
701 );
702
703 let mut min_expires = None;
704 for (rtype, tuples) in &partition.records {
705 for (rtype_with_data, expires) in tuples {
706 assert_eq!(*rtype, rtype_with_data.rtype());
707
708 if let Some(e) = min_expires {
709 if *expires < e {
710 min_expires = Some(*expires);
711 }
712 } else {
713 min_expires = Some(*expires);
714 }
715 }
716 }
717
718 assert_eq!(Some(partition.next_expiry), min_expires);
719
720 access_priority.push(name.clone(), Reverse(partition.last_read));
721 expiry_priority.push(name.clone(), Reverse(partition.next_expiry));
722 }
723
724 assert_eq!(cache.inner.access_priority, access_priority);
725 assert_eq!(cache.inner.expiry_priority, expiry_priority);
726 }
727}
728
729#[cfg(test)]
730#[allow(clippy::missing_panics_doc)]
731pub mod test_util {
732 use super::*;
733
734 pub fn assert_cache_response(original: &ResourceRecord, response: &[ResourceRecord]) {
738 assert_eq!(1, response.len());
739 let cached = response[0].clone();
740
741 assert_eq!(original.name, cached.name);
742 assert_eq!(original.rtype_with_data, cached.rtype_with_data);
743 assert_eq!(RecordClass::IN, cached.rclass);
744 assert!(original.ttl >= cached.ttl);
745 }
746}