Skip to content

Commit

Permalink
Add receive() to Transport trait
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrit committed Feb 6, 2025
1 parent 7a95b94 commit 2b07482
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
19 changes: 15 additions & 4 deletions crates/mcp-core/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ pub trait Transport: Send + Sync {

/// Sends a message
async fn send(&mut self, message: JSONRPCMessage) -> Result<(), Self::Error>;
}

/// Represents a transport that can receive messages
#[async_trait::async_trait]
pub trait ReceiveTransport: Transport {
/// Receives a message
async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error>;
}

#[cfg(test)]
pub mod test_utils {
use mcp_types::JSONRPCNotification;

use super::*;
use std::sync::{Arc, Mutex};

Expand Down Expand Up @@ -78,5 +77,17 @@ pub mod test_utils {
self.sent_messages.lock().unwrap().push(message);
Ok(())
}

async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error> {
if *self.should_fail.lock().unwrap() {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "Mock error"));
}

Ok(JSONRPCMessage::Notification(JSONRPCNotification {
jsonrpc: "2.0".to_string(),
method: "test".to_string(),
params: None,
}))
}
}
}
5 changes: 1 addition & 4 deletions crates/mcp-transport-sse/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::future::Future;
use std::pin::Pin;

use futures::TryStreamExt;
use mcp_core::transport::{ReceiveTransport, Transport};
use mcp_core::transport::Transport;
use mcp_types::JSONRPCMessage;
use reqwest::Client;
use reqwest_websocket::{Message, RequestBuilderExt};
Expand Down Expand Up @@ -111,10 +111,7 @@ impl Transport for SSEClientTransport {

Ok(())
}
}

#[async_trait::async_trait]
impl ReceiveTransport for SSEClientTransport {
async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error> {
if !self.started {
return Err(SSETransportError::NotStarted);
Expand Down
5 changes: 1 addition & 4 deletions crates/mcp-transport-sse/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::Arc;

use axum::{extract::State, response::IntoResponse, routing::post, Json, Router};
use futures::SinkExt;
use mcp_core::transport::{ReceiveTransport, Transport};
use mcp_core::transport::Transport;
use mcp_types::JSONRPCMessage;
use reqwest::Client;
use reqwest_websocket::{Message, RequestBuilderExt};
Expand Down Expand Up @@ -112,10 +112,7 @@ impl Transport for SSEServerTransport {

Ok(())
}
}

#[async_trait::async_trait]
impl ReceiveTransport for SSEServerTransport {
async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error> {
if !self.started {
return Err(SSETransportError::NotStarted);
Expand Down
5 changes: 1 addition & 4 deletions crates/mcp-transport-stdio/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
process::Child,
};

use mcp_core::transport::{ReceiveTransport, Transport};
use mcp_core::transport::Transport;
use mcp_types::JSONRPCMessage;
use tracing::info;

Expand Down Expand Up @@ -232,10 +232,7 @@ impl Transport for StdioClientTransport {

Ok(())
}
}

#[async_trait::async_trait]
impl ReceiveTransport for StdioClientTransport {
async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error> {
if !self.started {
return Err(std::io::Error::new(
Expand Down
5 changes: 1 addition & 4 deletions crates/mcp-transport-stdio/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![cfg(feature = "server")]
use std::io::{BufRead, Write};

use mcp_core::transport::{ReceiveTransport, Transport};
use mcp_core::transport::Transport;
use mcp_types::JSONRPCMessage;

pub struct StdioServerTransport {
Expand Down Expand Up @@ -60,10 +60,7 @@ impl Transport for StdioServerTransport {

Ok(())
}
}

#[async_trait::async_trait]
impl ReceiveTransport for StdioServerTransport {
async fn receive(&mut self) -> Result<JSONRPCMessage, Self::Error> {
if !self.started {
return Err(std::io::Error::new(
Expand Down

0 comments on commit 2b07482

Please sign in to comment.