Skip to content

Commit

Permalink
add: tls_connection and fix tcp connection send function
Browse files Browse the repository at this point in the history
  • Loading branch information
FranciscaOrtegaG committed Aug 27, 2024
1 parent 2c1175b commit 6b0f988
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 69 deletions.
1 change: 1 addition & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod client_connection;
pub mod tcp_connection;
pub mod tls_connection;
pub mod udp_connection;
pub mod client_error;

Expand Down
120 changes: 51 additions & 69 deletions src/client/tcp_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,82 +52,64 @@ impl ClientConnection for ClientTCPConnection {

/// creates socket tcp, sends query and receive response
async fn send(self, dns_query: DnsMessage) -> Result<Vec<u8>, ClientError> {
// async fn send(self, dns_query: DnsMessage) -> Result<(Vec<u8>, IpAddr), ClientError> {

// async fn send(self, dns_query: DnsMessage) -> Result<(Vec<u8>, IpAddr), ClientError> {

let conn_timeout: Duration = self.get_timeout();
let bytes: Vec<u8> = dns_query.to_bytes();
let server_addr:SocketAddr = SocketAddr::new(self.get_server_addr(), 53);

// let mut stream: TcpStream = TcpStream::connect_timeout(&server_addr,timeout)?;
let conn_task = TcpStream::connect(&server_addr);
let mut stream: TcpStream = match timeout(conn_timeout, conn_task).await {
Ok(stream_result) => stream_result?,
Err(_) => return Err(ClientError::Io(IoError::new(ErrorKind::TimedOut, format!("Error: timeout"))).into()),
};

let bytes: Vec<u8> = dns_query.to_bytes();

// Add len of message len
let msg_length: u16 = bytes.len() as u16;
let tcp_bytes_length: [u8; 2] = [(msg_length >> 8) as u8, msg_length as u8];
let full_msg: Vec<u8> = [&tcp_bytes_length, bytes.as_slice()].concat();

// stream.set_read_timeout(Some(timeout))?; //-> Se hace con tokio

// stream.write(&full_msg)?;
stream.write(&full_msg).await?;

// Read response
let mut msg_size_response: [u8; 2] = [0; 2];

// Add len of message len
let msg_length: u16 = bytes.len() as u16;
let tcp_bytes_length: [u8; 2] = [(msg_length >> 8) as u8, msg_length as u8];
let full_msg: Vec<u8> = [&tcp_bytes_length, bytes.as_slice()].concat();
stream.read_exact(&mut msg_size_response).await?;

//get domain name
let domain_name = dns_query.get_question().get_qname().get_name();
let dns_name = DnsNameRef::try_from_ascii_str(&domain_name);
if dns_name.is_err() {
return Err(ClientError::Io(IoError::new(ErrorKind::InvalidInput, format!("Error: invalid domain name"))).into());
}

let root_store = rustls::RootCertStore::from_iter(
webpki_roots::TLS_SERVER_ROOTS
.iter()
.cloned(),
);
let config = rustls::ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
let rc_config = Arc::new(config);


let dns_name = domain_name;
let server_name =ServerName::try_from(dns_name).expect("invalid DNS name");
let connector = rustls::ClientConnection::new(rc_config, server_name).unwrap();
let conn_timeout: Duration = self.get_timeout();
let server_addr:SocketAddr = SocketAddr::new(self.get_server_addr(), 853);

// let mut stream: TcpStream = TcpStream::connect_timeout(&server_addr,timeout)?;
let conn_task = TcpStream::connect(&server_addr);
let mut stream: TcpStream = TcpStream::connect(server_addr).await?;
// stream.set_read_timeout(Some(timeout))?; //-> Se hace con tokio

// stream.write(&full_msg)?;
stream.write(&full_msg).await?;
let tcp_msg_len: u16 = (msg_size_response[0] as u16) << 8 | msg_size_response[1] as u16;
let mut vec_msg: Vec<u8> = Vec::new();
let ip = self.get_server_addr();
let mut additionals = dns_query.get_additional();
let mut ar = ARdata::new();
ar.set_address(ip);
let a_rdata = Rdata::A(ar);
let rr = ResourceRecord::new(a_rdata);
additionals.push(rr);


// Read response
let mut msg_size_response: [u8; 2] = [0; 2];

stream.read_exact(&mut msg_size_response).await?;
while vec_msg.len() < tcp_msg_len as usize {
let mut msg = [0; 512];
let read_task = stream.read(&mut msg);
let number_of_bytes_msg_result = match timeout(conn_timeout, read_task).await {
Ok(n) => n,
Err(_) => return Err(ClientError::Io(IoError::new(ErrorKind::TimedOut, format!("Error: timeout"))).into()),
};

let tcp_msg_len: u16 = (msg_size_response[0] as u16) << 8 | msg_size_response[1] as u16;
let mut vec_msg: Vec<u8> = Vec::new();
let ip = self.get_server_addr();
let mut additionals = dns_query.get_additional();
let mut ar = ARdata::new();
ar.set_address(ip);
let a_rdata = Rdata::A(ar);
let rr = ResourceRecord::new(a_rdata);
additionals.push(rr);

let number_of_bytes_msg = match number_of_bytes_msg_result {
Ok(n) if n > 0 => n,
_ => return Err(IoError::new(ErrorKind::Other, format!("Error: no data received "))).map_err(Into::into),

};

while vec_msg.len() < tcp_msg_len as usize {
let mut msg = [0; 512];
let read_task = stream.read(&mut msg);
let number_of_bytes_msg_result = match timeout(conn_timeout, read_task).await {
Ok(n) => n,
Err(_) => return Err(ClientError::Io(IoError::new(ErrorKind::TimedOut, format!("Error: timeout"))).into()),
};

let number_of_bytes_msg = match number_of_bytes_msg_result {
Ok(n) if n > 0 => n,
_ => return Err(IoError::new(ErrorKind::Other, format!("Error: no data received "))).map_err(Into::into),

};

vec_msg.extend_from_slice(&msg[..number_of_bytes_msg]);
vec_msg.extend_from_slice(&msg[..number_of_bytes_msg]);
}

return Ok(vec_msg);
}

return Ok(vec_msg);
}
}

//Getters
Expand Down
Loading

0 comments on commit 6b0f988

Please sign in to comment.