diff --git a/zvt/Cargo.toml b/zvt/Cargo.toml index 1a605ec..bfb70fe 100644 --- a/zvt/Cargo.toml +++ b/zvt/Cargo.toml @@ -22,7 +22,7 @@ zvt_builder = { version = "0.1.0", path = "../zvt_builder" } log = "0.4.19" env_logger = "0.10.0" tokio-stream = "0.1.14" -tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros"] } +tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros", "fs"] } async-stream = "0.3.5" serde = { version = "1.0.185", features = ["derive"] } serde_json = "1.0.105" diff --git a/zvt/src/bin/status/main.rs b/zvt/src/bin/status/main.rs index ee8bff7..0afb6b3 100644 --- a/zvt/src/bin/status/main.rs +++ b/zvt/src/bin/status/main.rs @@ -51,13 +51,76 @@ fn init_logger() { .init(); } +use std::fs::File; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub struct LoggingTcpStream { + stream: TcpStream, + file: File, +} + +impl LoggingTcpStream { + pub async fn new(stream: TcpStream, file_path: &str) -> io::Result { + let file = File::create(file_path)?; + Ok(LoggingTcpStream { stream, file }) + } +} + +impl AsyncRead for LoggingTcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let self_mut = self.get_mut(); + let read = buf.filled().len(); + let poll = Pin::new(&mut self_mut.stream).poll_read(cx, buf); + if let Poll::Ready(Ok(_)) = poll { + let buf = &buf.filled()[read..]; + self_mut.file.write_all(buf)?; + } + poll + } +} + +impl AsyncWrite for LoggingTcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let self_mut = self.get_mut(); + let poll = Pin::new(&mut self_mut.stream).poll_write(cx, buf); + if let Poll::Ready(Ok(size)) = poll { + // If write is successful, write to file + let data = &buf[..size]; + self_mut.file.write_all(data)?; + } + poll + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_shutdown(cx) + } +} + #[tokio::main] async fn main() -> std::io::Result<()> { init_logger(); let args = Args::parse(); info!("Using the args {:?}", args); - let mut socket = TcpStream::connect(args.ip).await?; + let mut socket = LoggingTcpStream { + stream: TcpStream::connect(args.ip).await?, + file: File::create("/tmp/dump.txt")?, + }; let request = packets::Registration { password: args.password, diff --git a/zvt/src/feig/sequences.rs b/zvt/src/feig/sequences.rs index 232dd19..198ca69 100644 --- a/zvt/src/feig/sequences.rs +++ b/zvt/src/feig/sequences.rs @@ -1,5 +1,5 @@ -use crate::sequences::{read_packet_async, write_with_ack_async, Sequence}; -use crate::{packets, ZvtEnum, ZvtParser, ZvtSerializer}; +use crate::sequences::{read_packet_async, write_packet_async, write_with_ack_async, Sequence}; +use crate::{packets, ZvtEnum, ZvtParser }; use anyhow::Result; use async_stream::try_stream; use std::boxed::Box; @@ -146,14 +146,12 @@ impl WriteFile { match response { WriteFileResponse::CompletionData(_) => { - src.write_all(&packets::Ack {}.zvt_serialize()).await?; - + write_packet_async(&mut src, &packets::Ack {}).await?; yield response; break; } WriteFileResponse::Abort(_) => { - src.write_all(&packets::Ack {}.zvt_serialize()).await?; - + write_packet_async(&mut src, &packets::Ack {}).await?; yield response; break; } @@ -195,7 +193,7 @@ impl WriteFile { }), }), }; - src.write_all(&packet.zvt_serialize()).await?; + write_packet_async(&mut src, &packet).await?; yield response; } diff --git a/zvt/src/sequences.rs b/zvt/src/sequences.rs index 01670a7..8d24627 100644 --- a/zvt/src/sequences.rs +++ b/zvt/src/sequences.rs @@ -7,7 +7,7 @@ use log::debug; use std::boxed::Box; use std::marker::Unpin; use std::pin::Pin; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result> { let mut buf = vec![0; 3]; @@ -31,6 +31,18 @@ pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result< Ok(buf.to_vec()) } +pub async fn write_packet_async( + drain: &mut Pin<&mut impl AsyncWriteExt>, + p: &T, +) -> io::Result<()> +where + T: ZvtSerializer + Sync + Send, + encoding::Default: encoding::Encoding, +{ + let bytes = p.zvt_serialize(); + drain.write_all(&bytes).await +} + #[derive(ZvtEnum)] enum Ack { Ack(packets::Ack), @@ -46,14 +58,29 @@ where { // We declare the bytes as a separate variable to help the compiler to // figure out that we can send stuff between threads. - let bytes = p.zvt_serialize(); - src.write_all(&bytes).await?; + write_packet_async(src, p).await?; let bytes = read_packet_async(src).await?; let _ = Ack::zvt_parse(&bytes)?; Ok(()) } +struct DataSource + where T: AsyncReadExt + AsyncWriteExt + Unpin { + s: T, + } + +// TODO(hrapp): I wanted to wrap the functions "write_packet_async" and "read_packet_async" into +// a struct that can take any `T` and has the more type "read_packet" and "write_packet". However I +// completely failed to do so because of pin!(). How do I do that?!? +impl DataSource + where T: AsyncReadExt + AsyncWriteExt + Unpin { + pub async fn read_packet_async(self: Pin<&mut Self>) -> Result> { + read_packet_async(self.s).await? + } + } + + /// The trait for converting a sequence into a stream. /// /// What is written below? The [Self::Input] type must be a command as defined @@ -91,7 +118,7 @@ where let bytes = read_packet_async(&mut src).await?; let packet = Self::Output::zvt_parse(&bytes)?; // Write the response. - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async::(&mut src, &packets::Ack {}).await?; yield packet; }; Box::pin(s) @@ -149,7 +176,7 @@ impl Sequence for ReadCard { let bytes = read_packet_async(&mut src).await?; let packet = ReadCardResponse::zvt_parse(&bytes)?; // Write the response. - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { ReadCardResponse::StatusInformation(_) | ReadCardResponse::Abort(_) => { @@ -206,7 +233,7 @@ impl Sequence for Initialization { let response = InitializationResponse::zvt_parse(&bytes)?; // Every message requires an Ack. - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match response { InitializationResponse::CompletionData(_) @@ -313,7 +340,7 @@ impl Sequence for Diagnosis { let response = DiagnosisResponse::zvt_parse(&bytes)?; // Every message requires an Ack. - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match response { DiagnosisResponse::CompletionData(_) @@ -380,7 +407,7 @@ impl Sequence for EndOfDay { let packet = EndOfDayResponse::zvt_parse(&bytes)?; // Write the response. - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { EndOfDayResponse::CompletionData(_) | EndOfDayResponse::Abort(_) => { yield packet; @@ -445,7 +472,7 @@ impl Sequence for Reservation { loop { let bytes = read_packet_async(&mut src).await?; let packet = AuthorizationResponse::zvt_parse(&bytes)?; - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { AuthorizationResponse::CompletionData(_) | AuthorizationResponse::Abort(_) => { yield packet; @@ -515,7 +542,7 @@ impl Sequence for PartialReversal { loop { let bytes = read_packet_async(&mut src).await?; let packet = PartialReversalResponse::zvt_parse(&bytes)?; - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { PartialReversalResponse::CompletionData(_) | PartialReversalResponse::PartialReversalAbort(_) => { @@ -555,7 +582,7 @@ impl Sequence for PreAuthReversal { loop { let bytes = read_packet_async(&mut src).await?; let packet = PartialReversalResponse::zvt_parse(&bytes)?; - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { PartialReversalResponse::CompletionData(_) | PartialReversalResponse::PartialReversalAbort(_) => { @@ -606,7 +633,7 @@ impl Sequence for PrintSystemConfiguration { loop { let bytes = read_packet_async(&mut src).await?; let packet = PrintSystemConfigurationResponse::zvt_parse(&bytes)?; - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { PrintSystemConfigurationResponse::CompletionData(_) => { yield packet; @@ -677,7 +704,7 @@ impl Sequence for StatusEnquiry { loop { let bytes = read_packet_async(&mut src).await?; let packet = StatusEnquiryResponse::zvt_parse(&bytes)?; - src.write_all(&packets::Ack {}.zvt_serialize()).await?; + write_packet_async(&mut src, &packets::Ack {}).await?; match packet { StatusEnquiryResponse::CompletionData(_) => { yield packet;