Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Try encapsulating writes #5

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Failed attempt.
  • Loading branch information
SirVer committed Dec 6, 2023
commit 5253bf9b82a9cc72bb8e6fc3cd586a47b9fdb93d
2 changes: 1 addition & 1 deletion zvt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ zvt_builder = { version = "0.1.0", path = "../zvt_builder" }
log = "0.4.19"
env_logger = "0.10.0"
tokio-stream = "0.1.14"
tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros"] }
tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros", "fs"] }
async-stream = "0.3.5"
serde = { version = "1.0.185", features = ["derive"] }
serde_json = "1.0.105"
Expand Down
65 changes: 64 additions & 1 deletion zvt/src/bin/status/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,76 @@ fn init_logger() {
.init();
}

use std::fs::File;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub struct LoggingTcpStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole idea was to support not only tcp but also other ways of communication just as the python code did. Hardcoding tcp here is a regression IMHO

stream: TcpStream,
file: File,
}

impl LoggingTcpStream {
pub async fn new(stream: TcpStream, file_path: &str) -> io::Result<Self> {
let file = File::create(file_path)?;
Ok(LoggingTcpStream { stream, file })
}
}

impl AsyncRead for LoggingTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let self_mut = self.get_mut();
let read = buf.filled().len();
let poll = Pin::new(&mut self_mut.stream).poll_read(cx, buf);
if let Poll::Ready(Ok(_)) = poll {
let buf = &buf.filled()[read..];
self_mut.file.write_all(buf)?;
}
poll
}
}

impl AsyncWrite for LoggingTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let self_mut = self.get_mut();
let poll = Pin::new(&mut self_mut.stream).poll_write(cx, buf);
if let Poll::Ready(Ok(size)) = poll {
// If write is successful, write to file
let data = &buf[..size];
self_mut.file.write_all(data)?;
}
poll
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
}
}

#[tokio::main]
async fn main() -> std::io::Result<()> {
init_logger();
let args = Args::parse();

info!("Using the args {:?}", args);
let mut socket = TcpStream::connect(args.ip).await?;
let mut socket = LoggingTcpStream {
stream: TcpStream::connect(args.ip).await?,
file: File::create("/tmp/dump.txt")?,
};

let request = packets::Registration {
password: args.password,
Expand Down
12 changes: 5 additions & 7 deletions zvt/src/feig/sequences.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::sequences::{read_packet_async, write_with_ack_async, Sequence};
use crate::{packets, ZvtEnum, ZvtParser, ZvtSerializer};
use crate::sequences::{read_packet_async, write_packet_async, write_with_ack_async, Sequence};
use crate::{packets, ZvtEnum, ZvtParser };
use anyhow::Result;
use async_stream::try_stream;
use std::boxed::Box;
Expand Down Expand Up @@ -146,14 +146,12 @@ impl WriteFile {

match response {
WriteFileResponse::CompletionData(_) => {
src.write_all(&packets::Ack {}.zvt_serialize()).await?;

write_packet_async(&mut src, &packets::Ack {}).await?;
yield response;
break;
}
WriteFileResponse::Abort(_) => {
src.write_all(&packets::Ack {}.zvt_serialize()).await?;

write_packet_async(&mut src, &packets::Ack {}).await?;
yield response;
break;
}
Expand Down Expand Up @@ -195,7 +193,7 @@ impl WriteFile {
}),
}),
};
src.write_all(&packet.zvt_serialize()).await?;
write_packet_async(&mut src, &packet).await?;

yield response;
}
Expand Down
53 changes: 40 additions & 13 deletions zvt/src/sequences.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use log::debug;
use std::boxed::Box;
use std::marker::Unpin;
use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};

pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result<Vec<u8>> {
let mut buf = vec![0; 3];
Expand All @@ -31,6 +31,18 @@ pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result<
Ok(buf.to_vec())
}

