Skip to content

Commit

Permalink
refactor DNS cache implementation to use CacheKey enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Litr0 committed Jun 11, 2024
1 parent edef97f commit 4de52f3
Showing 1 changed file with 44 additions and 126 deletions.
170 changes: 44 additions & 126 deletions src/async_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::net::IpAddr;
use std::vec;
use std::sync::{Arc, Mutex};
use crate::client::client_error::ClientError;
use crate::dns_cache::DnsCache;
use crate::resolver_cache::ResolverCache;
use crate::domain_name::DomainName;
use crate::message::rcode::Rcode;
use crate::message::{self, DnsMessage};
Expand Down Expand Up @@ -39,12 +39,8 @@ use self::lookup_response::LookupResponse;
/// `lookup_ip` method.
#[derive(Clone)]
pub struct AsyncResolver {
/// Cache for the answer section of the DNS messages.
cache_answer: Arc<Mutex<DnsCache>>,
/// Cache for the additional section of the DNS messages.
cache_additional: Arc<Mutex<DnsCache>>,
/// Cache for the authority section of the DNS messages.
cache_authority: Arc<Mutex<DnsCache>>,
/// Cache for the resolver
cache: Arc<Mutex<ResolverCache>>,
/// Configuration for the resolver.
config: ResolverConfig ,
}
Expand All @@ -65,9 +61,7 @@ impl AsyncResolver {
/// ```
pub fn new(config: ResolverConfig)-> Self {
let async_resolver = AsyncResolver {
cache_answer: Arc::new(Mutex::new(DnsCache::new(None))),
cache_additional: Arc::new(Mutex::new(DnsCache::new(None))),
cache_authority: Arc::new(Mutex::new(DnsCache::new(None))),
cache: Arc::new(Mutex::new(ResolverCache::new(None))),
config: config,
};
async_resolver
Expand Down Expand Up @@ -224,33 +218,15 @@ impl AsyncResolver {
// Cache lookup
// Search in cache only if its available
if self.config.is_cache_enabled() {
let lock_result = self.cache_answer.lock();
let lock_result = self.cache.lock();
let cache = match lock_result {
Ok(val) => val,
Err(_) => Err(ClientError::Message("Error getting cache"))?, // FIXME: it shouldn't
// return the error, it shoul go to the next part of the code
// return the error, it should go to the next part of the code
};
if let Some(cache_lookup) = cache.clone().get(domain_name.clone(), qtype, qclass) {
let mut new_query = query.clone();

// Get RR from cache
for rr_cache_value in cache_lookup.iter() {
let rr = rr_cache_value.get_resource_record();

// Get negative answer
if u16::from(qtype) != u16::from(rr.get_rtype()) {
let additionals: Vec<ResourceRecord> = vec![rr];
new_query.add_additionals(additionals);
let mut new_header = new_query.get_header();
new_header.set_rcode(Rcode::NXDOMAIN); // TODO: is here where the other problem originates?
new_query.set_header(new_header);
}
else { //FIXME: change to alg RFC 1034-1035
let answer: Vec<ResourceRecord> = vec![rr];
new_query.set_answer(answer);
}
}
let new_lookup_response = LookupResponse::new(new_query);
if let Some(cache_lookup) = cache.clone().get(query.clone()) {

let new_lookup_response = LookupResponse::new(cache_lookup.clone());

return Ok(new_lookup_response)
}
Expand Down Expand Up @@ -335,59 +311,13 @@ impl AsyncResolver {
let truncated = response.get_header().get_tc();
let rcode = response.get_header().get_rcode();
{
let mut cache_answer = self.cache_answer.lock().unwrap();
cache_answer.timeout_cache();
if !truncated {
// TODO: RFC 1035: 7.4. Using the cache
response.get_answer()
.iter()
.for_each(|rr| {
if rr.get_ttl() > 0 {
cache_answer.add(rr.get_name(),
rr.clone(),
Some(response.get_question().get_qtype()),
response.get_question().get_qclass(),
Some(rcode));
}
});

}
let mut cache_additional = self.cache_additional.lock().unwrap();
cache_additional.timeout_cache();
let mut cache = self.cache.lock().unwrap();
cache.timeout();
if !truncated {
response.get_additional()
.iter()
.for_each(|rr| {
if rr.get_ttl() > 0 {
let rtype = rr.get_rtype();
// Do not cache OPT records
if rtype != Rtype::OPT{
cache_additional.add(rr.get_name(),
rr.clone(),
Some(response.get_question().get_qtype()),
response.get_question().get_qclass(),
Some(rcode));
}
}
});
}
let mut cache_authority = self.cache_authority.lock().unwrap();
cache_authority.timeout_cache();
if !truncated {
response.get_authority()
.iter()
.for_each(|rr| {
if rr.get_ttl() > 0 {
cache_authority.add(rr.get_name(),
rr.clone(),
Some(response.get_question().get_qtype()),
response.get_question().get_qclass(),
Some(rcode));
}
});
}
cache.add(response.clone());
}
self.save_negative_answers(response);
}
}

/// Stores the data of negative answers in the cache.
Expand Down Expand Up @@ -424,19 +354,19 @@ impl AsyncResolver {
let qname = response.get_question().get_qname();
let qtype = response.get_question().get_qtype();
let qclass = response.get_question().get_qclass();
let rcode = response.get_header().get_rcode();
let additionals = response.get_additional();
let answer = response.get_answer();
let aa = response.get_header().get_aa();

// If not existence RR for query, add SOA to cache
let mut cache = self.cache_answer.lock().unwrap(); // FIXME: que la función entregue result
let mut cache = self.cache.lock().unwrap(); // FIXME: que la función entregue result
if additionals.len() > 0 && answer.len() == 0 && aa == true{
additionals.iter()
.for_each(|rr| {
if rr.get_rtype() == Rtype::SOA {
cache.add_negative_answer(qname.clone(),qtype , qclass, rr.clone());
for additional in additionals {
if additional.get_rtype() == Rtype::SOA {
cache.add_additional(qname.clone(), additional, Some(qtype), qclass, Some(rcode));
}
});
}
}

}
Expand Down Expand Up @@ -480,20 +410,8 @@ impl AsyncResolver {
// Getters
impl AsyncResolver {
// Gets the cache from the struct
pub fn get_cache_answer(&self) -> DnsCache {
let cache = self.cache_answer.lock().unwrap(); // FIXME: ver que hacer con el error
return cache.clone();
}

// Gets the cache_additional from the struct
pub fn get_cache_additional(&self) -> DnsCache {
let cache = self.cache_additional.lock().unwrap();
return cache.clone();
}

// Gets the cache_authority from the struct
pub fn get_cache_authority(&self) -> DnsCache {
let cache = self.cache_authority.lock().unwrap();
pub fn get_cache(&self) -> ResolverCache {
let cache = self.cache.lock().unwrap(); // FIXME: ver que hacer con el error
return cache.clone();
}
}
Expand Down Expand Up @@ -920,13 +838,13 @@ mod async_resolver_test {
#[tokio::test]
async fn inner_lookup_cache_available() {
let resolver = AsyncResolver::new(ResolverConfig::default());
resolver.cache_answer.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());
resolver.cache.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());

let domain_name = DomainName::new_from_string("example.com".to_string());
let a_rdata = ARdata::new_from_addr(IpAddr::from_str("93.184.216.34").unwrap());
let a_rdata = Rdata::A(a_rdata);
let resource_record = ResourceRecord::new(a_rdata);
resolver.cache_answer.lock().unwrap().add(domain_name, resource_record, Some(Qtype::A), Qclass::IN, None);
resolver.cache.lock().unwrap().add_answer(domain_name, resource_record, Some(Qtype::A), Qclass::IN, None);

let domain_name = DomainName::new_from_string("example.com".to_string());
let response = resolver.inner_lookup(domain_name, Qtype::A, Qclass::IN).await;
Expand All @@ -946,14 +864,14 @@ mod async_resolver_test {

let resolver = AsyncResolver::new(config);
{
let mut cache = resolver.cache_answer.lock().unwrap();
let mut cache = resolver.cache.lock().unwrap();
cache.set_max_size(NonZeroUsize::new(1).unwrap());

let domain_name = DomainName::new_from_string("example.com".to_string());
let a_rdata = ARdata::new_from_addr(IpAddr::from_str("93.184.216.34").unwrap());
let a_rdata = Rdata::A(a_rdata);
let resource_record = ResourceRecord::new(a_rdata);
cache.add(domain_name, resource_record, Some(Qtype::A), Qclass::IN, None);
cache.add_answer(domain_name, resource_record, Some(Qtype::A), Qclass::IN, None);
}

let domain_name = DomainName::new_from_string("example.com".to_string());
Expand All @@ -970,11 +888,11 @@ mod async_resolver_test {
#[tokio::test]
async fn cache_data() {
let mut resolver = AsyncResolver::new(ResolverConfig::default());
resolver.cache_answer.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());
assert_eq!(resolver.cache_answer.lock().unwrap().is_empty(), true);
resolver.cache.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());
assert_eq!(resolver.cache.lock().unwrap().is_empty(), true);

let _response = resolver.lookup("example.com", "UDP", "A","IN").await;
assert_eq!(resolver.cache_answer.lock().unwrap().is_cached(CacheKey::Primary(Qtype::A, Qclass::IN, DomainName::new_from_str("example.com"))), true);
assert_eq!(resolver.cache.lock().unwrap().is_cached(CacheKey::Primary(Qtype::A, Qclass::IN, DomainName::new_from_str("example.com"))), true);
// TODO: Test special cases from RFC
}

Expand Down Expand Up @@ -1913,7 +1831,7 @@ mod async_resolver_test {
fn not_store_data_in_cache_if_truncated() {
let resolver = AsyncResolver::new(ResolverConfig::default());

resolver.cache_answer.lock().unwrap().set_max_size(NonZeroUsize::new(10).unwrap());
resolver.cache.lock().unwrap().set_max_size(NonZeroUsize::new(10).unwrap());


let domain_name = DomainName::new_from_string("example.com".to_string());
Expand All @@ -1932,14 +1850,14 @@ mod async_resolver_test {
dns_response.set_header(truncated_header);

resolver.store_data_cache(dns_response);

assert_eq!(resolver.get_cache_answer().get_cache().len(), 0);
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 0);
}

#[test]
fn not_store_cero_ttl_data_in_cache() {
let resolver = AsyncResolver::new(ResolverConfig::default());
resolver.cache_answer.lock().unwrap().set_max_size(NonZeroUsize::new(10).unwrap());
resolver.cache.lock().unwrap().set_max_size(NonZeroUsize::new(10).unwrap());

let domain_name = DomainName::new_from_string("example.com".to_string());

Expand Down Expand Up @@ -1975,16 +1893,16 @@ mod async_resolver_test {

dns_response.set_answer(answer);
assert_eq!(dns_response.get_answer().len(), 3);
assert_eq!(resolver.get_cache_answer().get_cache().len(), 0);
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 0);

resolver.store_data_cache(dns_response);
assert_eq!(resolver.get_cache_answer().get_cache().len(), 2);
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 2);
}

#[test]
fn save_cache_negative_answer(){
let resolver = AsyncResolver::new(ResolverConfig::default());
resolver.cache_answer.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());
resolver.cache.lock().unwrap().set_max_size(NonZeroUsize::new(1).unwrap());

let domain_name = DomainName::new_from_string("banana.exaple".to_string());
let mname = DomainName::new_from_string("a.root-servers.net.".to_string());
Expand Down Expand Up @@ -2031,16 +1949,16 @@ mod async_resolver_test {
let qtype_search = Qtype::A;
assert_eq!(dns_response.get_answer().len(), 0);
assert_eq!(dns_response.get_additional().len(), 1);
assert_eq!(resolver.get_cache_answer().get_cache().len(), 1);
assert!(resolver.get_cache_answer().get(dns_response.get_question().get_qname().clone(), qtype_search, Qclass::IN).is_some())
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 1);
// assert!(resolver.cache.lock().unwrap().get_cache_answer().get(dns_response.get_question().get_qname().clone(), qtype_search, Qclass::IN).is_some())

}

#[ignore = "Optional, not implemented"]
/* #[ignore = "Optional, not implemented"]
#[tokio::test]
async fn inner_lookup_negative_answer_in_cache(){
let resolver = AsyncResolver::new(ResolverConfig::default());
let mut cache = resolver.get_cache_answer();
let mut cache = resolver.cache.lock().unwrap().get_cache_answer();
let qtype = Qtype::A;
cache.set_max_size(NonZeroUsize::new(9).unwrap());
Expand Down Expand Up @@ -2069,18 +1987,18 @@ mod async_resolver_test {
rr.set_name(domain_name.clone());
// Add negative answer to cache
let mut cache = resolver.get_cache_answer();
let mut cache = resolver.cache.lock().unwrap().get_cache_answer();
cache.set_max_size(NonZeroUsize::new(9).unwrap());
cache.add_negative_answer(domain_name.clone(),qtype ,Qclass::IN, rr.clone());
let mut cache_guard = resolver.cache_answer.lock().unwrap();
let mut cache_guard = resolver.cache.lock().unwrap().get_cache_answer();
*cache_guard = cache;
assert_eq!(resolver.get_cache_answer().get_cache().len(), 1);
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 1);
let qclass = Qclass::IN;
let response = resolver.inner_lookup(domain_name,qtype,qclass).await.unwrap();
assert_eq!(resolver.get_cache_answer().get_cache().len(), 1);
assert_eq!(resolver.cache.lock().unwrap().get_cache_answer().get_cache().len(), 1);
assert_eq!(response.to_dns_msg().get_answer().len(), 0);
assert_eq!(response
.to_dns_msg()
Expand All @@ -2090,7 +2008,7 @@ mod async_resolver_test {
.to_dns_msg()
.get_header()
.get_rcode(), Rcode::NXDOMAIN);
}
} */

// TODO: Finish tests, it shoudl verify that we can send several asynchroneous queries concurrently
#[tokio::test]
Expand Down

0 comments on commit 4de52f3

Please sign in to comment.