Skip to content

Commit

Permalink
Merge pull request #17 from niclabs/fix_bug_get_cache
Browse files Browse the repository at this point in the history
Fix bug get cache
  • Loading branch information
joalopez1206 authored Sep 5, 2024
2 parents 80d3bb3 + 14447f1 commit 13ca849
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
55 changes: 36 additions & 19 deletions src/dns_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl DnsCache {
}

/// Given a domain_name, gets an element from cache
pub fn get(&mut self, domain_name: DomainName, rrtype: Rrtype, rclass: Rclass) -> Option<Vec<RRStoredData>> {
pub fn get_primary(&mut self, domain_name: DomainName, rrtype: Rrtype, rclass: Rclass) -> Option<Vec<RRStoredData>> {
let mut cache = self.get_cache();

let rr_cache_vec = cache.get(&CacheKey::Primary(rrtype, rclass, domain_name)).cloned();
Expand All @@ -130,6 +130,23 @@ impl DnsCache {
rr_cache_vec
}

pub fn get_secondary(&mut self, domain_name: DomainName, rclass: Rclass)-> Option<Vec<RRStoredData>> {
let mut cache = self.get_cache();

let rr_cache_vec = cache.get(&&CacheKey::Secondary(rclass, domain_name)).cloned();

self.set_cache(cache);

rr_cache_vec
}

pub fn get(&mut self, domain_name: DomainName, rrtype: Option<Rrtype>, rclass: Rclass) -> Option<Vec<RRStoredData>> {
if rrtype != None {
return self.get_primary(domain_name, rrtype.unwrap(), rclass)
}
self.get_secondary(domain_name, rclass)
}

/// Removes the resource records from a domain name and type which were the oldest used
pub fn remove_oldest_used(&mut self) {
let mut cache = self.get_cache();
Expand All @@ -143,7 +160,7 @@ impl DnsCache {
pub fn get_response_time(
&mut self,
domain_name: DomainName,
rrtype: Rrtype,
rrtype: Option<Rrtype>,
rclass: Rclass,
ip_address: IpAddr,
) -> u32 {
Expand Down Expand Up @@ -347,7 +364,7 @@ mod dns_cache_test {

cache.add(domain_name.clone(), resource_record.clone(), Some(Rrtype::A), Rclass::IN, None);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN).unwrap();
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN).unwrap();

let first_rr_cache = rr_cache_vec.first().unwrap();

Expand Down Expand Up @@ -382,7 +399,7 @@ mod dns_cache_test {

cache.add(domain_name.clone(), resource_record.clone(), Some(Rrtype::A), Rclass::IN, None);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN).unwrap();
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN).unwrap();

assert_eq!(rr_cache_vec.len(), 2);
}
Expand Down Expand Up @@ -411,9 +428,9 @@ mod dns_cache_test {

cache.add(domain_name.clone(), resource_record_2.clone(), Some(Rrtype::AAAA), Rclass::IN, None);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN).unwrap();
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN).unwrap();

let rr_cache_vec_2 = cache.get(domain_name.clone(), Rrtype::AAAA, Rclass::IN).unwrap();
let rr_cache_vec_2 = cache.get(domain_name.clone(), Some(Rrtype::AAAA), Rclass::IN).unwrap();

assert_eq!(rr_cache_vec.len(), 1);
assert_eq!(rr_cache_vec_2.len(), 1);
Expand Down Expand Up @@ -443,7 +460,7 @@ mod dns_cache_test {

cache.add(domain_name.clone(), resource_record.clone(), Some(Rrtype::A), Rclass::IN, None);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN).unwrap();
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN).unwrap();

assert_eq!(rr_cache_vec.len(), 1);
}
Expand All @@ -464,7 +481,7 @@ mod dns_cache_test {

cache.remove(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec.is_none());
}
Expand All @@ -483,7 +500,7 @@ mod dns_cache_test {

cache.add(domain_name.clone(), resource_record.clone(), Some(Rrtype::A), Rclass::IN, None);

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN).unwrap();
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN).unwrap();

let first_rr_cache = rr_cache_vec.first().unwrap();

Expand All @@ -508,7 +525,7 @@ mod dns_cache_test {
let mut cache = DnsCache::new(NonZeroUsize::new(10));
let domain_name = DomainName::new_from_str("example.com");

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec.is_none());
}
Expand Down Expand Up @@ -547,21 +564,21 @@ mod dns_cache_test {
cache.add(domain_name_2.clone(), resource_record_2.clone(), Some(Rrtype::A), Rclass::IN, None);
cache.add(domain_name_3.clone(), resource_record_3.clone(), Some(Rrtype::A), Rclass::IN, None);

let _rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN);
let _rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

let _rr_cache_vec_2 = cache.get(domain_name_2.clone(), Rrtype::A, Rclass::IN);
let _rr_cache_vec_2 = cache.get(domain_name_2.clone(), Some(Rrtype::A), Rclass::IN);

cache.remove_oldest_used();

