Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: returning warning instead of error on unsupported SET statement #4761

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use datafusion_expr::LogicalPlan;
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
Expand Down Expand Up @@ -338,10 +338,18 @@ impl StatementExecutor {

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
// of connection establishment
//
if query_ctx.channel() == Channel::Postgres {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name));
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
.fail()
}
}
Ok(Output::new_with_affected_rows(0))
Expand Down
42 changes: 37 additions & 5 deletions src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -23,7 +24,7 @@ use common_telemetry::{debug, error, tracing};
use datafusion_common::ParamValues;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::SchemaRef;
use futures::{future, stream, Stream, StreamExt};
use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
Expand All @@ -32,6 +33,7 @@ use pgwire::api::results::{
use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use query::query_engine::DescribeResult;
use session::context::QueryContextRef;
use session::Session;
Expand All @@ -49,11 +51,13 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
#[tracing::instrument(skip_all, fields(protocol = "postgres"))]
async fn do_query<'a, C>(
&self,
_client: &mut C,
client: &mut C,
query: &'a str,
) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let query_ctx = self.session.new_query_context();
let db = query_ctx.get_db_string();
Expand All @@ -67,6 +71,7 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
}

if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
Ok(resps)
} else {
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
Expand All @@ -79,11 +84,34 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
results.push(resp);
}

send_warning_opt(client, query_ctx).await?;
Ok(results)
}
}
}

async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
where
C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if let Some(warning) = query_context.warning() {
client
.feed(PgWireBackendMessage::NoticeResponse(
ErrorInfo::new(
PgErrorSeverity::Warning.to_string(),
PgErrorCode::Ec01000.code(),
warning.to_string(),
)
.into(),
))
.await?;
}

Ok(())
}

pub(crate) fn output_to_query_response<'a>(
query_ctx: QueryContextRef,
output: Result<Output>,
Expand Down Expand Up @@ -247,12 +275,14 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {

async fn do_query<'a, C>(
&self,
_client: &mut C,
client: &mut C,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let query_ctx = self.session.new_query_context();
let db = query_ctx.get_db_string();
Expand All @@ -268,6 +298,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
}

if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
}
Expand Down Expand Up @@ -297,6 +328,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
.remove(0)
};

send_warning_opt(client, query_ctx.clone()).await?;
output_to_query_response(query_ctx, output, &portal.result_column_format)
}

Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use session::session_config::PGByteaOutputValue;

use self::bytea::{EscapeOutputBytea, HexOutputBytea};
use self::datetime::{StylingDate, StylingDateTime};
pub use self::error::PgErrorCode;
pub use self::error::{PgErrorCode, PgErrorSeverity};
use self::interval::PgInterval;
use crate::error::{self as server_error, Error, Result};
use crate::SqlPlan;
Expand Down
40 changes: 20 additions & 20 deletions src/servers/src/postgres/types/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use strum::{AsRefStr, Display, EnumIter, EnumMessage};

