Skip to content

Commit

Permalink
feat: introduce DynamicTimeoutLayer (#5006)
Browse files Browse the repository at this point in the history
* feat: introduce `DynamicTimeoutLayer`

* test: add unit test

* chore: apply suggestions from CR

* feat: add timeout option for cli
  • Loading branch information
WenyXu authored Nov 18, 2024
1 parent 9289265 commit 7c135c0
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 9 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/cmd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ flow.workspace = true
frontend = { workspace = true, default-features = false }
futures.workspace = true
human-panic = "2.0"
humantime.workspace = true
lazy_static.workspace = true
meta-client.workspace = true
meta-srv.workspace = true
Expand Down
19 changes: 18 additions & 1 deletion src/cmd/src/cli/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Duration;

use base64::engine::general_purpose;
use base64::Engine;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use humantime::format_duration;
use serde_json::Value;
use servers::http::greptime_result_v1::GreptimedbV1Response;
use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
use servers::http::GreptimeQueryOutput;
use snafu::ResultExt;

Expand All @@ -26,10 +30,16 @@ pub(crate) struct DatabaseClient {
addr: String,
catalog: String,
auth_header: Option<String>,
timeout: Option<Duration>,
}

impl DatabaseClient {
pub fn new(addr: String, catalog: String, auth_basic: Option<String>) -> Self {
pub fn new(
addr: String,
catalog: String,
auth_basic: Option<String>,
timeout: Option<Duration>,
) -> Self {
let auth_header = if let Some(basic) = auth_basic {
let encoded = general_purpose::STANDARD.encode(basic);
Some(format!("basic {}", encoded))
Expand All @@ -41,6 +51,7 @@ impl DatabaseClient {
addr,
catalog,
auth_header,
timeout,
}
}

Expand All @@ -62,6 +73,12 @@ impl DatabaseClient {
if let Some(ref auth) = self.auth_header {
request = request.header("Authorization", auth);
}
if let Some(ref timeout) = self.timeout {
request = request.header(
GREPTIME_DB_HEADER_TIMEOUT,
format_duration(*timeout).to_string(),
);
}

let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
reason: format!("bad url: {}", url),
Expand Down
13 changes: 11 additions & 2 deletions src/cmd/src/cli/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use clap::{Parser, ValueEnum};
Expand Down Expand Up @@ -83,14 +84,22 @@ pub struct ExportCommand {
/// The basic authentication for connecting to the server
#[clap(long)]
auth_basic: Option<String>,

/// The timeout of invoking the database.
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
}

impl ExportCommand {
pub async fn build(&self, guard: Vec<WorkerGuard>) -> Result<Instance> {
let (catalog, schema) = database::split_database(&self.database)?;

let database_client =
DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone());
let database_client = DatabaseClient::new(
self.addr.clone(),
catalog.clone(),
self.auth_basic.clone(),
self.timeout,
);

Ok(Instance::new(
Box::new(Export {
Expand Down
13 changes: 11 additions & 2 deletions src/cmd/src/cli/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use clap::{Parser, ValueEnum};
Expand Down Expand Up @@ -68,13 +69,21 @@ pub struct ImportCommand {
/// The basic authentication for connecting to the server
#[clap(long)]
auth_basic: Option<String>,

/// The timeout of invoking the database.
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
}

impl ImportCommand {
pub async fn build(&self, guard: Vec<WorkerGuard>) -> Result<Instance> {
let (catalog, schema) = database::split_database(&self.database)?;
let database_client =
DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone());
let database_client = DatabaseClient::new(
self.addr.clone(),
catalog.clone(),
self.auth_basic.clone(),
self.timeout,
);

Ok(Instance::new(
Box::new(Import {
Expand Down
2 changes: 2 additions & 0 deletions src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ datafusion-expr.workspace = true
datatypes.workspace = true
derive_builder.workspace = true
futures = "0.3"
futures-util.workspace = true
hashbrown = "0.14"
headers = "0.3"
hostname = "0.3"
http = "0.2"
http-body = "0.4"
humantime.workspace = true
humantime-serde.workspace = true
hyper = { version = "0.14", features = ["full"] }
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }
Expand Down
22 changes: 18 additions & 4 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ use serde_json::Value;
use snafu::{ensure, ResultExt};
use tokio::sync::oneshot::{self, Sender};
use tokio::sync::Mutex;
use tower::timeout::TimeoutLayer;
use tower::ServiceBuilder;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::trace::TraceLayer;
Expand Down Expand Up @@ -101,6 +100,9 @@ pub mod greptime_result_v1;
pub mod influxdb_result_v1;
pub mod json_result;
pub mod table_result;
mod timeout;

pub(crate) use timeout::DynamicTimeoutLayer;

#[cfg(any(test, feature = "testing"))]
pub mod test_helpers;
Expand Down Expand Up @@ -704,7 +706,7 @@ impl HttpServer {

pub fn build(&self, router: Router) -> Router {
let timeout_layer = if self.options.timeout != Duration::default() {
Some(ServiceBuilder::new().layer(TimeoutLayer::new(self.options.timeout)))
Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
} else {
info!("HTTP server timeout is disabled");
None
Expand Down Expand Up @@ -997,10 +999,12 @@ mod test {
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, UInt32Vector};
use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
use query::parser::PromQuery;
use query::query_engine::DescribeResult;
use session::context::QueryContextRef;
use tokio::sync::mpsc;
use tokio::time::Instant;

use super::*;
use crate::error::Error;
Expand Down Expand Up @@ -1062,8 +1066,8 @@ mod test {
}
}

fn timeout() -> TimeoutLayer {
TimeoutLayer::new(Duration::from_millis(10))
fn timeout() -> DynamicTimeoutLayer {
DynamicTimeoutLayer::new(Duration::from_millis(10))
}

async fn forever() {
Expand Down Expand Up @@ -1102,6 +1106,16 @@ mod test {
let client = TestClient::new(app);
let res = client.get("/test/timeout").send().await;
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);

let now = Instant::now();
let res = client
.get("/test/timeout")
.header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
.send()
.await;
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
let elapsed = now.elapsed();
assert!(elapsed > Duration::from_millis(15));
}

#[tokio::test]
Expand Down
1 change: 1 addition & 0 deletions src/servers/src/http/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub mod constants {

// LEGACY HEADERS - KEEP IT UNMODIFIED
pub const GREPTIME_DB_HEADER_FORMAT: &str = "x-greptime-format";
pub const GREPTIME_DB_HEADER_TIMEOUT: &str = "x-greptime-timeout";
pub const GREPTIME_DB_HEADER_EXECUTION_TIME: &str = "x-greptime-execution-time";
pub const GREPTIME_DB_HEADER_METRICS: &str = "x-greptime-metrics";
pub const GREPTIME_DB_HEADER_NAME: &str = "x-greptime-db-name";
Expand Down
144 changes: 144 additions & 0 deletions src/servers/src/http/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use axum::body::Body;
use axum::http::Request;
use axum::response::Response;
use pin_project::pin_project;
use tokio::time::Sleep;
use tower::timeout::error::Elapsed;
use tower::{BoxError, Layer, Service};

use crate::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;

/// [`Timeout`] response future
///
/// [`Timeout`]: crate::timeout::Timeout
///
/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/future.rs
#[derive(Debug)]
#[pin_project]
pub struct ResponseFuture<T> {
#[pin]
response: T,
#[pin]
sleep: Sleep,
}

impl<T> ResponseFuture<T> {
pub(crate) fn new(response: T, sleep: Sleep) -> Self {
ResponseFuture { response, sleep }
}
}

impl<F, T, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<T, E>>,
E: Into<BoxError>,
{
type Output = Result<T, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

// First, try polling the future
match this.response.poll(cx) {
Poll::Ready(v) => return Poll::Ready(v.map_err(Into::into)),
Poll::Pending => {}
}

// Now check the sleep
match this.sleep.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => Poll::Ready(Err(Elapsed::new().into())),
}
}
}

/// Applies a timeout to requests via the supplied inner service.
///
/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/layer.rs
#[derive(Debug, Clone)]
pub struct DynamicTimeoutLayer {
default_timeout: Duration,
}

impl DynamicTimeoutLayer {
/// Create a timeout from a duration
pub fn new(default_timeout: Duration) -> Self {
DynamicTimeoutLayer { default_timeout }
}
}

impl<S> Layer<S> for DynamicTimeoutLayer {
type Service = DynamicTimeout<S>;

fn layer(&self, service: S) -> Self::Service {
DynamicTimeout::new(service, self.default_timeout)
}
}

/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/mod.rs
#[derive(Clone)]
pub struct DynamicTimeout<S> {
inner: S,
default_timeout: Duration,
}

impl<S> DynamicTimeout<S> {
/// Create a new [`DynamicTimeout`] with the given timeout
pub fn new(inner: S, default_timeout: Duration) -> Self {
DynamicTimeout {
inner,
default_timeout,
}
}
}

impl<S> Service<Request<Body>> for DynamicTimeout<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
}
}

fn call(&mut self, request: Request<Body>) -> Self::Future {
let user_timeout = request
.headers()
.get(GREPTIME_DB_HEADER_TIMEOUT)
.and_then(|value| {
value
.to_str()
.ok()
.and_then(|value| humantime::parse_duration(value).ok())
});
let response = self.inner.call(request);
let sleep = tokio::time::sleep(user_timeout.unwrap_or(self.default_timeout));
ResponseFuture::new(response, sleep)
}
}

0 comments on commit 7c135c0

Please sign in to comment.