Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add strict mode to validate protocol strings #3638

Merged
merged 14 commits into from
Apr 15, 2024
6 changes: 5 additions & 1 deletion src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ where

if opts.prom_store.enable {
builder = builder
.with_prom_handler(self.instance.clone(), opts.prom_store.with_metric_engine)
.with_prom_handler(
self.instance.clone(),
opts.prom_store.with_metric_engine,
opts.http.strict_mode,
)
.with_prometheus_handler(self.instance.clone());
}

Expand Down
21 changes: 18 additions & 3 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ pub struct HttpOptions {
pub disable_dashboard: bool,

pub body_limit: ReadableSize,

pub strict_mode: bool,
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
}

impl Default for HttpOptions {
Expand All @@ -136,6 +138,7 @@ impl Default for HttpOptions {
timeout: Duration::from_secs(30),
disable_dashboard: false,
body_limit: DEFAULT_BODY_LIMIT,
strict_mode: false,
}
}
}
Expand Down Expand Up @@ -502,11 +505,12 @@ impl HttpServerBuilder {
self,
handler: PromStoreProtocolHandlerRef,
prom_store_with_metric_engine: bool,
strict_mode: bool,
) -> Self {
Self {
router: self.router.nest(
&format!("/{HTTP_API_VERSION}/prometheus"),
HttpServer::route_prom(handler, prom_store_with_metric_engine),
HttpServer::route_prom(handler, prom_store_with_metric_engine, strict_mode),
),
..self
}
Expand Down Expand Up @@ -698,15 +702,26 @@ impl HttpServer {
fn route_prom<S>(
prom_handler: PromStoreProtocolHandlerRef,
prom_store_with_metric_engine: bool,
strict_mode: bool,
) -> Router<S> {
let mut router = Router::new().route("/read", routing::post(prom_store::remote_read));
if prom_store_with_metric_engine {
if prom_store_with_metric_engine && strict_mode {
etolbakov marked this conversation as resolved.
Show resolved Hide resolved
router = router.route("/write", routing::post(prom_store::remote_write));
} else {
} else if prom_store_with_metric_engine && !strict_mode {
router = router.route(
"/write",
routing::post(prom_store::remote_write_without_strict_mode),
);
} else if !prom_store_with_metric_engine && strict_mode {
router = router.route(
"/write",
routing::post(prom_store::route_write_without_metric_engine),
);
} else {
router = router.route(
"/write",
routing::post(prom_store::route_write_without_metric_engine_and_strict_mode),
);
etolbakov marked this conversation as resolved.
Show resolved Hide resolved
}
router.with_state(prom_handler)
}
Expand Down
73 changes: 68 additions & 5 deletions src/servers/src/http/prom_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,34 @@ pub async fn route_write_without_metric_engine(
.with_label_values(&[db.as_str()])
.start_timer();

let (request, samples) = decode_remote_write_request(body).await?;
let (request, samples) = decode_remote_write_request(body, true).await?;
// reject if physical table is specified when metric engine is disabled
if params.physical_table.is_some() {
return UnexpectedPhysicalTableSnafu {}.fail();
}

let output = handler.write(request, query_ctx, false).await?;
crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64);
Ok((
StatusCode::NO_CONTENT,
write_cost_header_map(output.meta.cost),
))
}

/// Same with [remote_write] but won't store data to metric engine.
#[axum_macros::debug_handler]
pub async fn route_write_without_metric_engine_and_strict_mode(
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<DatabaseQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
RawBody(body): RawBody,
) -> Result<impl IntoResponse> {
let db = params.db.clone().unwrap_or_default();
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

let (request, samples) = decode_remote_write_request(body, false).await?;
// reject if physical table is specified when metric engine is disabled
if params.physical_table.is_some() {
return UnexpectedPhysicalTableSnafu {}.fail();
Expand Down Expand Up @@ -105,7 +132,7 @@ pub async fn remote_write(
.with_label_values(&[db.as_str()])
.start_timer();

let (request, samples) = decode_remote_write_request_to_row_inserts(body).await?;
let (request, samples) = decode_remote_write_request_to_row_inserts(body, true).await?;

if let Some(physical_table) = params.physical_table {
let mut new_query_ctx = query_ctx.as_ref().clone();
Expand All @@ -121,6 +148,38 @@ pub async fn remote_write(
))
}

#[axum_macros::debug_handler]
#[tracing::instrument(
skip_all,
fields(protocol = "prometheus", request_type = "remote_write")
)]
pub async fn remote_write_without_strict_mode(
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<DatabaseQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
RawBody(body): RawBody,
) -> Result<impl IntoResponse> {
let db = params.db.clone().unwrap_or_default();
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
let (request, samples) = decode_remote_write_request_to_row_inserts(body, true).await?;

if let Some(physical_table) = params.physical_table {
let mut new_query_ctx = query_ctx.as_ref().clone();
new_query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table);
query_ctx = Arc::new(new_query_ctx);
}

let output = handler.write(request, query_ctx, false).await?;
crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64);
Ok((
StatusCode::NO_CONTENT,
write_cost_header_map(output.meta.cost),
))
}

