diff --git a/config/config.md b/config/config.md index d3353930b163..b2be8de9018a 100644 --- a/config/config.md +++ b/config/config.md @@ -18,6 +18,7 @@ | `init_regions_parallelism` | Integer | `16` | Parallelism of initializing regions. | | `max_concurrent_queries` | Integer | `0` | The maximum current queries allowed to be executed. Zero means unlimited. | | `enable_telemetry` | Bool | `true` | Enable telemetry to collect anonymous usage data. Enabled by default. | +| `max_in_flight_write_bytes` | String | Unset | The maximum in-flight write bytes. | | `runtime` | -- | -- | The runtime options. | | `runtime.global_rt_size` | Integer | `8` | The number of threads to execute the runtime for global read operations. | | `runtime.compact_rt_size` | Integer | `4` | The number of threads to execute the runtime for global write operations. | @@ -195,6 +196,7 @@ | Key | Type | Default | Descriptions | | --- | -----| ------- | ----------- | | `default_timezone` | String | Unset | The default timezone of the server. | +| `max_in_flight_write_bytes` | String | Unset | The maximum in-flight write bytes. | | `runtime` | -- | -- | The runtime options. | | `runtime.global_rt_size` | Integer | `8` | The number of threads to execute the runtime for global read operations. | | `runtime.compact_rt_size` | Integer | `4` | The number of threads to execute the runtime for global write operations. | diff --git a/config/frontend.example.toml b/config/frontend.example.toml index 1fb372a6d12e..b8e6c5cd8b9e 100644 --- a/config/frontend.example.toml +++ b/config/frontend.example.toml @@ -2,6 +2,10 @@ ## @toml2docs:none-default default_timezone = "UTC" +## The maximum in-flight write bytes. +## @toml2docs:none-default +#+ max_in_flight_write_bytes = "500MB" + ## The runtime options. #+ [runtime] ## The number of threads to execute the runtime for global read operations. diff --git a/config/standalone.example.toml b/config/standalone.example.toml index b73246d37f0a..77445f8883bf 100644 --- a/config/standalone.example.toml +++ b/config/standalone.example.toml @@ -18,6 +18,10 @@ max_concurrent_queries = 0 ## Enable telemetry to collect anonymous usage data. Enabled by default. #+ enable_telemetry = true +## The maximum in-flight write bytes. +## @toml2docs:none-default +#+ max_in_flight_write_bytes = "500MB" + ## The runtime options. #+ [runtime] ## The number of threads to execute the runtime for global read operations. diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 8490e14147b2..e3675a7db7c1 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -22,6 +22,7 @@ use catalog::information_schema::InformationExtension; use catalog::kvbackend::KvBackendCatalogManager; use clap::Parser; use client::api::v1::meta::RegionRole; +use common_base::readable_size::ReadableSize; use common_base::Plugins; use common_catalog::consts::{MIN_USER_FLOW_ID, MIN_USER_TABLE_ID}; use common_config::{metadata_store_dir, Configurable, KvBackendConfig}; @@ -152,6 +153,7 @@ pub struct StandaloneOptions { pub tracing: TracingOptions, pub init_regions_in_background: bool, pub init_regions_parallelism: usize, + pub max_in_flight_write_bytes: Option, } impl Default for StandaloneOptions { @@ -181,6 +183,7 @@ impl Default for StandaloneOptions { tracing: TracingOptions::default(), init_regions_in_background: false, init_regions_parallelism: 16, + max_in_flight_write_bytes: None, } } } @@ -218,6 +221,7 @@ impl StandaloneOptions { user_provider: cloned_opts.user_provider, // Handle the export metrics task run by standalone to frontend for execution export_metrics: cloned_opts.export_metrics, + max_in_flight_write_bytes: cloned_opts.max_in_flight_write_bytes, ..Default::default() } } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index c6e7218a389e..5275d7f4a2e3 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -321,6 +321,12 @@ pub enum Error { location: Location, source: BoxedError, }, + + #[snafu(display("In-flight write bytes exceeded the maximum limit"))] + InFlightWriteBytesExceeded { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -392,6 +398,8 @@ impl ErrorExt for Error { Error::StartScriptManager { source, .. } => source.status_code(), Error::TableOperation { source, .. } => source.status_code(), + + Error::InFlightWriteBytesExceeded { .. } => StatusCode::RateLimited, } } diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index 55f2dae3c386..a424c1c8095b 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_base::readable_size::ReadableSize; use common_config::config::Configurable; use common_options::datanode::DatanodeClientOptions; use common_telemetry::logging::{LoggingOptions, TracingOptions}; @@ -46,6 +47,7 @@ pub struct FrontendOptions { pub user_provider: Option, pub export_metrics: ExportMetricsOption, pub tracing: TracingOptions, + pub max_in_flight_write_bytes: Option, } impl Default for FrontendOptions { @@ -68,6 +70,7 @@ impl Default for FrontendOptions { user_provider: None, export_metrics: ExportMetricsOption::default(), tracing: TracingOptions::default(), + max_in_flight_write_bytes: None, } } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index b22bde96e0ff..aef8cd41492e 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -86,6 +86,7 @@ use crate::error::{ }; use crate::frontend::FrontendOptions; use crate::heartbeat::HeartbeatTask; +use crate::limiter::LimiterRef; use crate::script::ScriptExecutor; #[async_trait] @@ -124,6 +125,7 @@ pub struct Instance { export_metrics_task: Option, table_metadata_manager: TableMetadataManagerRef, stats: StatementStatistics, + limiter: Option, } impl Instance { diff --git a/src/frontend/src/instance/builder.rs b/src/frontend/src/instance/builder.rs index f24141d8ba2b..eaed2437c489 100644 --- a/src/frontend/src/instance/builder.rs +++ b/src/frontend/src/instance/builder.rs @@ -43,6 +43,7 @@ use crate::frontend::FrontendOptions; use crate::heartbeat::HeartbeatTask; use crate::instance::region_query::FrontendRegionQueryHandler; use crate::instance::Instance; +use crate::limiter::Limiter; use crate::script::ScriptExecutor; /// The frontend [`Instance`] builder. @@ -196,6 +197,14 @@ impl FrontendBuilder { plugins.insert::(statement_executor.clone()); + // Create the limiter if the max_in_flight_write_bytes is set. + let limiter = self + .options + .max_in_flight_write_bytes + .map(|max_in_flight_write_bytes| { + Arc::new(Limiter::new(max_in_flight_write_bytes.as_bytes())) + }); + Ok(Instance { options: self.options, catalog_manager: self.catalog_manager, @@ -211,6 +220,7 @@ impl FrontendBuilder { export_metrics_task: None, table_metadata_manager: Arc::new(TableMetadataManager::new(kv_backend)), stats: self.stats, + limiter, }) } } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index ad225bf30b4e..903a18f97607 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -29,8 +29,8 @@ use snafu::{ensure, OptionExt, ResultExt}; use table::table_name::TableName; use crate::error::{ - Error, IncompleteGrpcRequestSnafu, NotSupportedSnafu, PermissionSnafu, Result, - TableOperationSnafu, + Error, InFlightWriteBytesExceededSnafu, IncompleteGrpcRequestSnafu, NotSupportedSnafu, + PermissionSnafu, Result, TableOperationSnafu, }; use crate::instance::{attach_timer, Instance}; use crate::metrics::{GRPC_HANDLE_PROMQL_ELAPSED, GRPC_HANDLE_SQL_ELAPSED}; @@ -50,6 +50,16 @@ impl GrpcQueryHandler for Instance { .check_permission(ctx.current_user(), PermissionReq::GrpcRequest(&request)) .context(PermissionSnafu)?; + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_request(&request); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + let output = match request { Request::Inserts(requests) => self.handle_inserts(requests, ctx.clone()).await?, Request::RowInserts(requests) => self.handle_row_inserts(requests, ctx.clone()).await?, diff --git a/src/frontend/src/instance/influxdb.rs b/src/frontend/src/instance/influxdb.rs index c337e4174615..864c88e89e14 100644 --- a/src/frontend/src/instance/influxdb.rs +++ b/src/frontend/src/instance/influxdb.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use client::Output; use common_error::ext::BoxedError; -use servers::error::{AuthSnafu, Error}; +use servers::error::{AuthSnafu, Error, InFlightWriteBytesExceededSnafu}; use servers::influxdb::InfluxdbRequest; use servers::interceptor::{LineProtocolInterceptor, LineProtocolInterceptorRef}; use servers::query_handler::InfluxdbLineProtocolHandler; @@ -46,6 +46,16 @@ impl InfluxdbLineProtocolHandler for Instance { .post_lines_conversion(requests, ctx.clone()) .await?; + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&requests); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + self.handle_influx_row_inserts(requests, ctx) .await .map_err(BoxedError::new) diff --git a/src/frontend/src/instance/log_handler.rs b/src/frontend/src/instance/log_handler.rs index 2da2d6717d3b..671caf1de77c 100644 --- a/src/frontend/src/instance/log_handler.rs +++ b/src/frontend/src/instance/log_handler.rs @@ -22,7 +22,8 @@ use common_error::ext::BoxedError; use pipeline::pipeline_operator::PipelineOperator; use pipeline::{GreptimeTransformer, Pipeline, PipelineInfo, PipelineVersion}; use servers::error::{ - AuthSnafu, Error as ServerError, ExecuteGrpcRequestSnafu, PipelineSnafu, Result as ServerResult, + AuthSnafu, Error as ServerError, ExecuteGrpcRequestSnafu, InFlightWriteBytesExceededSnafu, + PipelineSnafu, Result as ServerResult, }; use servers::interceptor::{LogIngestInterceptor, LogIngestInterceptorRef}; use servers::query_handler::PipelineHandler; @@ -110,6 +111,16 @@ impl Instance { log: RowInsertRequests, ctx: QueryContextRef, ) -> ServerResult { + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&log); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + self.inserter .handle_log_inserts(log, ctx, self.statement_executor.as_ref()) .await diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 946c3b9ff7f5..6baf7a440ef2 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -17,7 +17,7 @@ use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use common_error::ext::BoxedError; use common_telemetry::tracing; use servers::error as server_error; -use servers::error::AuthSnafu; +use servers::error::{AuthSnafu, InFlightWriteBytesExceededSnafu}; use servers::opentsdb::codec::DataPoint; use servers::opentsdb::data_point_to_grpc_row_insert_requests; use servers::query_handler::OpentsdbProtocolHandler; @@ -41,6 +41,17 @@ impl OpentsdbProtocolHandler for Instance { .context(AuthSnafu)?; let (requests, _) = data_point_to_grpc_row_insert_requests(data_points)?; + + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&requests); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + let output = self .handle_row_inserts(requests, ctx) .await diff --git a/src/frontend/src/instance/otlp.rs b/src/frontend/src/instance/otlp.rs index f28179d40d59..989c6c4348fc 100644 --- a/src/frontend/src/instance/otlp.rs +++ b/src/frontend/src/instance/otlp.rs @@ -21,7 +21,7 @@ use opentelemetry_proto::tonic::collector::logs::v1::ExportLogsServiceRequest; use opentelemetry_proto::tonic::collector::metrics::v1::ExportMetricsServiceRequest; use opentelemetry_proto::tonic::collector::trace::v1::ExportTraceServiceRequest; use pipeline::PipelineWay; -use servers::error::{self, AuthSnafu, Result as ServerResult}; +use servers::error::{self, AuthSnafu, InFlightWriteBytesExceededSnafu, Result as ServerResult}; use servers::interceptor::{OpenTelemetryProtocolInterceptor, OpenTelemetryProtocolInterceptorRef}; use servers::otlp; use servers::query_handler::OpenTelemetryProtocolHandler; @@ -53,6 +53,16 @@ impl OpenTelemetryProtocolHandler for Instance { let (requests, rows) = otlp::metrics::to_grpc_insert_requests(request)?; OTLP_METRICS_ROWS.inc_by(rows as u64); + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&requests); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + self.handle_row_inserts(requests, ctx) .await .map_err(BoxedError::new) @@ -83,6 +93,16 @@ impl OpenTelemetryProtocolHandler for Instance { OTLP_TRACES_ROWS.inc_by(rows as u64); + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&requests); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + self.handle_log_inserts(requests, ctx) .await .map_err(BoxedError::new) @@ -109,6 +129,17 @@ impl OpenTelemetryProtocolHandler for Instance { interceptor_ref.pre_execute(ctx.clone())?; let (requests, rows) = otlp::logs::to_grpc_insert_requests(request, pipeline, table_name)?; + + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&requests); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + self.handle_log_inserts(requests, ctx) .await .inspect(|_| OTLP_LOGS_ROWS.inc_by(rows as u64)) diff --git a/src/frontend/src/instance/prom_store.rs b/src/frontend/src/instance/prom_store.rs index 8f1098b058f1..9b1a06487c12 100644 --- a/src/frontend/src/instance/prom_store.rs +++ b/src/frontend/src/instance/prom_store.rs @@ -30,7 +30,7 @@ use common_telemetry::{debug, tracing}; use operator::insert::InserterRef; use operator::statement::StatementExecutor; use prost::Message; -use servers::error::{self, AuthSnafu, Result as ServerResult}; +use servers::error::{self, AuthSnafu, InFlightWriteBytesExceededSnafu, Result as ServerResult}; use servers::http::header::{collect_plan_metrics, CONTENT_ENCODING_SNAPPY, CONTENT_TYPE_PROTOBUF}; use servers::http::prom_store::PHYSICAL_TABLE_PARAM; use servers::interceptor::{PromStoreProtocolInterceptor, PromStoreProtocolInterceptorRef}; @@ -175,6 +175,16 @@ impl PromStoreProtocolHandler for Instance { .get::>(); interceptor_ref.pre_write(&request, ctx.clone())?; + let _guard = if let Some(limiter) = &self.limiter { + let result = limiter.limit_row_inserts(&request); + if result.is_none() { + return InFlightWriteBytesExceededSnafu.fail(); + } + result + } else { + None + }; + let output = if with_metric_engine { let physical_table = ctx .extension(PHYSICAL_TABLE_PARAM) diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index de800b0b41c6..e887172797bd 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -18,6 +18,7 @@ pub mod error; pub mod frontend; pub mod heartbeat; pub mod instance; +pub(crate) mod limiter; pub(crate) mod metrics; mod script; pub mod server; diff --git a/src/frontend/src/limiter.rs b/src/frontend/src/limiter.rs new file mode 100644 index 000000000000..62c3dbb7c664 --- /dev/null +++ b/src/frontend/src/limiter.rs @@ -0,0 +1,291 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +use api::v1::column::Values; +use api::v1::greptime_request::Request; +use api::v1::value::ValueData; +use api::v1::{Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequests}; +use common_telemetry::{debug, warn}; + +pub(crate) type LimiterRef = Arc; + +/// A frontend request limiter that controls the total size of in-flight write requests. +pub(crate) struct Limiter { + // The maximum number of bytes that can be in flight. + max_in_flight_write_bytes: u64, + + // The current in-flight write bytes. + in_flight_write_bytes: Arc, +} + +/// A counter for the in-flight write bytes. +pub(crate) struct InFlightWriteBytesCounter { + // The current in-flight write bytes. + in_flight_write_bytes: Arc, + + // The write bytes that are being processed. + processing_write_bytes: u64, +} + +impl InFlightWriteBytesCounter { + /// Creates a new InFlightWriteBytesCounter. It will decrease the in-flight write bytes when dropped. + pub fn new(in_flight_write_bytes: Arc, processing_write_bytes: u64) -> Self { + debug!( + "processing write bytes: {}, current in-flight write bytes: {}", + processing_write_bytes, + in_flight_write_bytes.load(Ordering::Relaxed) + ); + Self { + in_flight_write_bytes, + processing_write_bytes, + } + } +} + +impl Drop for InFlightWriteBytesCounter { + // When the request is finished, the in-flight write bytes should be decreased. + fn drop(&mut self) { + self.in_flight_write_bytes + .fetch_sub(self.processing_write_bytes, Ordering::Relaxed); + } +} + +impl Limiter { + pub fn new(max_in_flight_write_bytes: u64) -> Self { + Self { + max_in_flight_write_bytes, + in_flight_write_bytes: Arc::new(AtomicU64::new(0)), + } + } + + pub fn limit_request(&self, request: &Request) -> Option { + let size = match request { + Request::Inserts(requests) => self.insert_requests_data_size(requests), + Request::RowInserts(requests) => self.rows_insert_requests_data_size(requests), + _ => 0, + }; + self.limit_in_flight_write_bytes(size as u64) + } + + pub fn limit_row_inserts( + &self, + requests: &RowInsertRequests, + ) -> Option { + let size = self.rows_insert_requests_data_size(requests); + self.limit_in_flight_write_bytes(size as u64) + } + + /// Returns None if the in-flight write bytes exceed the maximum limit. + /// Otherwise, returns Some(InFlightWriteBytesCounter) and the in-flight write bytes will be increased. + pub fn limit_in_flight_write_bytes(&self, bytes: u64) -> Option { + let result = self.in_flight_write_bytes.fetch_update( + Ordering::Relaxed, + Ordering::Relaxed, + |current| { + if current + bytes > self.max_in_flight_write_bytes { + warn!( + "in-flight write bytes exceed the maximum limit {}, request with {} bytes will be limited", + self.max_in_flight_write_bytes, + bytes + ); + return None; + } + Some(current + bytes) + }, + ); + + match result { + // Update the in-flight write bytes successfully. + Ok(_) => Some(InFlightWriteBytesCounter::new( + self.in_flight_write_bytes.clone(), + bytes, + )), + // It means the in-flight write bytes exceed the maximum limit. + Err(_) => None, + } + } + + /// Returns the current in-flight write bytes. + #[allow(dead_code)] + pub fn in_flight_write_bytes(&self) -> u64 { + self.in_flight_write_bytes.load(Ordering::Relaxed) + } + + fn insert_requests_data_size(&self, request: &InsertRequests) -> usize { + let mut size: usize = 0; + for insert in &request.inserts { + for column in &insert.columns { + if let Some(values) = &column.values { + size += self.size_of_column_values(values); + } + } + } + size + } + + fn rows_insert_requests_data_size(&self, request: &RowInsertRequests) -> usize { + let mut size: usize = 0; + for insert in &request.inserts { + if let Some(rows) = &insert.rows { + for row in &rows.rows { + for value in &row.values { + if let Some(value) = &value.value_data { + size += self.size_of_value_data(value); + } + } + } + } + } + size + } + + fn size_of_column_values(&self, values: &Values) -> usize { + let mut size: usize = 0; + size += values.i8_values.len() * size_of::(); + size += values.i16_values.len() * size_of::(); + size += values.i32_values.len() * size_of::(); + size += values.i64_values.len() * size_of::(); + size += values.u8_values.len() * size_of::(); + size += values.u16_values.len() * size_of::(); + size += values.u32_values.len() * size_of::(); + size += values.u64_values.len() * size_of::(); + size += values.f32_values.len() * size_of::(); + size += values.f64_values.len() * size_of::(); + size += values.bool_values.len() * size_of::(); + size += values + .binary_values + .iter() + .map(|v| v.len() * size_of::()) + .sum::(); + size += values.string_values.iter().map(|v| v.len()).sum::(); + size += values.date_values.len() * size_of::(); + size += values.datetime_values.len() * size_of::(); + size += values.timestamp_second_values.len() * size_of::(); + size += values.timestamp_millisecond_values.len() * size_of::(); + size += values.timestamp_microsecond_values.len() * size_of::(); + size += values.timestamp_nanosecond_values.len() * size_of::(); + size += values.time_second_values.len() * size_of::(); + size += values.time_millisecond_values.len() * size_of::(); + size += values.time_microsecond_values.len() * size_of::(); + size += values.time_nanosecond_values.len() * size_of::(); + size += values.interval_year_month_values.len() * size_of::(); + size += values.interval_day_time_values.len() * size_of::(); + size += values.interval_month_day_nano_values.len() * size_of::(); + size += values.decimal128_values.len() * size_of::(); + size + } + + fn size_of_value_data(&self, value: &ValueData) -> usize { + match value { + ValueData::I8Value(_) => size_of::(), + ValueData::I16Value(_) => size_of::(), + ValueData::I32Value(_) => size_of::(), + ValueData::I64Value(_) => size_of::(), + ValueData::U8Value(_) => size_of::(), + ValueData::U16Value(_) => size_of::(), + ValueData::U32Value(_) => size_of::(), + ValueData::U64Value(_) => size_of::(), + ValueData::F32Value(_) => size_of::(), + ValueData::F64Value(_) => size_of::(), + ValueData::BoolValue(_) => size_of::(), + ValueData::BinaryValue(v) => v.len() * size_of::(), + ValueData::StringValue(v) => v.len(), + ValueData::DateValue(_) => size_of::(), + ValueData::DatetimeValue(_) => size_of::(), + ValueData::TimestampSecondValue(_) => size_of::(), + ValueData::TimestampMillisecondValue(_) => size_of::(), + ValueData::TimestampMicrosecondValue(_) => size_of::(), + ValueData::TimestampNanosecondValue(_) => size_of::(), + ValueData::TimeSecondValue(_) => size_of::(), + ValueData::TimeMillisecondValue(_) => size_of::(), + ValueData::TimeMicrosecondValue(_) => size_of::(), + ValueData::TimeNanosecondValue(_) => size_of::(), + ValueData::IntervalYearMonthValue(_) => size_of::(), + ValueData::IntervalDayTimeValue(_) => size_of::(), + ValueData::IntervalMonthDayNanoValue(_) => size_of::(), + ValueData::Decimal128Value(_) => size_of::(), + } + } +} + +#[cfg(test)] +mod tests { + use api::v1::column::Values; + use api::v1::greptime_request::Request; + use api::v1::{Column, InsertRequest}; + + use super::*; + + fn generate_request(size: usize) -> Request { + let i8_values = vec![0; size]; + Request::Inserts(InsertRequests { + inserts: vec![InsertRequest { + columns: vec![Column { + values: Some(Values { + i8_values, + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }], + }) + } + + #[tokio::test] + async fn test_limiter() { + let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024)); + let tasks_count = 10; + let request_data_size = 100; + let mut handles = vec![]; + + // Generate multiple requests to test the limiter. + for _ in 0..tasks_count { + let limiter = limiter_ref.clone(); + let handle = tokio::spawn(async move { + let result = limiter.limit_request(&generate_request(request_data_size)); + assert!(result.is_some()); + }); + handles.push(handle); + } + + // Wait for all threads to complete. + for handle in handles { + handle.await.unwrap(); + } + } + + #[test] + fn test_in_flight_write_bytes() { + let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024)); + let req1 = generate_request(100); + let result1 = limiter_ref.limit_request(&req1); + assert!(result1.is_some()); + assert_eq!(limiter_ref.in_flight_write_bytes(), 100); + + let req2 = generate_request(200); + let result2 = limiter_ref.limit_request(&req2); + assert!(result2.is_some()); + assert_eq!(limiter_ref.in_flight_write_bytes(), 300); + + drop(result1.unwrap()); + assert_eq!(limiter_ref.in_flight_write_bytes(), 200); + + drop(result2.unwrap()); + assert_eq!(limiter_ref.in_flight_write_bytes(), 0); + } +} diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 071de93683cc..c1c331c33744 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -589,6 +589,12 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("In-flight write bytes exceeded the maximum limit"))] + InFlightWriteBytesExceeded { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -706,6 +712,8 @@ impl ErrorExt for Error { ToJson { .. } => StatusCode::Internal, ConvertSqlValue { source, .. } => source.status_code(), + + InFlightWriteBytesExceeded { .. } => StatusCode::RateLimited, } }