Skip to content

Commit

Permalink
Client initialize, ping compiling
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrit committed Feb 12, 2025
1 parent b330852 commit 7a2bc88
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 329 deletions.
64 changes: 53 additions & 11 deletions crates/mcp-core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use super::{
transport::{Transport, TransportError},
};
use mcp_types::{
ClientCapabilities, Implementation, InitializeRequestParams, JSONRPCNotification,
JSONRPCRequest, ListToolsResult, RequestId, ServerCapabilities,
ClientCapabilities, Implementation, InitializeRequestParams, InitializeResult,
InitializedNotificationParams, ListToolsRequestParams, ListToolsResult, PingRequestParams,
PingRequestParamsMeta, ServerCapabilities, LATEST_PROTOCOL_VERSION,
};

pub mod handlers;
Expand Down Expand Up @@ -34,31 +35,72 @@ impl<T: Transport> Client<T> {
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "mcp-core".into(),
version: "0.1.0".into(),
version: env!("CARGO_PKG_VERSION").into(),
},
server_info: None,
server_capabilities: None,
}
}

pub async fn connect(&mut self) -> Result<(), TransportError> {
pub async fn connect(&mut self) -> Result<(), ClientError> {
self.protocol.connect().await.map_err(Into::into)
}

pub async fn initialize(&mut self) -> Result<(), ClientError> {
let params = InitializeRequestParams {
client_info: self.client_info.clone(),
capabilities: self.capabilities.clone(),
protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
};

let result = self
.protocol
.send_request::<InitializeResult>("initialize", serde_json::to_value(params).unwrap())
.await?;

self.server_info = Some(result.server_info);
self.server_capabilities = Some(result.capabilities);

let params = InitializedNotificationParams {
meta: Default::default(),
};

self.protocol
.send_notification(
"notification/initialized",
serde_json::to_value(params).unwrap(),
)
.await?;

Ok(())
}

pub async fn ping(&mut self) -> Result<(), ClientError> {
let params = PingRequestParams {
meta: Some(PingRequestParamsMeta {
progress_token: None,
}),
};

self.protocol
.send_request("ping".into(), serde_json::Value::Null)
.send_request::<()>("ping", serde_json::to_value(params).unwrap())
.await?;

Ok(())
}

pub async fn list_tools(&mut self) -> Result<ListToolsResult, ClientError> {
let result = self
.protocol
.send_request("tools/list".into(), serde_json::Value::Null)
.await?;
pub async fn list_tools(
&mut self,
cursor: Option<String>,
) -> Result<ListToolsResult, ClientError> {
let params = ListToolsRequestParams { cursor };

Ok(serde_json::from_value(result)?)
self.protocol
.send_request::<ListToolsResult>(
"tools/list".into(),
serde_json::to_value(params).unwrap(),
)
.await
.map_err(Into::into)
}
}
114 changes: 39 additions & 75 deletions crates/mcp-core/src/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::transport::{Transport, TransportError};
use async_trait::async_trait;
use mcp_types::{
JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, MCPError,
MCPResult, RequestId,
};
use mcp_types::{v2024_11_05::convert::assert_v2024_11_05_type, RequestId};
use serde::de::DeserializeOwned;
use serde_json::json;
use std::{
collections::HashMap,
any::TypeId,
sync::atomic::{AtomicI64, Ordering},
};
use tokio::sync::oneshot;
Expand All @@ -20,105 +18,71 @@ pub enum ProtocolError {
RequestTimedOut,
#[error("request failed: {0}")]
RequestFailed(#[from] TransportError),
#[error("invalid result: expected type {0:?}, got {1:?}")]
InvalidResult(TypeId, serde_json::Map<String, serde_json::Value>),
}

pub struct Protocol<T: Transport> {
transport: T,
request_handlers: HashMap<&'static str, Box<dyn RequestHandler>>,
notification_handlers: HashMap<&'static str, Box<dyn NotificationHandler>>,
next_id: AtomicI64,
}

impl<T: Transport> Protocol<T> {
pub fn new(transport: T) -> Self {
Self {
transport,
request_handlers: HashMap::new(),
notification_handlers: HashMap::new(),
next_id: AtomicI64::new(1),
}
}

/// Main message processing loop
pub async fn run(mut self) -> Result<(), ProtocolError> {
while let Some(message) = self.transport.recv().await {
match message? {
JSONRPCMessage::Request(req) => self.handle_request(req).await,
JSONRPCMessage::Notification(notif) => todo!(),
JSONRPCMessage::Response(resp) => todo!(),
JSONRPCMessage::Error(err) => todo!(),
}
}
Ok(())
pub async fn connect(&mut self) -> Result<(), ProtocolError> {
self.transport.start().await.map_err(Into::into)
}

/// Send a request with method and params, handling JSON-RPC details internally
pub async fn send_request(
pub async fn send_request<R: DeserializeOwned + 'static>(
&mut self,
method: String,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ProtocolError> {
) -> Result<R, ProtocolError> {
let id = RequestId::Integer(self.next_id.fetch_add(1, Ordering::SeqCst));

// TODO: Should type conversion be handled differently instead of dealing with serde_json in Protocol?
let params = match params {
serde_json::Value::Null => None,
_ => Some(serde_json::from_value(params)?),
};

let request = JSONRPCRequest {
id: id.clone(),
method,
params,
jsonrpc: "2.0".into(),
};
let request = json!({
"id": id,
"jsonrpc": "2.0",
"method": method,
"params": params,
});

let (sender, receiver) = oneshot::channel();

// Send the request
self.transport.send_request(request, sender).await?;
let response = receiver.await?;

Ok(serde_json::to_value(response.result).unwrap())
}
// Wait for the response
let result = receiver.await?.result.meta;

async fn handle_request(&mut self, req: JSONRPCRequest) {
if let Some(handler) = self.request_handlers.get(req.method.as_str()) {
// Dispatch to handler
let result = handler
.handle(serde_json::to_value(req.params).unwrap())
.await;
let response = JSONRPCResponse {
id: req.id,
result: MCPResult {
meta: serde_json::from_value(result).unwrap(),
},
jsonrpc: "2.0".into(),
};
self.transport.send_response(response).await.unwrap();
} else {
// Send method not found error
let error = JSONRPCError {
error: MCPError {
code: -32601,
data: None,
message: "Method not found".into(),
},
id: req.id,
jsonrpc: "2.0".into(),
};
self.transport.send_error(error).await.unwrap();
// Validate and deserialize the result
match assert_v2024_11_05_type::<R>(serde_json::to_value(result.clone()).unwrap()) {
Some(result) => Ok(result),
None => Err(ProtocolError::InvalidResult(TypeId::of::<R>(), result)),
}
}
}

/// Trait for handling incoming requests
#[async_trait]
trait RequestHandler: Send + Sync {
async fn handle(&self, params: serde_json::Value) -> serde_json::Value;
}
/// Send a notification with method and params, handling JSON-RPC details internally
pub async fn send_notification(
&mut self,
method: &str,
params: serde_json::Value,
) -> Result<(), ProtocolError> {
let notification = json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});

/// Trait for handling notifications
#[async_trait]
trait NotificationHandler: Send + Sync {
async fn handle(&self, params: serde_json::Value);
self.transport.send_notification(notification).await?;

Ok(())
}
}
29 changes: 7 additions & 22 deletions crates/mcp-core/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
//! decoding messages, as well as transmitting/receiving them.
use async_trait::async_trait;
use mcp_types::{
ClientCapabilities, Implementation, InitializeRequestParams, InitializeResult, JSONRPCError,
JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, RequestId,
LATEST_PROTOCOL_VERSION,
};
use mcp_types::{JSONRPCError, JSONRPCMessage, JSONRPCResponse};
use std::error::Error;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::oneshot;

/// Core transport error type
#[derive(Debug, thiserror::Error)]
Expand All @@ -26,17 +22,20 @@ pub enum TransportError {

#[async_trait]
pub trait Transport {
/// Start communication
async fn start(&mut self) -> Result<(), TransportError>;

/// Send a JSON-RPC request and wait for response
async fn send_request(
&mut self,
request: JSONRPCRequest,
request: serde_json::Value,
sender: oneshot::Sender<JSONRPCResponse>,
) -> Result<(), TransportError>;

/// Send a JSON-RPC notification (fire-and-forget)
async fn send_notification(
&mut self,
notification: JSONRPCNotification,
notification: serde_json::Value,
) -> Result<(), TransportError>;

/// Send a JSON-RPC response to a request
Expand All @@ -51,17 +50,3 @@ pub trait Transport {
/// Close the transport connection
async fn close(self) -> Result<(), TransportError>;
}

/// Extension trait for common transport operations
#[async_trait]
pub trait TransportExt: Transport {
async fn initialize(
&mut self,
client_info: Implementation,
capabilities: ClientCapabilities,
) -> Result<InitializeResult, TransportError> {
todo!()
}
}

impl<T: Transport + ?Sized> TransportExt for T {}
2 changes: 1 addition & 1 deletion crates/mcp-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ default = ["2024_11_05"]
2024_11_05 = []

[dependencies]
jsonschema = "0.28.3"
jsonschema = "0.29.0"
schemars = "0.8.21"
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
Loading

0 comments on commit 7a2bc88

Please sign in to comment.