impl IntoResponse for PromStoreResponse {
fn into_response(self) -> axum::response::Response {
let mut header_map = HeaderMap::new();
Expand Down Expand Up @@ -163,6 +222,7 @@ pub async fn remote_read(

async fn decode_remote_write_request_to_row_inserts(
body: Body,
strict_mode: bool,
) -> Result<(RowInsertRequests, usize)> {
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_DECODE_ELAPSED.start_timer();
let body = hyper::body::to_bytes(body)
Expand All @@ -173,12 +233,15 @@ async fn decode_remote_write_request_to_row_inserts(

let mut request = PROM_WRITE_REQUEST_POOL.pull(PromWriteRequest::default);
request
.merge(buf)
.merge(buf, strict_mode)
.context(error::DecodePromRemoteRequestSnafu)?;
Ok(request.as_row_insert_requests())
}

async fn decode_remote_write_request(body: Body) -> Result<(RowInsertRequests, usize)> {
async fn decode_remote_write_request(
body: Body,
strict_mode: bool,
) -> Result<(RowInsertRequests, usize)> {
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_DECODE_ELAPSED.start_timer();
let body = hyper::body::to_bytes(body)
.await
Expand All @@ -188,7 +251,7 @@ async fn decode_remote_write_request(body: Body) -> Result<(RowInsertRequests, u

let mut request = PromWriteRequest::default();
request
.merge(buf)
.merge(buf, strict_mode)
.context(error::DecodePromRemoteRequestSnafu)?;
Ok(request.as_row_insert_requests())
}
Expand Down
37 changes: 30 additions & 7 deletions src/servers/src/prom_row_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use api::v1::{
use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use prost::DecodeError;

use crate::proto::PromLabel;
use crate::repeated_field::Clear;
Expand Down Expand Up @@ -118,13 +119,31 @@ impl TableBuilder {
}

/// Adds a set of labels and samples to table builder.
pub(crate) fn add_labels_and_samples(&mut self, labels: &[PromLabel], samples: &[Sample]) {
pub(crate) fn add_labels_and_samples(
&mut self,
labels: &[PromLabel],
samples: &[Sample],
strict_mode: bool,
) -> Result<(), DecodeError> {
let mut row = vec![Value { value_data: None }; self.col_indexes.len()];

for PromLabel { name, value } in labels {
// safety: we expect all labels are UTF-8 encoded strings.
let tag_name = unsafe { String::from_utf8_unchecked(name.to_vec()) };
let tag_value = unsafe { String::from_utf8_unchecked(value.to_vec()) };
let tag_name;
let tag_value;
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
if strict_mode {
etolbakov marked this conversation as resolved.
Show resolved Hide resolved
tag_name = match String::from_utf8(name.to_vec()) {
Ok(s) => s,
Err(_) => return Err(DecodeError::new("invalid utf-8")),
};
tag_value = match String::from_utf8(value.to_vec()) {
Ok(s) => s,
Err(_) => return Err(DecodeError::new("invalid utf-8")),
};
} else {
tag_name = unsafe { String::from_utf8_unchecked(name.to_vec()) };
tag_value = unsafe { String::from_utf8_unchecked(value.to_vec()) };
}

let tag_value = Some(ValueData::StringValue(tag_value));
let tag_num = self.col_indexes.len();

Expand Down Expand Up @@ -153,7 +172,7 @@ impl TableBuilder {
row[0].value_data = Some(ValueData::TimestampMillisecondValue(sample.timestamp));
row[1].value_data = Some(ValueData::F64Value(sample.value));
self.rows.push(Row { values: row });
return;
return Ok(());
}
for sample in samples {
row[0].value_data = Some(ValueData::TimestampMillisecondValue(sample.timestamp));
Expand All @@ -162,6 +181,8 @@ impl TableBuilder {
values: row.clone(),
});
}

Ok(())
}

/// Converts [TableBuilder] to [RowInsertRequest] and clears buffered data.
Expand Down Expand Up @@ -194,7 +215,7 @@ mod tests {
#[test]
fn test_table_builder() {
let mut builder = TableBuilder::default();
builder.add_labels_and_samples(
let _ = builder.add_labels_and_samples(
&[
PromLabel {
name: Bytes::from("tag0"),
Expand All @@ -209,9 +230,10 @@ mod tests {
value: 0.0,
timestamp: 0,
}],
true,
etolbakov marked this conversation as resolved.
Show resolved Hide resolved
);

builder.add_labels_and_samples(
let _ = builder.add_labels_and_samples(
&[
PromLabel {
name: Bytes::from("tag0"),
Expand All @@ -226,6 +248,7 @@ mod tests {
value: 0.1,
timestamp: 1,
}],
true
);

let request = builder.as_row_insert_request("test".to_string());
Expand Down
29 changes: 21 additions & 8 deletions src/servers/src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl PromTimeSeries {
tag: u32,
wire_type: WireType,
buf: &mut Bytes,
strict_mode: bool,
) -> Result<(), DecodeError> {
const STRUCT_NAME: &str = "PromTimeSeries";
match tag {
Expand All @@ -175,8 +176,14 @@ impl PromTimeSeries {
return Err(DecodeError::new("delimited length exceeded"));
}
if label.name.deref() == METRIC_NAME_LABEL_BYTES {
// safety: we expect all labels are UTF-8 encoded strings.
let table_name = unsafe { String::from_utf8_unchecked(label.value.to_vec()) };
let table_name = if strict_mode {
match String::from_utf8(label.value.to_vec()) {
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
Ok(s) => s,
Err(_) => return Err(DecodeError::new("invalid utf-8")),
}
} else {
unsafe { String::from_utf8_unchecked(label.value.to_vec()) }
};
self.table_name = table_name;
self.labels.truncate(self.labels.len() - 1); // remove last label
}
Expand All @@ -198,15 +205,19 @@ impl PromTimeSeries {
}
}

fn add_to_table_data(&mut self, table_builders: &mut TablesBuilder) {
fn add_to_table_data(&mut self, table_builders: &mut TablesBuilder, strict_mode: bool) {
let label_num = self.labels.len();
let row_num = self.samples.len();
let table_data = table_builders.get_or_create_table_builder(
std::mem::take(&mut self.table_name),
label_num,
row_num,
);
table_data.add_labels_and_samples(self.labels.as_slice(), self.samples.as_slice());
let _ = table_data.add_labels_and_samples(
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
self.labels.as_slice(),
self.samples.as_slice(),
strict_mode,
);
self.labels.clear();
self.samples.clear();
}
Expand All @@ -230,7 +241,7 @@ impl PromWriteRequest {
}

// todo(hl): maybe use &[u8] can reduce the overhead introduced with Bytes.
pub fn merge(&mut self, mut buf: Bytes) -> Result<(), DecodeError> {
pub fn merge(&mut self, mut buf: Bytes, strict_mode: bool) -> Result<(), DecodeError> {
const STRUCT_NAME: &str = "PromWriteRequest";
while buf.has_remaining() {
let (tag, wire_type) = decode_key(&mut buf)?;
Expand All @@ -250,12 +261,14 @@ impl PromWriteRequest {
let limit = remaining - len as usize;
while buf.remaining() > limit {
let (tag, wire_type) = decode_key(&mut buf)?;
self.series.merge_field(tag, wire_type, &mut buf)?;
self.series
.merge_field(tag, wire_type, &mut buf, strict_mode)?;
}
if buf.remaining() != limit {
return Err(DecodeError::new("delimited length exceeded"));
}
self.series.add_to_table_data(&mut self.table_data);
self.series
.add_to_table_data(&mut self.table_data, strict_mode);
}
3u32 => {
// todo(hl): metadata are skipped.
Expand Down Expand Up @@ -303,7 +316,7 @@ mod tests {
expected_rows: &RowInsertRequests,
) {
prom_write_request.clear();
prom_write_request.merge(data.clone()).unwrap();
prom_write_request.merge(data.clone(), true).unwrap();
let (prom_rows, samples) = prom_write_request.as_row_insert_requests();

assert_eq!(expected_samples, samples);
Expand Down
2 changes: 1 addition & 1 deletion src/servers/tests/http/prom_store_test.rs
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
let instance = Arc::new(DummyInstance { tx });
let server = HttpServerBuilder::new(http_opts)
.with_sql_handler(instance.clone(), None)
.with_prom_handler(instance, true)
.with_prom_handler(instance, true, true)
.build();
server.build(server.make_app())
}
Expand Down
2 changes: 1 addition & 1 deletion tests-integration/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ pub async fn setup_test_prom_app_with_frontend(
ServerSqlQueryHandlerAdapter::arc(frontend_ref.clone()),
Some(frontend_ref.clone()),
)
.with_prom_handler(frontend_ref.clone(), true)
.with_prom_handler(frontend_ref.clone(), true, true)
.with_prometheus_handler(frontend_ref)
.with_greptime_config_options(instance.mix_options.datanode.to_toml_string())
.build();
Expand Down
Loading