diff --git a/Cargo.toml b/Cargo.toml index a3beb9d92..9423d146c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ as_conversions = "deny" [workspace.dependencies] assert_matches = "1.5.0" +axum = "0.6.12" hyper = "0.13.9" rstest = "0.17.0" serde = { version = "1.0.193", features = ["derive"] } diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index 7a4aabbf0..fc57f5cba 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] assert_matches.workspace = true +axum.workspace = true hyper.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 1233dd9b1..95233336d 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -1,9 +1,8 @@ use crate::errors::{GatewayConfigError, GatewayError}; -use hyper::body::to_bytes; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::{Json, Router}; use starknet_api::external_transaction::ExternalTransaction; -use std::convert::Infallible; use std::net::SocketAddr; use std::str::FromStr; @@ -11,9 +10,6 @@ use std::str::FromStr; #[path = "gateway_test.rs"] pub mod gateway_test; -const NOT_FOUND_RESPONSE: &str = "Not found."; -type RequestBody = Request; -type ResponseBody = Response; pub type GatewayResult = Result<(), GatewayError>; pub struct Gateway { @@ -22,14 +18,21 @@ pub struct Gateway { impl Gateway { pub async fn build_server(&self) -> GatewayResult { + // Parses the bind address from GatewayConfig, returning an error for invalid addresses. let addr = SocketAddr::from_str(&self.gateway_config.bind_address).map_err(|_| { GatewayConfigError::InvalidServerBindAddress(self.gateway_config.bind_address.clone()) })?; - let make_service = - make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(handle_request)) }); + // Sets up the router with the specified routes for the server. + let app = Router::new() + .route("/is_alive", get(is_alive)) + .route("/add_transaction", post(add_transaction)); - Server::bind(&addr).serve(make_service).await?; + // Attempts to bind and run the server. If it fails, returns an error. + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); Ok(()) } @@ -39,44 +42,14 @@ pub struct GatewayConfig { pub bind_address: String, } -async fn handle_request(request: RequestBody) -> Result, GatewayError> { - let (parts, body) = request.into_parts(); - let response = match (parts.method, parts.uri.path()) { - (Method::GET, "/is_alive") => is_alive(), - (Method::POST, "/add_transaction") => add_transaction(body).await, - _ => response(StatusCode::NOT_FOUND, NOT_FOUND_RESPONSE.to_string()), - }; - response -} - -fn is_alive() -> Result { +async fn is_alive() -> impl IntoResponse { unimplemented!("Future handling should be implemented here."); } -// TODO(Ayelet): Consider using axum instead of Hyper. -async fn add_transaction(body: Body) -> Result, GatewayError> { - let bytes = to_bytes(body).await?; - let deserialized_transaction = serde_json::from_slice::(&bytes) - .map_err(|_| GatewayError::InvalidTransactionFormat); - - match deserialized_transaction { - Ok(transaction) => { - let tx_type = match transaction { - ExternalTransaction::Declare(_) => "DECLARE", - ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT", - ExternalTransaction::Invoke(_) => "INVOKE", - }; - response(StatusCode::OK, tx_type.to_string()) - } - Err(_) => response( - StatusCode::BAD_REQUEST, - "Invalid transaction format.".to_string(), - ), +async fn add_transaction(Json(transaction_json): Json) -> impl IntoResponse { + match transaction_json { + ExternalTransaction::Declare(_) => "DECLARE", + ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT", + ExternalTransaction::Invoke(_) => "INVOKE", } } - -fn response(status: StatusCode, body_content: String) -> Result, GatewayError> { - Ok(Response::builder() - .status(status) - .body(Body::from(body_content))?) -} diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 42895abd5..2e374c1ed 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -1,23 +1,9 @@ use crate::gateway::add_transaction; -use crate::gateway::handle_request; -use hyper::{body, Body, Request}; +use axum::{body::HttpBody, response::IntoResponse}; use rstest::rstest; -use std::fs; - -#[tokio::test] -async fn test_invalid_request() { - // Create a sample GET request for an invalid path - let request = Request::get("/some_invalid_path") - .body(Body::empty()) - .unwrap(); - let response = handle_request(request).await.unwrap(); - - assert_eq!(response.status(), 404); - assert_eq!( - String::from_utf8_lossy(&body::to_bytes(response.into_body()).await.unwrap()), - "Not found." - ); -} +use starknet_api::external_transaction::ExternalTransaction; +use std::fs::File; +use std::io::BufReader; // TODO(Ayelet): Replace the use of the JSON files with generated instances, then serialize these // into JSON for testing. @@ -30,14 +16,13 @@ async fn test_invalid_request() { #[case("./src/json_files_for_testing/invoke_v3.json", "INVOKE")] #[tokio::test] async fn test_add_transaction(#[case] json_file_path: &str, #[case] expected_response: &str) { - let json_str = fs::read_to_string(json_file_path).expect("Failed to read JSON file"); - let body = Body::from(json_str); - let response = add_transaction(body) - .await - .expect("Failed to process transaction"); - let bytes = body::to_bytes(response.into_body()) - .await - .expect("Failed to read response body"); - let body_str = String::from_utf8(bytes.to_vec()).expect("Response body is not valid UTF-8"); - assert_eq!(body_str, expected_response); + let file = File::open(json_file_path).unwrap(); + let reader = BufReader::new(file); + let transaction: ExternalTransaction = serde_json::from_reader(reader).unwrap(); + let response = add_transaction(transaction.into()).await.into_response(); + let response_bytes = response.into_body().collect().await.unwrap().to_bytes(); + assert_eq!( + &String::from_utf8(response_bytes.to_vec()).unwrap(), + expected_response + ); }