Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(servers): improve postgres error message #4463

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/servers/src/postgres/auth_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use snafu::IntoError;
use super::PostgresServerHandler;
use crate::error::{AuthSnafu, Result};
use crate::metrics::METRIC_AUTH_FAILURE;
use crate::postgres::types::PgErrorCode;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;

pub(crate) struct PgLoginVerifier {
Expand Down Expand Up @@ -141,7 +142,11 @@ impl StartupHandler for PostgresServerHandler {
PgWireFrontendMessage::Startup(ref startup) => {
// check ssl requirement
if !client.is_secure() && self.force_tls {
send_error(client, "FATAL", "28000", "No encryption".to_owned()).await?;
send_error(
client,
PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
)
.await?;
return Ok(());
}

Expand All @@ -155,7 +160,7 @@ impl StartupHandler for PostgresServerHandler {
let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
}
DbResolution::NotFound(msg) => {
send_error(client, "FATAL", "3D000", msg).await?;
send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
return Ok(());
}
}
Expand Down Expand Up @@ -193,9 +198,8 @@ impl StartupHandler for PostgresServerHandler {
} else {
return send_error(
client,
"FATAL",
"28P01",
"password authentication failed".to_owned(),
PgErrorCode::Ec28P01
.to_err_info("password authentication failed".to_string()),
)
.await;
}
Expand All @@ -206,13 +210,13 @@ impl StartupHandler for PostgresServerHandler {
}
}

async fn send_error<C>(client: &mut C, level: &str, code: &str, message: String) -> PgWireResult<()>
async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let error = ErrorResponse::from(ErrorInfo::new(level.to_owned(), code.to_owned(), message));
let error = ErrorResponse::from(err_info);
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
Expand Down
14 changes: 5 additions & 9 deletions src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,9 @@ fn output_to_query_response<'a>(
e
);
};
Ok(Response::Error(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"XX000".to_string(),
e.output_msg(),
))))
Ok(Response::Error(Box::new(
PgErrorCode::from(status_code).to_err_info(e.output_msg()),
)))
}
}
}
Expand Down Expand Up @@ -184,10 +182,8 @@ impl QueryParser for DefaultQueryParser {
ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
if stmts.len() != 1 {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"42P14".to_owned(),
"invalid_prepared_statement_definition".to_owned(),
Err(PgWireError::UserError(Box::new(ErrorInfo::from(
PgErrorCode::Ec42P14,
))))
} else {
let stmt = stmts.remove(0);
Expand Down
26 changes: 21 additions & 5 deletions src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod bytea;
mod datetime;
mod error;
mod interval;

use std::collections::HashMap;
Expand All @@ -28,15 +29,16 @@ use datatypes::types::TimestampType;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::error::{PgWireError, PgWireResult};
use query::plan::LogicalPlan;
use session::context::QueryContextRef;
use session::session_config::PGByteaOutputValue;

use self::bytea::{EscapeOutputBytea, HexOutputBytea};
use self::datetime::{StylingDate, StylingDateTime};
pub use self::error::PgErrorCode;
use self::interval::PgInterval;
use crate::error::{self, Error, Result};
use crate::error::{self as server_error, Error, Result};
use crate::SqlPlan;

pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
Expand Down Expand Up @@ -154,7 +156,7 @@ pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
&ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
&ConcreteDataType::Duration(_)
| &ConcreteDataType::List(_)
| &ConcreteDataType::Dictionary(_) => error::UnsupportedDataTypeSnafu {
| &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
data_type: origin,
reason: "not implemented",
}
Expand All @@ -177,7 +179,7 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
)),
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
_ => error::InternalSnafu {
_ => server_error::InternalSnafu {
err_msg: format!("unimplemented datatype {origin:?}"),
}
.fail(),
Expand Down Expand Up @@ -236,7 +238,7 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
}

pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError {
let mut error_info = ErrorInfo::new("ERROR".to_owned(), "22023".to_owned(), msg.to_owned());
let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
error_info.detail = detail.map(|s| s.to_owned());
PgWireError::UserError(Box::new(error_info))
}
Expand Down Expand Up @@ -829,4 +831,18 @@ mod test {
}
}
}

#[test]
fn test_invalid_parameter() {
// test for refactor with PgErrorCode
let msg = "invalid_parameter_count";
let error = invalid_parameter_error(msg, None);
if let PgWireError::UserError(value) = error {
assert_eq!("ERROR", value.severity);
assert_eq!("22023", value.code);
assert_eq!(msg, value.message);
} else {
panic!("test_invalid_parameter failed");
}
}
}
Loading
Loading