diff --git a/crates/driver/example.toml b/crates/driver/example.toml index f796761b08..926c84ae08 100644 --- a/crates/driver/example.toml +++ b/crates/driver/example.toml @@ -5,6 +5,7 @@ absolute-slippage = "40000000000000000" # Denominated in wei, optional relative-slippage = "0.1" # Percentage in the [0, 1] range account = "0x0000000000000000000000000000000000000000000000000000000000000001" # The private key of the solver merge-solutions = true # Multiple solutions proposed by the solver may be combined into one by the driver +response-size-limit-max-bytes = 30000000 [solver.request-headers] fake-header-one = "FAKE-HEADER-VALUE" # For instance an authorization token which must be provided on each request diff --git a/crates/driver/src/infra/config/file/load.rs b/crates/driver/src/infra/config/file/load.rs index 61e131abb6..ef338d14f4 100644 --- a/crates/driver/src/infra/config/file/load.rs +++ b/crates/driver/src/infra/config/file/load.rs @@ -93,6 +93,7 @@ pub async fn load(chain: eth::ChainId, path: &Path) -> infra::Config { s3: config.s3.map(Into::into), solver_native_token: config.manage_native_token.to_domain(), quote_tx_origin: config.quote_tx_origin.map(eth::Address), + response_size_limit_max_bytes: config.response_size_limit_max_bytes, } })) .await, diff --git a/crates/driver/src/infra/config/file/mod.rs b/crates/driver/src/infra/config/file/mod.rs index 9f1bd20b2c..b461883636 100644 --- a/crates/driver/src/infra/config/file/mod.rs +++ b/crates/driver/src/infra/config/file/mod.rs @@ -256,6 +256,10 @@ struct SolverConfig { /// Which `tx.origin` is required to make a quote simulation pass. #[serde(default)] quote_tx_origin: Option, + + /// Maximum HTTP response size the driver will accept in bytes. + #[serde(default = "default_response_size_limit_max_bytes")] + response_size_limit_max_bytes: usize, } #[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Serialize)] @@ -590,6 +594,10 @@ fn default_http_timeout() -> Duration { Duration::from_secs(10) } +fn default_response_size_limit_max_bytes() -> usize { + 30_000_000 +} + #[derive(Clone, Debug, Deserialize, Default)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub enum GasEstimatorType { diff --git a/crates/driver/src/infra/solver/mod.rs b/crates/driver/src/infra/solver/mod.rs index cb1fb126e0..ba7332dd65 100644 --- a/crates/driver/src/infra/solver/mod.rs +++ b/crates/driver/src/infra/solver/mod.rs @@ -29,8 +29,6 @@ use { pub mod dto; -const SOLVER_RESPONSE_MAX_BYTES: usize = 10_000_000; - // TODO At some point I should be checking that the names are unique, I don't // think I'm doing that. /// The solver name. The user can configure this to be anything that they like. @@ -124,6 +122,7 @@ pub struct Config { pub solver_native_token: ManageNativeToken, /// Which `tx.origin` is required to make quote verification pass. pub quote_tx_origin: Option, + pub response_size_limit_max_bytes: usize, } impl Solver { @@ -234,7 +233,7 @@ impl Solver { if let Some(id) = observe::request_id::get_task_local_storage() { req = req.header("X-REQUEST-ID", id); } - let res = util::http::send(SOLVER_RESPONSE_MAX_BYTES, req).await; + let res = util::http::send(self.config.response_size_limit_max_bytes, req).await; super::observe::solver_response(&url, res.as_deref()); let res = res?; let res: dto::Solutions = serde_json::from_str(&res) @@ -260,8 +259,9 @@ impl Solver { if let Some(id) = observe::request_id::get_task_local_storage() { req = req.header("X-REQUEST-ID", id); } + let response_size = self.config.response_size_limit_max_bytes; let future = async move { - if let Err(error) = util::http::send(SOLVER_RESPONSE_MAX_BYTES, req).await { + if let Err(error) = util::http::send(response_size, req).await { tracing::warn!(?error, "failed to notify solver"); } }; diff --git a/crates/driver/src/util/http.rs b/crates/driver/src/util/http.rs index 69e070e485..fed2b90ad0 100644 --- a/crates/driver/src/util/http.rs +++ b/crates/driver/src/util/http.rs @@ -4,10 +4,14 @@ pub async fn send(limit_bytes: usize, req: reqwest::RequestBuilder) -> Result limit_bytes { + data.extend_from_slice(&chunk); + if data.len() > limit_bytes { + tracing::trace!( + response = String::from_utf8_lossy(&data).as_ref(), + "response size exceeded" + ); return Err(Error::ResponseTooLarge { limit_bytes }); } - data.extend_from_slice(&chunk); } let body = String::from_utf8(data).map_err(Error::NotUtf8)?; if res.status().is_success() {