#[derive(Display, Debug, PartialEq)]
#[allow(dead_code)]
enum ErrorSeverity {
pub enum PgErrorSeverity {
#[strum(serialize = "INFO")]
Info,
#[strum(serialize = "DEBUG")]
Expand Down Expand Up @@ -335,23 +335,23 @@ pub enum PgErrorCode {
}

impl PgErrorCode {
fn severity(&self) -> ErrorSeverity {
fn severity(&self) -> PgErrorSeverity {
match self {
PgErrorCode::Ec00000 => ErrorSeverity::Info,
PgErrorCode::Ec01000 => ErrorSeverity::Warning,
PgErrorCode::Ec00000 => PgErrorSeverity::Info,
PgErrorCode::Ec01000 => PgErrorSeverity::Warning,

PgErrorCode::EcXX000 | PgErrorCode::Ec42P14 | PgErrorCode::Ec22023 => {
ErrorSeverity::Error
PgErrorSeverity::Error
}
PgErrorCode::Ec28000 | PgErrorCode::Ec28P01 | PgErrorCode::Ec3D000 => {
ErrorSeverity::Fatal
PgErrorSeverity::Fatal
}

_ => ErrorSeverity::Error,
_ => PgErrorSeverity::Error,
}
}

fn code(&self) -> String {
pub(crate) fn code(&self) -> String {
self.as_ref()[2..].to_string()
}

Expand Down Expand Up @@ -428,33 +428,33 @@ mod tests {
use common_error::status_code::StatusCode;
use strum::{EnumMessage, IntoEnumIterator};

use super::{ErrorInfo, ErrorSeverity, PgErrorCode};
use super::{ErrorInfo, PgErrorCode, PgErrorSeverity};

#[test]
fn test_error_severity() {
// test for ErrorSeverity enum
assert_eq!("INFO", ErrorSeverity::Info.to_string());
assert_eq!("DEBUG", ErrorSeverity::Debug.to_string());
assert_eq!("NOTICE", ErrorSeverity::Notice.to_string());
assert_eq!("WARNING", ErrorSeverity::Warning.to_string());
assert_eq!("INFO", PgErrorSeverity::Info.to_string());
assert_eq!("DEBUG", PgErrorSeverity::Debug.to_string());
assert_eq!("NOTICE", PgErrorSeverity::Notice.to_string());
assert_eq!("WARNING", PgErrorSeverity::Warning.to_string());

assert_eq!("ERROR", ErrorSeverity::Error.to_string());
assert_eq!("FATAL", ErrorSeverity::Fatal.to_string());
assert_eq!("PANIC", ErrorSeverity::Panic.to_string());
assert_eq!("ERROR", PgErrorSeverity::Error.to_string());
assert_eq!("FATAL", PgErrorSeverity::Fatal.to_string());
assert_eq!("PANIC", PgErrorSeverity::Panic.to_string());

// test for severity method
for code in PgErrorCode::iter() {
let name = code.as_ref();
assert_eq!("Ec", &name[0..2]);

if name.starts_with("Ec00") {
assert_eq!(ErrorSeverity::Info, code.severity());
assert_eq!(PgErrorSeverity::Info, code.severity());
} else if name.starts_with("Ec01") {
assert_eq!(ErrorSeverity::Warning, code.severity());
assert_eq!(PgErrorSeverity::Warning, code.severity());
} else if name.starts_with("Ec28") || name.starts_with("Ec3D") {
assert_eq!(ErrorSeverity::Fatal, code.severity());
assert_eq!(PgErrorSeverity::Fatal, code.severity());
} else {
assert_eq!(ErrorSeverity::Error, code.severity());
assert_eq!(PgErrorSeverity::Error, code.severity());
}
}
}
Expand Down
56 changes: 41 additions & 15 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ pub struct QueryContext {
current_catalog: String,
// we use Arc<RwLock>> for modifiable fields
#[builder(default)]
mutable_inner: Arc<RwLock<MutableInner>>,
mutable_session_data: Arc<RwLock<MutableInner>>,
#[builder(default)]
mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
sql_dialect: Arc<dyn Dialect + Send + Sync>,
#[builder(default)]
extensions: HashMap<String, String>,
Expand All @@ -52,6 +54,12 @@ pub struct QueryContext {
channel: Channel,
}

/// This fields hold data that is only valid to current query context
#[derive(Debug, Builder, Clone, Default)]
pub struct QueryContextMutableFields {
warning: Option<String>,
}

impl Display for QueryContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
Expand All @@ -65,21 +73,26 @@ impl Display for QueryContext {

impl QueryContextBuilder {
pub fn current_schema(mut self, schema: String) -> Self {
if self.mutable_inner.is_none() {
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
if self.mutable_session_data.is_none() {
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
}

// safe for unwrap because previous none check
self.mutable_inner.as_mut().unwrap().write().unwrap().schema = schema;
self.mutable_session_data
.as_mut()
.unwrap()
.write()
.unwrap()
.schema = schema;
self
}

pub fn timezone(mut self, timezone: Timezone) -> Self {
if self.mutable_inner.is_none() {
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
if self.mutable_session_data.is_none() {
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
}

self.mutable_inner
self.mutable_session_data
.as_mut()
.unwrap()
.write()
Expand Down Expand Up @@ -120,7 +133,7 @@ impl From<QueryContext> for api::v1::QueryContext {
fn from(
QueryContext {
current_catalog,
mutable_inner,
mutable_session_data: mutable_inner,
extensions,
channel,
..
Expand Down Expand Up @@ -182,11 +195,11 @@ impl QueryContext {
}

pub fn current_schema(&self) -> String {
self.mutable_inner.read().unwrap().schema.clone()
self.mutable_session_data.read().unwrap().schema.clone()
}

pub fn set_current_schema(&self, new_schema: &str) {
self.mutable_inner.write().unwrap().schema = new_schema.to_string();
self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
}

pub fn current_catalog(&self) -> &str {
Expand All @@ -208,19 +221,19 @@ impl QueryContext {
}

pub fn timezone(&self) -> Timezone {
self.mutable_inner.read().unwrap().timezone.clone()
self.mutable_session_data.read().unwrap().timezone.clone()
}

pub fn set_timezone(&self, timezone: Timezone) {
self.mutable_inner.write().unwrap().timezone = timezone;
self.mutable_session_data.write().unwrap().timezone = timezone;
}

pub fn current_user(&self) -> UserInfoRef {
self.mutable_inner.read().unwrap().user_info.clone()
self.mutable_session_data.read().unwrap().user_info.clone()
}

pub fn set_current_user(&self, user: UserInfoRef) {
self.mutable_inner.write().unwrap().user_info = user;
self.mutable_session_data.write().unwrap().user_info = user;
}

pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
Expand Down Expand Up @@ -257,6 +270,18 @@ impl QueryContext {
pub fn set_channel(&mut self, channel: Channel) {
self.channel = channel;
}

pub fn warning(&self) -> Option<String> {
self.mutable_query_context_data
.read()
.unwrap()
.warning
.clone()
}

pub fn set_warning(&self, msg: String) {
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
}
}

impl QueryContextBuilder {
Expand All @@ -266,7 +291,8 @@ impl QueryContextBuilder {
current_catalog: self
.current_catalog
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
mutable_inner: self.mutable_inner.unwrap_or_default(),
mutable_session_data: self.mutable_session_data.unwrap_or_default(),
mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
sql_dialect: self
.sql_dialect
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
Expand Down
2 changes: 1 addition & 1 deletion src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Session {
// catalog is not allowed for update in query context so we use
// string here
.current_catalog(self.catalog.read().unwrap().clone())
.mutable_inner(self.mutable_inner.clone())
.mutable_session_data(self.mutable_inner.clone())
.sql_dialect(self.conn_info.channel.dialect())
.configuration_parameter(self.configuration_variables.clone())
.channel(self.conn_info.channel)
Expand Down