Skip to content

Commit

Permalink
refactor: use axum instead of hyper (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayeletstarkware authored Apr 3, 2024
1 parent 4408e4b commit 2e9a1d0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 74 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ as_conversions = "deny"

[workspace.dependencies]
assert_matches = "1.5.0"
axum = "0.6.12"
hyper = "0.13.9"
rstest = "0.17.0"
serde = { version = "1.0.193", features = ["derive"] }
Expand Down
1 change: 1 addition & 0 deletions crates/gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ workspace = true

[dependencies]
assert_matches.workspace = true
axum.workspace = true
hyper.workspace = true
serde.workspace = true
serde_json.workspace = true
Expand Down
65 changes: 19 additions & 46 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
use crate::errors::{GatewayConfigError, GatewayError};
use hyper::body::to_bytes;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use starknet_api::external_transaction::ExternalTransaction;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::str::FromStr;

#[cfg(test)]
#[path = "gateway_test.rs"]
pub mod gateway_test;

const NOT_FOUND_RESPONSE: &str = "Not found.";
type RequestBody = Request<Body>;
type ResponseBody = Response<Body>;
pub type GatewayResult = Result<(), GatewayError>;

pub struct Gateway {
Expand All @@ -22,14 +18,21 @@ pub struct Gateway {

impl Gateway {
pub async fn build_server(&self) -> GatewayResult {
// Parses the bind address from GatewayConfig, returning an error for invalid addresses.
let addr = SocketAddr::from_str(&self.gateway_config.bind_address).map_err(|_| {
GatewayConfigError::InvalidServerBindAddress(self.gateway_config.bind_address.clone())
})?;

let make_service =
make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(handle_request)) });
// Sets up the router with the specified routes for the server.
let app = Router::new()
.route("/is_alive", get(is_alive))
.route("/add_transaction", post(add_transaction));

Server::bind(&addr).serve(make_service).await?;
// Create a server that runs forever.
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();

Ok(())
}
Expand All @@ -39,44 +42,14 @@ pub struct GatewayConfig {
pub bind_address: String,
}

async fn handle_request(request: RequestBody) -> Result<Response<Body>, GatewayError> {
let (parts, body) = request.into_parts();
let response = match (parts.method, parts.uri.path()) {
(Method::GET, "/is_alive") => is_alive(),
(Method::POST, "/add_transaction") => add_transaction(body).await,
_ => response(StatusCode::NOT_FOUND, NOT_FOUND_RESPONSE.to_string()),
};
response
}

fn is_alive() -> Result<ResponseBody, GatewayError> {
async fn is_alive() -> impl IntoResponse {
unimplemented!("Future handling should be implemented here.");
}

// TODO(Ayelet): Consider using axum instead of Hyper.
async fn add_transaction(body: Body) -> Result<Response<Body>, GatewayError> {
let bytes = to_bytes(body).await?;
let deserialized_transaction = serde_json::from_slice::<ExternalTransaction>(&bytes)
.map_err(|_| GatewayError::InvalidTransactionFormat);

match deserialized_transaction {
Ok(transaction) => {
let tx_type = match transaction {
ExternalTransaction::Declare(_) => "DECLARE",
ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT",
ExternalTransaction::Invoke(_) => "INVOKE",
};
response(StatusCode::OK, tx_type.to_string())
}
Err(_) => response(
StatusCode::BAD_REQUEST,
"Invalid transaction format.".to_string(),
),
async fn add_transaction(Json(transaction_json): Json<ExternalTransaction>) -> impl IntoResponse {
match transaction_json {
ExternalTransaction::Declare(_) => "DECLARE",
ExternalTransaction::DeployAccount(_) => "DEPLOY_ACCOUNT",
ExternalTransaction::Invoke(_) => "INVOKE",
}
}

fn response(status: StatusCode, body_content: String) -> Result<Response<Body>, GatewayError> {
Ok(Response::builder()
.status(status)
.body(Body::from(body_content))?)
}
41 changes: 13 additions & 28 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
use crate::gateway::add_transaction;
use crate::gateway::handle_request;
use hyper::{body, Body, Request};
use axum::{body::HttpBody, response::IntoResponse};
use rstest::rstest;
use std::fs;

#[tokio::test]
async fn test_invalid_request() {
// Create a sample GET request for an invalid path
let request = Request::get("/some_invalid_path")
.body(Body::empty())
.unwrap();
let response = handle_request(request).await.unwrap();

assert_eq!(response.status(), 404);
assert_eq!(
String::from_utf8_lossy(&body::to_bytes(response.into_body()).await.unwrap()),
"Not found."
);
}
use starknet_api::external_transaction::ExternalTransaction;
use std::fs::File;
use std::io::BufReader;

// TODO(Ayelet): Replace the use of the JSON files with generated instances, then serialize these
// into JSON for testing.
Expand All @@ -30,14 +16,13 @@ async fn test_invalid_request() {
#[case("./src/json_files_for_testing/invoke_v3.json", "INVOKE")]
#[tokio::test]
async fn test_add_transaction(#[case] json_file_path: &str, #[case] expected_response: &str) {
let json_str = fs::read_to_string(json_file_path).expect("Failed to read JSON file");
let body = Body::from(json_str);
let response = add_transaction(body)
.await
.expect("Failed to process transaction");
let bytes = body::to_bytes(response.into_body())
.await
.expect("Failed to read response body");
let body_str = String::from_utf8(bytes.to_vec()).expect("Response body is not valid UTF-8");
assert_eq!(body_str, expected_response);
let file = File::open(json_file_path).unwrap();
let reader = BufReader::new(file);
let transaction: ExternalTransaction = serde_json::from_reader(reader).unwrap();
let response = add_transaction(transaction.into()).await.into_response();
let response_bytes = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(
&String::from_utf8(response_bytes.to_vec()).unwrap(),
expected_response
);
}

0 comments on commit 2e9a1d0

Please sign in to comment.