|
1 | 1 | use std::fs::File;
|
2 | 2 | use std::os::unix::fs::FileExt;
|
3 |
| -use std::time::Duration; |
4 | 3 |
|
5 | 4 | use anyhow::anyhow;
|
6 | 5 | use bytes::Bytes;
|
@@ -38,93 +37,123 @@ use crate::utils::{
|
38 | 37 | encode_point, gen_ecdsa_keypair, gen_random, get_download_dir, hkdf_extract_expand,
|
39 | 38 | stream_read_exact, to_four_digit_string, DeviceType, RemoteDeviceInfo,
|
40 | 39 | };
|
41 |
| -use crate::{location_nearby_connections, sharing_nearby}; |
| 40 | +use crate::{location_nearby_connections, sharing_nearby, RqsEvent}; |
42 | 41 |
|
43 | 42 | type HmacSha256 = Hmac<Sha256>;
|
44 | 43 |
|
45 | 44 | const SANE_FRAME_LENGTH: i32 = 5 * 1024 * 1024;
|
46 |
| -const SANITY_DURATION: Duration = Duration::from_micros(10); |
47 | 45 |
|
48 | 46 | #[derive(Debug)]
|
49 | 47 | pub struct InboundRequest {
|
50 | 48 | socket: TcpStream,
|
51 | 49 | pub state: InnerState,
|
52 |
| - sender: Sender<ChannelMessage>, |
53 |
| - receiver: Receiver<ChannelMessage>, |
| 50 | + sender: Sender<RqsEvent>, |
| 51 | + receiver: Receiver<RqsEvent>, |
54 | 52 | }
|
55 | 53 |
|
56 | 54 | impl InboundRequest {
|
57 |
| - pub fn new(socket: TcpStream, id: String, sender: Sender<ChannelMessage>) -> Self { |
| 55 | + pub fn new(socket: TcpStream, id: String, sender: Sender<RqsEvent>) -> Self { |
| 56 | + // Create a receiver for the unified event channel |
58 | 57 | let receiver = sender.subscribe();
|
59 | 58 |
|
| 59 | + // We'll filter messages in the handle method instead of trying to create a custom filtered receiver |
60 | 60 | Self {
|
61 | 61 | socket,
|
62 | 62 | state: InnerState {
|
63 |
| - id, |
| 63 | + id: id.clone(), |
64 | 64 | server_seq: 0,
|
65 | 65 | client_seq: 0,
|
| 66 | + encryption_done: false, |
66 | 67 | state: State::Initial,
|
67 |
| - encryption_done: true, |
68 |
| - ..Default::default() |
| 68 | + remote_device_info: None, |
| 69 | + pin_code: None, |
| 70 | + transfer_metadata: None, |
| 71 | + transferred_files: Default::default(), |
| 72 | + cipher_commitment: None, |
| 73 | + private_key: None, |
| 74 | + public_key: None, |
| 75 | + server_init_data: None, |
| 76 | + client_init_msg_data: None, |
| 77 | + ukey_client_finish_msg_data: None, |
| 78 | + decrypt_key: None, |
| 79 | + recv_hmac_key: None, |
| 80 | + encrypt_key: None, |
| 81 | + send_hmac_key: None, |
| 82 | + text_payload: None, |
| 83 | + payload_buffers: Default::default(), |
69 | 84 | },
|
70 | 85 | sender,
|
71 | 86 | receiver,
|
72 | 87 | }
|
73 | 88 | }
|
74 | 89 |
|
75 | 90 | pub async fn handle(&mut self) -> Result<(), anyhow::Error> {
|
| 91 | + // Check for any pending messages from the frontend |
| 92 | + if let Ok(Some(channel_msg)) = self.wait_for_channel_message().await { |
| 93 | + if let Some(action) = channel_msg.action { |
| 94 | + match action { |
| 95 | + ChannelAction::AcceptTransfer => { |
| 96 | + return self.accept_transfer().await; |
| 97 | + } |
| 98 | + ChannelAction::RejectTransfer => { |
| 99 | + return self |
| 100 | + .reject_transfer(Some( |
| 101 | + sharing_nearby::connection_response_frame::Status::Reject, |
| 102 | + )) |
| 103 | + .await; |
| 104 | + } |
| 105 | + ChannelAction::CancelTransfer => { |
| 106 | + // Use an appropriate status value from the actual enum |
| 107 | + return self |
| 108 | + .reject_transfer(Some( |
| 109 | + sharing_nearby::connection_response_frame::Status::Reject, |
| 110 | + )) |
| 111 | + .await; |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
76 | 117 | // Buffer for the 4-byte length
|
77 | 118 | let mut length_buf = [0u8; 4];
|
78 | 119 |
|
79 | 120 | tokio::select! {
|
80 | 121 | i = self.receiver.recv() => {
|
81 |
| - match i { |
82 |
| - Ok(channel_msg) => { |
83 |
| - if channel_msg.direction == ChannelDirection::LibToFront { |
84 |
| - return Ok(()); |
85 |
| - } |
| 122 | + if let Ok(RqsEvent::Message(channel_msg)) = i { |
| 123 | + if channel_msg.direction == ChannelDirection::LibToFront { |
| 124 | + return Ok(()); |
| 125 | + } |
86 | 126 |
|
87 |
| - if channel_msg.id != self.state.id { |
88 |
| - return Ok(()); |
89 |
| - } |
| 127 | + if channel_msg.id != self.state.id { |
| 128 | + return Ok(()); |
| 129 | + } |
90 | 130 |
|
91 |
| - debug!("inbound: got: {:?}", channel_msg); |
92 |
| - match channel_msg.action { |
93 |
| - Some(ChannelAction::AcceptTransfer) => { |
| 131 | + debug!("inbound: got: {:?}", channel_msg); |
| 132 | + if let Some(action) = channel_msg.action { |
| 133 | + match action { |
| 134 | + ChannelAction::AcceptTransfer => { |
94 | 135 | self.accept_transfer().await?;
|
95 | 136 | },
|
96 |
| - Some(ChannelAction::RejectTransfer) => { |
| 137 | + ChannelAction::RejectTransfer => { |
97 | 138 | self.update_state(
|
98 | 139 | |e| {
|
99 | 140 | e.state = State::Rejected;
|
100 | 141 | },
|
101 | 142 | true,
|
102 | 143 | ).await;
|
103 |
| - |
104 |
| - self.reject_transfer(Some( |
105 |
| - sharing_nearby::connection_response_frame::Status::Reject |
106 |
| - )).await?; |
107 |
| - return Err(anyhow!(crate::errors::AppError::NotAnError)); |
108 | 144 | },
|
109 |
| - Some(ChannelAction::CancelTransfer) => { |
| 145 | + ChannelAction::CancelTransfer => { |
110 | 146 | self.update_state(
|
111 | 147 | |e| {
|
112 | 148 | e.state = State::Cancelled;
|
113 | 149 | },
|
114 | 150 | true,
|
115 | 151 | ).await;
|
116 |
| - self.disconnection().await?; |
117 |
| - return Err(anyhow!(crate::errors::AppError::NotAnError)); |
118 |
| - }, |
119 |
| - None => { |
120 |
| - trace!("inbound: nothing to do") |
121 |
| - }, |
| 152 | + } |
122 | 153 | }
|
123 | 154 | }
|
124 |
| - Err(e) => { |
125 |
| - error!("inbound: channel error: {}", e); |
126 |
| - } |
127 | 155 | }
|
| 156 | + return Ok(()); |
128 | 157 | },
|
129 | 158 | h = stream_read_exact(&mut self.socket, &mut length_buf) => {
|
130 | 159 | h?;
|
@@ -978,24 +1007,15 @@ impl InboundRequest {
|
978 | 1007 | }
|
979 | 1008 |
|
980 | 1009 | async fn disconnection(&mut self) -> Result<(), anyhow::Error> {
|
981 |
| - let frame = location_nearby_connections::OfflineFrame { |
982 |
| - version: Some(location_nearby_connections::offline_frame::Version::V1.into()), |
983 |
| - v1: Some(location_nearby_connections::V1Frame { |
984 |
| - r#type: Some( |
985 |
| - location_nearby_connections::v1_frame::FrameType::Disconnection.into(), |
986 |
| - ), |
987 |
| - disconnection: Some(location_nearby_connections::DisconnectionFrame { |
988 |
| - ..Default::default() |
989 |
| - }), |
990 |
| - ..Default::default() |
991 |
| - }), |
992 |
| - }; |
| 1010 | + self.update_state( |
| 1011 | + |s| { |
| 1012 | + s.state = State::Disconnected; |
| 1013 | + }, |
| 1014 | + true, |
| 1015 | + ) |
| 1016 | + .await; |
993 | 1017 |
|
994 |
| - if self.state.encryption_done { |
995 |
| - self.encrypt_and_send(&frame).await |
996 |
| - } else { |
997 |
| - self.send_frame(frame.encode_to_vec()).await |
998 |
| - } |
| 1018 | + Ok(()) |
999 | 1019 | }
|
1000 | 1020 |
|
1001 | 1021 | async fn accept_transfer(&mut self) -> Result<(), anyhow::Error> {
|
@@ -1327,22 +1347,35 @@ impl InboundRequest {
|
1327 | 1347 | {
|
1328 | 1348 | f(&mut self.state);
|
1329 | 1349 |
|
1330 |
| - if !inform { |
1331 |
| - return; |
| 1350 | + if inform { |
| 1351 | + let _ = self.sender.send(RqsEvent::Message(ChannelMessage { |
| 1352 | + id: self.state.id.clone(), |
| 1353 | + direction: ChannelDirection::LibToFront, |
| 1354 | + state: Some(self.state.state.clone()), |
| 1355 | + meta: self.state.transfer_metadata.clone(), |
| 1356 | + ..Default::default() |
| 1357 | + })); |
1332 | 1358 | }
|
| 1359 | + } |
1333 | 1360 |
|
1334 |
| - trace!("Sending msg into the channel"); |
1335 |
| - let _ = self.sender.send(ChannelMessage { |
1336 |
| - id: self.state.id.clone(), |
1337 |
| - direction: ChannelDirection::LibToFront, |
1338 |
| - rtype: Some(crate::channel::TransferType::Inbound), |
1339 |
| - state: Some(self.state.state.clone()), |
1340 |
| - meta: self.state.transfer_metadata.clone(), |
1341 |
| - ..Default::default() |
1342 |
| - }); |
1343 |
| - // Add a small sleep timer to allow the Tokio runtime to have |
1344 |
| - // some spare time to process channel's message. Otherwise it |
1345 |
| - // get spammed by new requests. Currently set to 10 micro secs. |
1346 |
| - tokio::time::sleep(SANITY_DURATION).await; |
| 1361 | + // Add a helper method to check for relevant messages |
| 1362 | + async fn wait_for_channel_message(&mut self) -> Result<Option<ChannelMessage>, anyhow::Error> { |
| 1363 | + let mut timeout = tokio::time::interval(tokio::time::Duration::from_millis(100)); |
| 1364 | + |
| 1365 | + for _ in 0..50 { |
| 1366 | + // Try for 5 seconds |
| 1367 | + tokio::select! { |
| 1368 | + _ = timeout.tick() => {}, |
| 1369 | + result = self.receiver.recv() => { |
| 1370 | + if let Ok(RqsEvent::Message(msg)) = result { |
| 1371 | + if msg.direction == ChannelDirection::FrontToLib && msg.id == self.state.id { |
| 1372 | + return Ok(Some(msg)); |
| 1373 | + } |
| 1374 | + } |
| 1375 | + } |
| 1376 | + } |
| 1377 | + } |
| 1378 | + |
| 1379 | + Ok(None) |
1347 | 1380 | }
|
1348 | 1381 | }
|
0 commit comments