Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix frame sending order #3

Merged
merged 2 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

## 0.2.0
chore: add frame trace log
fix: fix frame sending order

## 0.1.1

chore: change package description
Expand Down
12 changes: 11 additions & 1 deletion src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
pub type Sid = u32;

/// Frame commands of smux protocal.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub enum Cmd {
/// Stream open.
Sync,
Expand Down Expand Up @@ -66,6 +66,16 @@ pub struct Frame {
pub data: Option<Vec<u8>>,
}

impl std::fmt::Debug for Frame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Frame")
.field("sid", &self.sid)
.field("cmd", &self.cmd)
.field("len", &self.length)
.finish()
}
}

impl Frame {
pub fn new(ver: u8, cmd: Cmd, sid: Sid) -> Self {
Self {
Expand Down
52 changes: 32 additions & 20 deletions src/read_frame_grouper.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::frame::{Cmd, Frame, Sid};
use crate::session_inner::ReadRequest;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Mutex};

// Consume reading frames and split them into:
// - Sync frames
Expand All @@ -18,11 +18,11 @@ pub(crate) struct ReadFrameGrouper {
pub sync_tx: mpsc::Sender<Frame>,

// Session could also operate `sid_tx_map` and `sid_rx_map`.
pub sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
pub sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,

// `sid_rx_map` is shared with session.
// Items of `sid_rx_map` will be taken away by the session when accepting new streams.
pub sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
pub sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,

pub sid_frame_buffer_size: usize,
}
Expand Down Expand Up @@ -57,11 +57,18 @@ impl ReadFrameGrouper {

async fn handle_sync(&mut self, read_req: ReadRequest) {
let sid = read_req.frame.sid;
if !self.sid_tx_map.contains_key(&sid) {
let (tx, rx) = mpsc::channel(self.sid_frame_buffer_size);
self.sid_tx_map.insert(sid, tx);
self.sid_rx_map.insert(sid, rx);
}
{
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
let (tx, rx) = mpsc::channel(self.sid_frame_buffer_size);
self
.sid_tx_map
.lock()
.await
.insert(sid, Arc::new(Mutex::new(tx)));
self.sid_rx_map.lock().await.insert(sid, rx);
}
};
let send_sync_tx_res = self.sync_tx.send(read_req.frame).await;
if send_sync_tx_res.is_err() {
// session closed
Expand All @@ -71,13 +78,18 @@ impl ReadFrameGrouper {

async fn handle_fin_push(&mut self, read_req: ReadRequest) {
let sid = read_req.frame.sid;
if !self.sid_tx_map.contains_key(&sid) {
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
// unexpected, ignore the frame
log::warn!("[grouper] receive unexecpted frame, sid: {}", sid,);
return;
}
let tx = self.sid_tx_map.get(&sid).unwrap();
let _ = tx.send(read_req.frame).await;
let tx = {
let lock = self.sid_tx_map.lock().await;
let tx = lock.get(&sid).unwrap();
tx.clone()
};
let _ = tx.lock().await.send(read_req.frame).await;
}
}

Expand All @@ -86,21 +98,22 @@ mod test {
use crate::frame::{Cmd, Frame};
use crate::read_frame_grouper::ReadFrameGrouper;
use crate::session_inner::ReadRequest;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Mutex};

#[tokio::test]
async fn test_grouper() {
let (new_frame_tx, new_frame_rx) = mpsc::channel(1024);
let (sync_tx, mut sync_rx) = mpsc::channel(1024);
let sid_rx_map = Arc::new(DashMap::new());
let sid_rx_map = HashMap::new();
let sid_rx_map = Arc::new(Mutex::new(sid_rx_map));
let mut grouper = ReadFrameGrouper {
new_frame_rx,
sid_frame_buffer_size: 1024,
sync_tx,
sid_rx_map: sid_rx_map.clone(),
sid_tx_map: Arc::new(DashMap::new()),
sid_tx_map: Arc::new(Mutex::new(HashMap::new())),
};

// Should create correspond sid_tx_map when receive sync frames.
Expand All @@ -121,10 +134,9 @@ mod test {

let frame = sync_rx.recv().await.unwrap();
assert!(matches!(frame.cmd, Cmd::Sync));
let item = sid_rx_map.remove(&sid);
let item = { sid_rx_map.lock().await.remove(&sid) };
assert!(item.is_some());
let (id, mut item_frame_rx) = item.unwrap();
assert_eq!(id, sid);
let mut item_frame_rx = item.unwrap();
let frame = item_frame_rx.recv().await.unwrap();
assert_eq!(frame.sid, sid);

Expand All @@ -141,7 +153,7 @@ mod test {
frame.with_data(vec![0; 10]);
new_frame_tx.send(ReadRequest { frame }).await.unwrap();
item_frame_rx.recv().await.unwrap();
assert!(sid_rx_map.remove(&sid).is_none());
assert!(sid_rx_map.lock().await.remove(&sid).is_none());

// Cloud new_frame_rx should
drop(new_frame_tx);
Expand Down
82 changes: 44 additions & 38 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::frame::{Cmd, Frame, Sid};
use crate::read_frame_grouper::ReadFrameGrouper;
use crate::session_inner::{SessionInner, WriteRequest};
use crate::stream::Stream;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, oneshot, Mutex};
Expand All @@ -24,9 +24,9 @@ pub struct Session {
go_away: bool,

sync_rx: mpsc::Receiver<Frame>,
sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
sid_close_tx_map: Arc<DashMap<Sid, oneshot::Sender<()>>>,
sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,
sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,
sid_close_tx_map: Arc<Mutex<HashMap<Sid, oneshot::Sender<()>>>>,
sid_drop_tx: mpsc::Sender<Sid>,

inner_err: Arc<Mutex<Option<TokioSmuxError>>>,
Expand All @@ -35,16 +35,18 @@ pub struct Session {
impl Drop for Session {
fn drop(&mut self) {
// close all streams
let mut keys: Vec<Sid> = vec![];
for kv in self.sid_close_tx_map.iter() {
let sid = kv.key();
keys.push(*sid);
}
for id in keys {
let item = self.sid_close_tx_map.remove(&id);
let (_, tx) = item.unwrap();
let _ = tx.send(());
}
let sid_close_tx_map = self.sid_close_tx_map.clone();
tokio::spawn(async move {
let mut keys: Vec<Sid> = vec![];
for (kv, _) in sid_close_tx_map.lock().await.iter() {
keys.push(*kv);
}
for id in keys {
let item = sid_close_tx_map.lock().await.remove(&id);
let tx = item.unwrap();
let _ = tx.send(());
}
});
}
}

Expand Down Expand Up @@ -91,8 +93,8 @@ impl Session {

// init ReadFrameGrouper
let (sync_tx, sync_rx) = mpsc::channel(MAX_IN_QUEUE_SYNC_FRAMES);
let sid_tx_map = Arc::new(DashMap::new());
let sid_rx_map = Arc::new(DashMap::new());
let sid_tx_map = Arc::new(Mutex::new(HashMap::new()));
let sid_rx_map = Arc::new(Mutex::new(HashMap::new()));
let mut spliter = ReadFrameGrouper {
new_frame_rx,
sync_tx,
Expand All @@ -119,7 +121,7 @@ impl Session {

sid_tx_map,
sid_rx_map,
sid_close_tx_map: Arc::new(DashMap::new()),
sid_close_tx_map: Arc::new(Mutex::new(HashMap::new())),
sid_drop_tx,

sync_rx,
Expand Down Expand Up @@ -181,17 +183,22 @@ impl Session {

// Update sid_tx_map and sid_rx_map when open_stream.
{
if !self.sid_tx_map.contains_key(&sid) {
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
let (tx, rx) = mpsc::channel(self.config.stream_reading_frame_channel_capacity);
self.sid_tx_map.insert(sid, tx);
self.sid_rx_map.insert(sid, rx);
self
.sid_tx_map
.lock()
.await
.insert(sid, Arc::new(Mutex::new(tx)));
self.sid_rx_map.lock().await.insert(sid, rx);
}
}
let stream = self.new_stream(sid);
let stream = self.new_stream(sid).await;
if stream.is_err() {
// not likely
self.sid_tx_map.remove(&sid);
self.sid_rx_map.remove(&sid);
self.sid_tx_map.lock().await.remove(&sid);
self.sid_rx_map.lock().await.remove(&sid);
}
Ok(stream.unwrap())
}
Expand All @@ -209,25 +216,24 @@ impl Session {
let frame = frame.unwrap();
let sid = frame.sid;

let stream = self.new_stream(sid)?;
let stream = self.new_stream(sid).await?;

Ok(stream)
}

fn new_stream(&mut self, sid: Sid) -> Result<Stream> {
async fn new_stream(&mut self, sid: Sid) -> Result<Stream> {
let frame_rx = {
let rx = self.sid_rx_map.remove(&sid);
let rx = self.sid_rx_map.lock().await.remove(&sid);
if rx.is_none() {
return Err(TokioSmuxError::Default {
msg: "unexpected empty sid in sid_rx_map".to_string(),
});
}
let (_id, rx) = rx.unwrap();
rx
rx.unwrap()
};

let (close_tx, close_rx) = oneshot::channel();
self.sid_close_tx_map.insert(sid, close_tx);
self.sid_close_tx_map.lock().await.insert(sid, close_tx);

let mut stream = Stream::new(sid, frame_rx, self.write_tx.clone(), close_rx);
stream.with_drop_tx(Some(self.sid_drop_tx.clone()));
Expand Down Expand Up @@ -255,9 +261,9 @@ impl Session {
struct SessionCleaner {
sid_drop_rx: mpsc::Receiver<Sid>,

sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
sid_close_tx_map: Arc<DashMap<Sid, oneshot::Sender<()>>>,
sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,
sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,
sid_close_tx_map: Arc<Mutex<HashMap<Sid, oneshot::Sender<()>>>>,
}

impl SessionCleaner {
Expand All @@ -278,9 +284,9 @@ impl SessionCleaner {
}

let sid = sid.unwrap();
self.sid_tx_map.remove(&sid);
self.sid_rx_map.remove(&sid);
self.sid_close_tx_map.remove(&sid);
self.sid_tx_map.lock().await.remove(&sid);
self.sid_rx_map.lock().await.remove(&sid);
self.sid_close_tx_map.lock().await.remove(&sid);
}
}
}
Expand Down Expand Up @@ -531,9 +537,9 @@ pub mod test {

// clean up after a while
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(client.sid_tx_map.get(&sid).is_none());
assert!(client.sid_rx_map.get(&sid).is_none());
assert!(client.sid_close_tx_map.get(&sid).is_none());
assert!(client.sid_tx_map.lock().await.get(&sid).is_none());
assert!(client.sid_rx_map.lock().await.get(&sid).is_none());
assert!(client.sid_close_tx_map.lock().await.get(&sid).is_none());

// the remote should also receive the fin
let data = write_rx.recv().await.unwrap();
Expand Down
17 changes: 10 additions & 7 deletions src/session_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::future;
use crate::error::Result;
use crate::frame::HEADER_SIZE;
use crate::frame::{Cmd, Frame};
use log;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use tokio::time;
Expand Down Expand Up @@ -115,7 +116,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
self.read_finished = true;
continue;
}
self.handle_read_data(&data[0..size])?;
self.handle_read_data(&data[0..size]).await?;
}
// write
req = self.write_rx.recv() => {
Expand All @@ -137,6 +138,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {

async fn handle_keep_alive_interval_tick(&mut self) -> Result<()> {
let frame = Frame::new_v1(Cmd::Nop, 0);
log::trace!("send frame: {:?}", frame);
let buf = frame.get_buf()?;
self.conn.write_all(&buf).await?;

Expand All @@ -152,14 +154,16 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
return Ok(());
}

log::trace!("send frame: {:?}", req.frame);

let finish_tx = req.finish_tx.take().unwrap();
// ignore stream closed error
let _ = finish_tx.send(());

Ok(())
}

fn handle_read_data(&mut self, data: &[u8]) -> Result<()> {
async fn handle_read_data(&mut self, data: &[u8]) -> Result<()> {
if data.len() == 0 {
// Remote write side closed, no more data.
return Ok(());
Expand All @@ -179,6 +183,9 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
break;
}
let mut frame = frame.unwrap();

log::trace!("receive frame: {:?}", frame);

let frame_length = frame.length;
// check if all data ready
if (frame_length as u32 + HEADER_SIZE as u32) > (self.read_buf.len() as u32) {
Expand All @@ -196,11 +203,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {

// output frame
let recv_tx = self.recv_tx.clone();
tokio::spawn(async move {
// Will block if the tx capability is empty.
// is_err() means the session is closed, therefore ignore the error.
let _ = recv_tx.send(read_req).await;
});
let _ = recv_tx.send(read_req).await;

// continue
}
Expand Down
Loading
Loading