diff --git a/Cargo.toml b/Cargo.toml
index a3beb9d92..9423d146c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,7 @@ as_conversions = "deny"
[workspace.dependencies]
assert_matches = "1.5.0"
+axum = "0.6.12"
hyper = "0.13.9"
rstest = "0.17.0"
serde = { version = "1.0.193", features = ["derive"] }
diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml
index 7a4aabbf0..fc57f5cba 100644
--- a/crates/gateway/Cargo.toml
+++ b/crates/gateway/Cargo.toml
@@ -10,6 +10,7 @@ workspace = true
[dependencies]
assert_matches.workspace = true
+axum.workspace = true
hyper.workspace = true
serde.workspace = true
serde_json.workspace = true
diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs
index 1233dd9b1..95233336d 100644
--- a/crates/gateway/src/gateway.rs
+++ b/crates/gateway/src/gateway.rs
@@ -1,9 +1,8 @@
use crate::errors::{GatewayConfigError, GatewayError};
-use hyper::body::to_bytes;
-use hyper::service::{make_service_fn, service_fn};
-use hyper::{Body, Method, Request, Response, Server, StatusCode};
+use axum::response::IntoResponse;
+use axum::routing::{get, post};
+use axum::{Json, Router};
use starknet_api::external_transaction::ExternalTransaction;
-use std::convert::Infallible;
use std::net::SocketAddr;
use std::str::FromStr;
@@ -11,9 +10,6 @@ use std::str::FromStr;
#[path = "gateway_test.rs"]
pub mod gateway_test;
-const NOT_FOUND_RESPONSE: &str = "Not found.";
-type RequestBody = Request
;
-type ResponseBody = Response;
pub type GatewayResult = Result<(), GatewayError>;
pub struct Gateway {
@@ -22,14 +18,21 @@ pub struct Gateway {
impl Gateway {
pub async fn build_server(&self) -> GatewayResult {
+ // Parses the bind address from GatewayConfig, returning an error for invalid addresses.
let addr = SocketAddr::from_str(&self.gateway_config.bind_address).map_err(|_| {
GatewayConfigError::InvalidServerBindAddress(self.gateway_config.bind_address.clone())
})?;
- let make_service =
- make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(handle_request)) });
+ // Sets up the router with the specified routes for the server.
+ let app = Router::new()
+ .route("/is_alive", get(is_alive))
+ .route("/add_transaction", post(add_transaction));
- Server::bind(&addr).serve(make_service).await?;
+ // Attempts to bind and run the server. If it fails, returns an error.
+ axum::Server::bind(&addr)
+ .serve(app.into_make_service())
+ .await
+ .unwrap();
Ok(())
}
@@ -39,44 +42,14 @@ pub struct GatewayConfig {
pub bind_address: String,
}
-async fn handle_request(request: RequestBody) -> Result, GatewayError> {
- let (parts, body) = request.into_parts();
- let response = match (parts.method, parts.uri.path()) {
- (Method::GET, "/is_alive") => is_alive(),
- (Method::POST, "/add_transaction") => add_transaction(body).await,
- _ => response(StatusCode::NOT_FOUND, NOT_FOUND_RESPONSE.to_string()),
- };
- response
-}
-
-fn is_alive() -> Result {
+async fn is_alive() -> impl IntoResponse {
unimplemented!("Future handling should be implemented here.");
}
-// TODO(Ayelet): Consider using axum instead of Hyper.
-async fn add_transaction(body: Body) -> Result, GatewayError> {
- let bytes = to_bytes(body).await?;
- let deserialized_transaction = serde_json::from_slice::(&bytes)
- .map_err(|_| GatewayError::InvalidTransactionFormat);
-
- match deserialized_transaction {
- Ok(transaction) => {
- let tx_type = match transaction {
- ExternalTransaction::Declare(_) => "DECLARE",
- ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT",
- ExternalTransaction::Invoke(_) => "INVOKE",
- };
- response(StatusCode::OK, tx_type.to_string())
- }
- Err(_) => response(
- StatusCode::BAD_REQUEST,
- "Invalid transaction format.".to_string(),
- ),
+async fn add_transaction(Json(transaction_json): Json) -> impl IntoResponse {
+ match transaction_json {
+ ExternalTransaction::Declare(_) => "DECLARE",
+ ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT",
+ ExternalTransaction::Invoke(_) => "INVOKE",
}
}
-
-fn response(status: StatusCode, body_content: String) -> Result, GatewayError> {
- Ok(Response::builder()
- .status(status)
- .body(Body::from(body_content))?)
-}
diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs
index 42895abd5..2e374c1ed 100644
--- a/crates/gateway/src/gateway_test.rs
+++ b/crates/gateway/src/gateway_test.rs
@@ -1,23 +1,9 @@
use crate::gateway::add_transaction;
-use crate::gateway::handle_request;
-use hyper::{body, Body, Request};
+use axum::{body::HttpBody, response::IntoResponse};
use rstest::rstest;
-use std::fs;
-
-#[tokio::test]
-async fn test_invalid_request() {
- // Create a sample GET request for an invalid path
- let request = Request::get("/some_invalid_path")
- .body(Body::empty())
- .unwrap();
- let response = handle_request(request).await.unwrap();
-
- assert_eq!(response.status(), 404);
- assert_eq!(
- String::from_utf8_lossy(&body::to_bytes(response.into_body()).await.unwrap()),
- "Not found."
- );
-}
+use starknet_api::external_transaction::ExternalTransaction;
+use std::fs::File;
+use std::io::BufReader;
// TODO(Ayelet): Replace the use of the JSON files with generated instances, then serialize these
// into JSON for testing.
@@ -30,14 +16,13 @@ async fn test_invalid_request() {
#[case("./src/json_files_for_testing/invoke_v3.json", "INVOKE")]
#[tokio::test]
async fn test_add_transaction(#[case] json_file_path: &str, #[case] expected_response: &str) {
- let json_str = fs::read_to_string(json_file_path).expect("Failed to read JSON file");
- let body = Body::from(json_str);
- let response = add_transaction(body)
- .await
- .expect("Failed to process transaction");
- let bytes = body::to_bytes(response.into_body())
- .await
- .expect("Failed to read response body");
- let body_str = String::from_utf8(bytes.to_vec()).expect("Response body is not valid UTF-8");
- assert_eq!(body_str, expected_response);
+ let file = File::open(json_file_path).unwrap();
+ let reader = BufReader::new(file);
+ let transaction: ExternalTransaction = serde_json::from_reader(reader).unwrap();
+ let response = add_transaction(transaction.into()).await.into_response();
+ let response_bytes = response.into_body().collect().await.unwrap().to_bytes();
+ assert_eq!(
+ &String::from_utf8(response_bytes.to_vec()).unwrap(),
+ expected_response
+ );
}