pub async fn write_packet_async<T>(
drain: &mut Pin<&mut impl AsyncWriteExt>,
p: &T,
) -> io::Result<()>
where
T: ZvtSerializer + Sync + Send,
encoding::Default: encoding::Encoding<T>,
{
let bytes = p.zvt_serialize();
drain.write_all(&bytes).await
}

#[derive(ZvtEnum)]
enum Ack {
Ack(packets::Ack),
Expand All @@ -46,14 +58,29 @@ where
{
// We declare the bytes as a separate variable to help the compiler to
// figure out that we can send stuff between threads.
let bytes = p.zvt_serialize();
src.write_all(&bytes).await?;
write_packet_async(src, p).await?;

let bytes = read_packet_async(src).await?;
let _ = Ack::zvt_parse(&bytes)?;
Ok(())
}

struct DataSource<T>
where T: AsyncReadExt + AsyncWriteExt + Unpin {
s: T,
}

// TODO(hrapp): I wanted to wrap the functions "write_packet_async" and "read_packet_async" into
// a struct that can take any `T` and has the more type "read_packet" and "write_packet". However I
// completely failed to do so because of pin!(). How do I do that?!?
impl<T> DataSource<T>
where T: AsyncReadExt + AsyncWriteExt + Unpin {
pub async fn read_packet_async(self: Pin<&mut Self>) -> Result<Vec<u8>> {
read_packet_async(self.s).await?
}
}


/// The trait for converting a sequence into a stream.
///
/// What is written below? The [Self::Input] type must be a command as defined
Expand Down Expand Up @@ -91,7 +118,7 @@ where
let bytes = read_packet_async(&mut src).await?;
let packet = Self::Output::zvt_parse(&bytes)?;
// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async::<packets::Ack>(&mut src, &packets::Ack {}).await?;
yield packet;
};
Box::pin(s)
Expand Down Expand Up @@ -149,7 +176,7 @@ impl Sequence for ReadCard {
let bytes = read_packet_async(&mut src).await?;
let packet = ReadCardResponse::zvt_parse(&bytes)?;
// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match packet {
ReadCardResponse::StatusInformation(_) | ReadCardResponse::Abort(_) => {
Expand Down Expand Up @@ -206,7 +233,7 @@ impl Sequence for Initialization {
let response = InitializationResponse::zvt_parse(&bytes)?;

// Every message requires an Ack.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match response {
InitializationResponse::CompletionData(_)
Expand Down Expand Up @@ -313,7 +340,7 @@ impl Sequence for Diagnosis {
let response = DiagnosisResponse::zvt_parse(&bytes)?;

// Every message requires an Ack.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match response {
DiagnosisResponse::CompletionData(_)
Expand Down Expand Up @@ -380,7 +407,7 @@ impl Sequence for EndOfDay {
let packet = EndOfDayResponse::zvt_parse(&bytes)?;

// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
EndOfDayResponse::CompletionData(_) | EndOfDayResponse::Abort(_) => {
yield packet;
Expand Down Expand Up @@ -445,7 +472,7 @@ impl Sequence for Reservation {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = AuthorizationResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
AuthorizationResponse::CompletionData(_) | AuthorizationResponse::Abort(_) => {
yield packet;
Expand Down Expand Up @@ -515,7 +542,7 @@ impl Sequence for PartialReversal {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PartialReversalResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PartialReversalResponse::CompletionData(_)
| PartialReversalResponse::PartialReversalAbort(_) => {
Expand Down Expand Up @@ -555,7 +582,7 @@ impl Sequence for PreAuthReversal {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PartialReversalResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PartialReversalResponse::CompletionData(_)
| PartialReversalResponse::PartialReversalAbort(_) => {
Expand Down Expand Up @@ -606,7 +633,7 @@ impl Sequence for PrintSystemConfiguration {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PrintSystemConfigurationResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PrintSystemConfigurationResponse::CompletionData(_) => {
yield packet;
Expand Down Expand Up @@ -677,7 +704,7 @@ impl Sequence for StatusEnquiry {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = StatusEnquiryResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
StatusEnquiryResponse::CompletionData(_) => {
yield packet;
Expand Down