From 8786624515b85cdcf238ae88f92b7cfbb80abab0 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 19 Sep 2024 13:30:56 +0800 Subject: [PATCH] feat: improve support for postgres extended protocol (#4721) * feat: improve support for postgres extended protocol * fix: lint fix * fix: test code * fix: adopt upstream * refactor: remove dup code * refactor: avoid copy on error message --- src/query/src/datafusion.rs | 18 +- src/servers/src/postgres/fixtures.rs | 22 +- src/servers/src/postgres/handler.rs | 49 ++- src/servers/src/postgres/types.rs | 471 ++++++++++++++------------- 4 files changed, 313 insertions(+), 247 deletions(-) diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 03eadfde970d..6ed5844de09b 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -398,11 +398,19 @@ impl QueryEngine for DatafusionQueryEngine { query_ctx: QueryContextRef, ) -> Result { let ctx = self.engine_context(query_ctx); - let optimised_plan = self.optimize(&ctx, &plan)?; - Ok(DescribeResult { - schema: optimised_plan.schema()?, - logical_plan: optimised_plan, - }) + if let Ok(optimised_plan) = self.optimize(&ctx, &plan) { + Ok(DescribeResult { + schema: optimised_plan.schema()?, + logical_plan: optimised_plan, + }) + } else { + // Table's like those in information_schema cannot be optimized when + // it contains parameters. So we fallback to original plans. + Ok(DescribeResult { + schema: plan.schema()?, + logical_plan: plan, + }) + } } async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { diff --git a/src/servers/src/postgres/fixtures.rs b/src/servers/src/postgres/fixtures.rs index 5b02480da941..18c3661b9334 100644 --- a/src/servers/src/postgres/fixtures.rs +++ b/src/servers/src/postgres/fixtures.rs @@ -54,17 +54,19 @@ static SET_TRANSACTION_PATTERN: Lazy = static TRANSACTION_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(BEGIN|ROLLBACK|COMMIT);?").unwrap()); +/// Test if given query statement matches the patterns +pub(crate) fn matches(query: &str) -> bool { + TRANSACTION_PATTERN.captures(query).is_some() + || SHOW_PATTERN.captures(query).is_some() + || SET_TRANSACTION_PATTERN.is_match(query) +} + /// Process unsupported SQL and return fixed result as a compatibility solution -pub(crate) fn process<'a>( - query: &str, - _query_ctx: QueryContextRef, -) -> Option>>> { +pub(crate) fn process<'a>(query: &str, _query_ctx: QueryContextRef) -> Option>> { // Transaction directives: if let Some(tx) = TRANSACTION_PATTERN.captures(query) { let tx_tag = &tx[1]; - Some(Ok(vec![Response::Execution(Tag::new( - &tx_tag.to_uppercase(), - ))])) + Some(vec![Response::Execution(Tag::new(&tx_tag.to_uppercase()))]) } else if let Some(show_var) = SHOW_PATTERN.captures(query) { let show_var = show_var[1].to_lowercase(); if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) { @@ -81,12 +83,12 @@ pub(crate) fn process<'a>( vec![vec![value.to_string()]], )); - Some(Ok(vec![Response::Query(QueryResponse::new(schema, data))])) + Some(vec![Response::Query(QueryResponse::new(schema, data))]) } else { None } } else if SET_TRANSACTION_PATTERN.is_match(query) { - Some(Ok(vec![Response::Execution(Tag::new("SET"))])) + Some(vec![Response::Execution(Tag::new("SET"))]) } else { None } @@ -101,7 +103,6 @@ mod test { fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) { if let Response::Execution(tag) = process(q, query_context.clone()) .unwrap_or_else(|| panic!("fail to match {}", q)) - .expect("unexpected error") .remove(0) { assert_eq!(Tag::new(t), tag); @@ -113,7 +114,6 @@ mod test { fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> { if let Response::Query(resp) = process(q, query_context.clone()) .unwrap_or_else(|| panic!("fail to match {}", q)) - .expect("unexpected error") .remove(0) { resp diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 190684ed34fc..53d907d814db 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -59,8 +59,13 @@ impl SimpleQueryHandler for PostgresServerHandler { .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()]) .start_timer(); + if query.is_empty() { + // early return if query is empty + return Ok(vec![Response::EmptyQuery]); + } + if let Some(resps) = fixtures::process(query, query_ctx.clone()) { - resps + Ok(resps) } else { let outputs = self.query_handler.do_query(query, query_ctx.clone()).await; @@ -184,6 +189,16 @@ impl QueryParser for DefaultQueryParser { async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc(); let query_ctx = self.session.new_query_context(); + + // do not parse if query is empty or matches rules + if sql.is_empty() || fixtures::matches(sql) { + return Ok(SqlPlan { + query: sql.to_owned(), + plan: None, + schema: None, + }); + } + let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default()) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -193,6 +208,7 @@ impl QueryParser for DefaultQueryParser { )))) } else { let stmt = stmts.remove(0); + let describe_result = self .query_handler .do_describe(stmt, query_ctx) @@ -244,6 +260,16 @@ impl ExtendedQueryHandler for PostgresServerHandler { let sql_plan = &portal.statement.statement; + if sql_plan.query.is_empty() { + // early return if query is empty + return Ok(Response::EmptyQuery); + } + + if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) { + // if the statement matches our predefined rules, return it early + return Ok(resps.remove(0)); + } + let output = if let Some(plan) = &sql_plan.plan { let plan = plan .replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref()) @@ -297,6 +323,17 @@ impl ExtendedQueryHandler for PostgresServerHandler { .map(|fields| DescribeStatementResponse::new(param_types, fields)) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { + if let Some(mut resp) = + fixtures::process(&sql_plan.query, self.session.new_query_context()) + { + if let Response::Query(query_response) = resp.remove(0) { + return Ok(DescribeStatementResponse::new( + param_types, + (*query_response.row_schema()).clone(), + )); + } + } + Ok(DescribeStatementResponse::new(param_types, vec![])) } } @@ -317,6 +354,16 @@ impl ExtendedQueryHandler for PostgresServerHandler { .map(DescribePortalResponse::new) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { + if let Some(mut resp) = + fixtures::process(&sql_plan.query, self.session.new_query_context()) + { + if let Response::Query(query_response) = resp.remove(0) { + return Ok(DescribePortalResponse::new( + (*query_response.row_schema()).clone(), + )); + } + } + Ok(DescribePortalResponse::new(vec![])) } } diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 2bec6c2999f5..9f9d94905e4a 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -239,14 +239,14 @@ pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWir .unwrap_or_else(|| "".to_owned())), _ => Err(invalid_parameter_error( "unsupported_parameter_type", - Some(¶m_type.to_string()), + Some(param_type.to_string()), )), } } -pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError { +pub(super) fn invalid_parameter_error(msg: &str, detail: Option) -> PgWireError { let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string()); - error_info.detail = detail.map(|s| s.to_owned()); + error_info.detail = detail; PgWireError::UserError(Box::new(error_info)) } @@ -279,303 +279,314 @@ pub(super) fn parameters_to_scalar_values( .get_param_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - // ensure parameter count consistent for: client parameter types, server - // parameter types and parameter count - if param_types.len() != param_count { - return Err(invalid_parameter_error( - "invalid_parameter_count", - Some(&format!( - "Expected: {}, found: {}", - param_types.len(), - param_count - )), - )); - } - for idx in 0..param_count { - let server_type = - if let Some(Some(server_infer_type)) = param_types.get(&format!("${}", idx + 1)) { - server_infer_type - } else { - // at the moment we require type information inferenced by - // server so here we return error if the type is unknown from - // server-side. - // - // It might be possible to parse the parameter just using client - // specified type, we will implement that if there is a case. - return Err(invalid_parameter_error("unknown_parameter_type", None)); - }; + let server_type = param_types + .get(&format!("${}", idx + 1)) + .and_then(|t| t.as_ref()); let client_type = if let Some(client_given_type) = client_param_types.get(idx) { client_given_type.clone() + } else if let Some(server_provided_type) = &server_type { + type_gt_to_pg(server_provided_type).map_err(|e| PgWireError::ApiError(Box::new(e)))? } else { - type_gt_to_pg(server_type).map_err(|e| PgWireError::ApiError(Box::new(e)))? + return Err(invalid_parameter_error( + "unknown_parameter_type", + Some(format!( + "Cannot get parameter type information for parameter {}", + idx + )), + )); }; let value = match &client_type { &Type::VARCHAR | &Type::TEXT => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::String(_) => ScalarValue::Utf8(data), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::String(_) => ScalarValue::Utf8(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Utf8(data) } } &Type::BOOL => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Boolean(data) } } &Type::INT2 => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Timestamp(unit) => { - to_timestamp_scalar_value(data, unit, server_type)? - } - ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => { + ScalarValue::Date64(data.map(|d| d as i64)) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Int16(data) } } &Type::INT4 => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Timestamp(unit) => { - to_timestamp_scalar_value(data, unit, server_type)? - } - ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => { + ScalarValue::Date64(data.map(|d| d as i64)) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Int32(data) } } &Type::INT8 => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Timestamp(unit) => { - to_timestamp_scalar_value(data, unit, server_type)? - } - ConcreteDataType::DateTime(_) => ScalarValue::Date64(data), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => ScalarValue::Date64(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Int64(data) } } &Type::FLOAT4 => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Float32(_) => ScalarValue::Float32(data), - ConcreteDataType::Float64(_) => ScalarValue::Float64(data.map(|n| n as f64)), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Float32(_) => ScalarValue::Float32(data), + ConcreteDataType::Float64(_) => { + ScalarValue::Float64(data.map(|n| n as f64)) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Float32(data) } } &Type::FLOAT8 => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Float32(_) => ScalarValue::Float32(data.map(|n| n as f32)), - ConcreteDataType::Float64(_) => ScalarValue::Float64(data), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Float32(_) => { + ScalarValue::Float32(data.map(|n| n as f32)) + } + ConcreteDataType::Float64(_) => ScalarValue::Float64(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::Float64(data) } } &Type::TIMESTAMP => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Timestamp(unit) => match *unit { - TimestampType::Second(_) => ScalarValue::TimestampSecond( - data.map(|ts| ts.and_utc().timestamp()), - None, - ), - TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond( - data.map(|ts| ts.and_utc().timestamp_millis()), - None, - ), - TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond( - data.map(|ts| ts.and_utc().timestamp_micros()), - None, - ), - TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond( - data.map(|ts| ts.and_utc().timestamp_micros()), - None, - ), - }, - ConcreteDataType::DateTime(_) => { - ScalarValue::Date64(data.map(|d| d.and_utc().timestamp_millis())) - } - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )) + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Timestamp(unit) => match *unit { + TimestampType::Second(_) => ScalarValue::TimestampSecond( + data.map(|ts| ts.and_utc().timestamp()), + None, + ), + TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond( + data.map(|ts| ts.and_utc().timestamp_millis()), + None, + ), + TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond( + data.map(|ts| ts.and_utc().timestamp_micros()), + None, + ), + TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond( + data.map(|ts| ts.and_utc().timestamp_micros()), + None, + ), + }, + ConcreteDataType::DateTime(_) => { + ScalarValue::Date64(data.map(|d| d.and_utc().timestamp_millis())) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )) + } } + } else { + ScalarValue::TimestampMillisecond( + data.map(|ts| ts.and_utc().timestamp_millis()), + None, + ) } } &Type::DATE => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| { - (d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32 - })), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )); + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| { + (d - NaiveDate::from(NaiveDateTime::UNIX_EPOCH)).num_days() as i32 + })), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )); + } } + } else { + ScalarValue::Date32(data.map(|d| { + (d - NaiveDate::from(NaiveDateTime::UNIX_EPOCH)).num_days() as i32 + })) } } &Type::INTERVAL => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Interval(_) => { - ScalarValue::IntervalMonthDayNano(data.map(|i| Interval::from(i).to_i128())) - } - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )); + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Interval(_) => ScalarValue::IntervalMonthDayNano( + data.map(|i| Interval::from(i).to_i128()), + ), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )); + } } + } else { + ScalarValue::IntervalMonthDayNano(data.map(|i| Interval::from(i).to_i128())) } } &Type::BYTEA => { let data = portal.parameter::>(idx, &client_type)?; - match server_type { - ConcreteDataType::String(_) => { - ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string())) - } - ConcreteDataType::Binary(_) => ScalarValue::Binary(data), - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )); + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::String(_) => { + ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string())) + } + ConcreteDataType::Binary(_) => ScalarValue::Binary(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )); + } } + } else { + ScalarValue::Binary(data) } } &Type::JSONB => { let data = portal.parameter::(idx, &client_type)?; - match server_type { - ConcreteDataType::Binary(_) => { - ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec())) - } - _ => { - return Err(invalid_parameter_error( - "invalid_parameter_type", - Some(&format!( - "Expected: {}, found: {}", - server_type, client_type - )), - )); + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Binary(_) => { + ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec())) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )); + } } + } else { + ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec())) } } _ => Err(invalid_parameter_error( "unsupported_parameter_value", - Some(&format!("Found type: {}", client_type)), + Some(format!("Found type: {}", client_type)), ))?, };