let rr_cache_vec = cache.get(domain_name_3.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec = cache.get(domain_name_3.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec.is_none());

let rr_cache_vec_2 = cache.get(domain_name_2.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec_2 = cache.get(domain_name_2.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec_2.is_some());

let rr_cache_vec_3 = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec_3 = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec_3.is_some());
}
Expand Down Expand Up @@ -590,7 +607,7 @@ mod dns_cache_test {

cache.set_cache(lru_cache);

let response_time_obtained = cache.get_response_time(domain_name.clone(), Rrtype::A, Rclass::IN, ip_address);
let response_time_obtained = cache.get_response_time(domain_name.clone(), Some(Rrtype::A), Rclass::IN, ip_address);

assert_eq!(response_time_obtained, response_time);
}
Expand All @@ -613,7 +630,7 @@ mod dns_cache_test {

cache.update_response_time(domain_name.clone(), Rrtype::A, Rclass::IN, new_response_time, ip_address);

let response_time_obtained = cache.get_response_time(domain_name.clone(), Rrtype::A, Rclass::IN, ip_address);
let response_time_obtained = cache.get_response_time(domain_name.clone(), Some(Rrtype::A), Rclass::IN, ip_address);

assert_eq!(response_time_obtained, new_response_time);
}
Expand Down Expand Up @@ -718,11 +735,11 @@ mod dns_cache_test {

assert!(!cache.is_empty());

let rr_cache_vec = cache.get(domain_name.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec = cache.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec.is_none());

let rr_cache_vec_2 = cache.get(domain_name_2.clone(), Rrtype::A, Rclass::IN);
let rr_cache_vec_2 = cache.get(domain_name_2.clone(), Some(Rrtype::A), Rclass::IN);

assert!(rr_cache_vec_2.is_some());
}
Expand Down
26 changes: 13 additions & 13 deletions src/resolver_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl ResolverCache {
qtype: Rrtype,
qclass: Rclass,
) -> Option<Vec<ResourceRecord>> {
let rr_stored_data = self.cache_answer.get(domain_name, qtype, qclass);
let rr_stored_data = self.cache_answer.get(domain_name, Some(qtype), qclass);

if let Some(rr_stored_data) = rr_stored_data {
let mut rr_vec = Vec::new();
Expand All @@ -199,7 +199,7 @@ impl ResolverCache {
qtype: Rrtype,
qclass: Rclass,
) -> Option<Vec<ResourceRecord>> {
let rr_stored_data = self.cache_authority.get(domain_name, qtype, qclass);
let rr_stored_data = self.cache_authority.get(domain_name, Some(qtype), qclass);

if let Some(rr_stored_data) = rr_stored_data {
let mut rr_vec = Vec::new();
Expand All @@ -219,7 +219,7 @@ impl ResolverCache {
qtype: Rrtype,
qclass: Rclass,
) -> Option<Vec<ResourceRecord>> {
let rr_stored_data = self.cache_additional.get(domain_name, qtype, qclass);
let rr_stored_data = self.cache_additional.get(domain_name, Some(qtype), qclass);

if let Some(rr_stored_data) = rr_stored_data {
let mut rr_vec = Vec::new();
Expand All @@ -238,7 +238,7 @@ impl ResolverCache {
qtype: Rrtype,
qclass: Rclass,
) -> Option<Rcode> {
let rr_stored_data = self.cache_answer.get(domain_name, qtype, qclass);
let rr_stored_data = self.cache_answer.get(domain_name, Some(qtype), qclass);

if let Some(rr_stored_data) = rr_stored_data {
Some(rr_stored_data[0].get_rcode())
Expand Down Expand Up @@ -515,7 +515,7 @@ mod resolver_cache_test {

let rr = resolver_cache
.cache_answer
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();

assert_eq!(rr[0].get_resource_record(), resource_record);
Expand Down Expand Up @@ -547,7 +547,7 @@ mod resolver_cache_test {

let rr = resolver_cache
.cache_authority
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();

assert_eq!(rr[0].get_resource_record(), resource_record);
Expand Down Expand Up @@ -579,7 +579,7 @@ mod resolver_cache_test {

let rr = resolver_cache
.cache_additional
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();

assert_eq!(rr[0].get_resource_record(), resource_record);
Expand Down Expand Up @@ -647,15 +647,15 @@ mod resolver_cache_test {

let rr_answer = resolver_cache
.cache_answer
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();
let rr_authority = resolver_cache
.cache_authority
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();
let rr_additional = resolver_cache
.cache_additional
.get(domain_name.clone(), Rrtype::A, Rclass::IN)
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN)
.unwrap();

assert_eq!(rr_answer[0].get_resource_record(), resource_record_1);
Expand Down Expand Up @@ -1320,15 +1320,15 @@ mod resolver_cache_test {

let rr_answer = resolver_cache
.cache_answer
.get(domain_name.clone(), Rrtype::A, Rclass::IN);
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);
let rr_authority =
resolver_cache
.cache_authority
.get(domain_name.clone(), Rrtype::A, Rclass::IN);
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);
let rr_additional =
resolver_cache
.cache_additional
.get(domain_name.clone(), Rrtype::A, Rclass::IN);
.get(domain_name.clone(), Some(Rrtype::A), Rclass::IN);

assert_eq!(rr_answer, None);
assert_eq!(rr_authority, None);
Expand Down

0 comments on commit 13ca849

Please sign in to comment.