diff --git a/src/async_resolver.rs b/src/async_resolver.rs index 14d764e9..863a5410 100644 --- a/src/async_resolver.rs +++ b/src/async_resolver.rs @@ -241,7 +241,7 @@ impl AsyncResolver { // return the error, it should go to the next part of the code }; if let Some(cache_lookup) = cache.clone().get(query.clone()) { - let new_lookup_response = LookupResponse::new(cache_lookup.clone()); + let new_lookup_response = LookupResponse::new(cache_lookup.clone(), cache_lookup.to_bytes()); return Ok(new_lookup_response); } @@ -436,7 +436,8 @@ impl AsyncResolver { Err(_) => Err(ClientError::TemporaryError("no DNS message found")), }; - let dns_response = lookup_response.unwrap().to_dns_msg(); + let lookup_response = lookup_response.expect("error!"); + let dns_response = lookup_response.to_dns_msg(); let key_bytes = self.config.get_key(); let shared_key_name = self.config.get_key_name(); @@ -455,7 +456,7 @@ impl AsyncResolver { ); match rcode { - Rcode::NOERROR => Ok(LookupResponse::new(dns_response)), + Rcode::NOERROR => Ok(LookupResponse::new(dns_response, lookup_response.get_bytes())), Rcode::FORMERR => Err(ClientError::FormatError("The name server was unable to interpret the query."))?, Rcode::SERVFAIL => Err(ClientError::ServerFailure("The name server was unable to process this query due to a problem with the name server."))?, Rcode::NXDOMAIN => Err(ClientError::NameError("The domain name referenced in the query does not exist."))?, @@ -909,7 +910,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1138,7 +1139,7 @@ mod async_resolver_test { header.set_qr(true); header.set_rcode(Rcode::FORMERR); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1185,7 +1186,7 @@ mod async_resolver_test { header.set_qr(true); header.set_rcode(Rcode::SERVFAIL); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1231,7 +1232,7 @@ mod async_resolver_test { header.set_qr(true); header.set_rcode(Rcode::NXDOMAIN); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1278,7 +1279,7 @@ mod async_resolver_test { header.set_qr(true); header.set_rcode(Rcode::NOTIMP); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1325,7 +1326,7 @@ mod async_resolver_test { header.set_qr(true); header.set_rcode(Rcode::REFUSED); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1372,7 +1373,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1411,7 +1412,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_lookup = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_lookup { @@ -1450,7 +1451,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1489,7 +1490,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1528,7 +1529,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1567,7 +1568,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1606,7 +1607,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1645,7 +1646,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1684,7 +1685,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1723,7 +1724,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1762,7 +1763,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1801,7 +1802,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1840,7 +1841,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1879,7 +1880,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { @@ -1918,7 +1919,7 @@ mod async_resolver_test { let mut header = dns_response.get_header(); header.set_qr(true); dns_response.set_header(header); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response)); if let Ok(lookup_response) = result_vec_rr { diff --git a/src/async_resolver/lookup.rs b/src/async_resolver/lookup.rs index df3009e6..286ce579 100644 --- a/src/async_resolver/lookup.rs +++ b/src/async_resolver/lookup.rs @@ -26,7 +26,7 @@ pub struct LookupStrategy { /// Resolver configuration. config: ResolverConfig, /// Reference to the response of the query. - response_msg: Arc>>, + response_msg: Arc), ResolverError>>>, } impl LookupStrategy { @@ -115,7 +115,7 @@ impl LookupStrategy { // appropriate response to its caller. pub fn received_appropriate_response(&self) -> bool { let response_arc = self.response_msg.lock().unwrap(); - if let Ok(dns_msg) = response_arc.as_ref() { + if let Ok((dns_msg, bytes)) = response_arc.as_ref() { match dns_msg.get_header().get_rcode().into() { Rcode::SERVFAIL => return false, Rcode::NOTIMP => return false, @@ -191,7 +191,7 @@ impl LookupStrategy { ) -> Result { let response_arc= self.response_msg.clone(); let protocol = self.config.get_protocol(); - let mut dns_msg_result: Result; + let mut dns_msg_result: Result<(DnsMessage, Vec), ResolverError>; { // Guard reference to modify the response let mut response_guard = response_arc.lock().unwrap(); // TODO: add error handling @@ -205,12 +205,16 @@ impl LookupStrategy { .await .unwrap_or_else( |_| {Err(ResolverError::Message("Execute Strategy Timeout Error".into()))} - ); + ); *response_guard = dns_msg_result.clone(); + //*response_guard = dns_msg_result.clone(); } if self.received_appropriate_response() { return dns_msg_result.and_then( - |dns_msg| Ok(LookupResponse::new(dns_msg)) + |dns_msg| { + let (dns_msg, bytes) = dns_msg; + Ok(LookupResponse::new(dns_msg, bytes)) + } ) } if let ConnectionProtocol::UDP = protocol { @@ -231,7 +235,10 @@ impl LookupStrategy { *response_guard = dns_msg_result.clone(); } dns_msg_result.and_then( - |dns_msg| Ok(LookupResponse::new(dns_msg)) + |dns_msg| { + let (dns_msg, bytes) = dns_msg; + Ok(LookupResponse::new(dns_msg, bytes)) + } ) } } @@ -247,7 +254,7 @@ async fn send_query_by_protocol( query: &DnsMessage, protocol: ConnectionProtocol, server_info: &ServerInfo, -) -> Result { +) -> Result<(DnsMessage, Vec), ResolverError> { let query_id = query.get_query_id(); let dns_query = query.clone(); let dns_msg_result; @@ -257,16 +264,17 @@ async fn send_query_by_protocol( udp_connection.set_timeout(timeout); let response_result = udp_connection.send(dns_query).await; dns_msg_result = parse_response(response_result, query_id); + dns_msg_result } ConnectionProtocol::TCP => { let mut tcp_connection = server_info.get_tcp_connection().clone(); tcp_connection.set_timeout(timeout); let response_result = tcp_connection.send(dns_query).await; dns_msg_result = parse_response(response_result, query_id); + dns_msg_result } - _ => {dns_msg_result = Err(ResolverError::Message("Invalid Protocol".into()))}, // TODO: specific add error handling - }; - dns_msg_result + _ => Err(ResolverError::Message("Invalid Protocol".into())), // TODO: specific add error handling + } } /// Parse the received response datagram to a `DnsMessage`. @@ -288,17 +296,18 @@ async fn send_query_by_protocol( /// excessively long TTL, say greater than 1 week, either discard /// the whole response, or limit all TTLs in the response to 1 /// week. -fn parse_response(response_result: Result, ClientError>, query_id:u16) -> Result { - let dns_msg = response_result.map_err(Into::into) - .and_then(|response_message| { - DnsMessage::from_bytes(&response_message) - .map_err(|_| ResolverError::Parse("The name server was unable to interpret the query.".to_string())) - })?; +fn parse_response(response_result: Result, ClientError>, query_id:u16) -> Result<(DnsMessage, Vec), ResolverError> { + let response_msg = response_result.map_err(Into::::into)?; + + let dns_msg = DnsMessage::from_bytes(&response_msg). + map_err(|_| ResolverError::Parse("The name server was unable to interpret the query.".to_string()))?; + let header = dns_msg.get_header(); // check Header - header.format_check() - .map_err(|e| ResolverError::Parse(format!("Error formated Header: {}", e)))?; + header + .format_check() + .map_err(|e| ResolverError::Parse(format!("Error formated Header: {}", e)))?; // Check ID if dns_msg.get_query_id() != query_id { @@ -306,7 +315,7 @@ fn parse_response(response_result: Result, ClientError>, query_id:u16) - } if header.get_qr() { - return Ok(dns_msg); + return Ok((dns_msg, response_msg)); } Err(ResolverError::Parse("Message is a query. A response was expected.".to_string())) } @@ -555,7 +564,7 @@ mod async_resolver_test { let response_dns_msg = parse_response(response_result,query_id); println!("[###############] {:?}",response_dns_msg); assert!(response_dns_msg.is_ok()); - if let Ok(dns_msg) = response_dns_msg { + if let Ok((dns_msg,_)) = response_dns_msg { assert_eq!(dns_msg.get_header().get_qr(), true); // response (1) assert_eq!(dns_msg.get_header().get_ancount(), 1); assert_eq!(dns_msg.get_header().get_rcode(), Rcode::NOERROR); diff --git a/src/async_resolver/lookup_response.rs b/src/async_resolver/lookup_response.rs index f1ba54ec..47ec8497 100644 --- a/src/async_resolver/lookup_response.rs +++ b/src/async_resolver/lookup_response.rs @@ -12,12 +12,17 @@ use std::fmt; pub struct LookupResponse { // The DNS message response. dns_msg_response: DnsMessage, + bytes: Vec } impl LookupResponse { /// Create a new `LookupResponse` instance. - pub fn new(dns_msg_response: DnsMessage) -> LookupResponse { - LookupResponse { dns_msg_response } + pub fn new(dns_msg_response: DnsMessage, bytes: Vec) -> LookupResponse { + LookupResponse { dns_msg_response, bytes: bytes } + } + + pub fn get_bytes(&self) -> Vec { + self.bytes.clone() } /// Convert the response to a byte vector. @@ -82,7 +87,7 @@ mod lookup_response_tests { #[test] fn new_lookup_response() { let dns_response = DnsMessage::new(); - let lookup_response = LookupResponse::new(dns_response); + let lookup_response = LookupResponse::new(dns_response, vec![]); assert_eq!(lookup_response.to_string(), ""); } @@ -106,7 +111,7 @@ mod lookup_response_tests { dns_query_message.set_answer(answer); - let lookup_response = LookupResponse::new(dns_query_message); + let lookup_response = LookupResponse::new(dns_query_message, vec![]); println!("{}", lookup_response.to_string()); assert_eq!( @@ -207,7 +212,7 @@ mod lookup_response_tests { dns_msg.set_answer(answer); - let lookup_response = LookupResponse::new(dns_msg); + let lookup_response = LookupResponse::new(dns_msg, vec![]); let dns_from_lookup = lookup_response.to_dns_msg(); assert_eq!(dns_from_lookup.get_header().get_id(), 0b0010010010010101); assert_eq!(dns_from_lookup.get_header().get_qr(), true); @@ -262,7 +267,7 @@ mod lookup_response_tests { dns_msg.set_question(question); dns_msg.set_answer(answer); - let lookup_response = LookupResponse::new(dns_msg); + let lookup_response = LookupResponse::new(dns_msg, vec![]); let vec_of_rr = lookup_response.to_vec_of_rr(); assert_eq!(vec_of_rr[0].get_name().get_name(), "dcc.cl"); }