Skip to content

Commit

Permalink
refactor: unify mysql execute through cli and protocol (#5038)
Browse files Browse the repository at this point in the history
refactor: mysql execute
  • Loading branch information
CookiePieWw authored Nov 22, 2024
1 parent 1578c00 commit 1255638
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 121 deletions.
10 changes: 9 additions & 1 deletion src/servers/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,13 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},

#[snafu(display("Prepare statement not found: {}", name))]
PrepareStatementNotFound {
name: String,
#[snafu(implicit)]
location: Location,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -643,7 +650,8 @@ impl ErrorExt for Error {
| TimestampOverflow { .. }
| OpenTelemetryLog { .. }
| UnsupportedJsonDataTypeForTag { .. }
| InvalidTableName { .. } => StatusCode::InvalidArguments,
| InvalidTableName { .. }
| PrepareStatementNotFound { .. } => StatusCode::InvalidArguments,

Catalog { source, .. } => source.status_code(),
RowWriter { source, .. } => source.status_code(),
Expand Down
227 changes: 107 additions & 120 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ use crate::SqlPlan;
const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";

/// Parameters for the prepared statement
enum Params<'a> {
/// Parameters passed through protocol
ProtocolParams(Vec<ParamValue<'a>>),
/// Parameters passed through cli
CliParams(Vec<sql::ast::Expr>),
}

impl Params<'_> {
fn len(&self) -> usize {
match self {
Params::ProtocolParams(params) => params.len(),
Params::CliParams(params) => params.len(),
}
}
}

// An intermediate shim for executing MySQL queries.
pub struct MysqlInstanceShim {
query_handler: ServerSqlQueryHandlerRef,
Expand Down Expand Up @@ -143,9 +160,9 @@ impl MysqlInstanceShim {
}

/// Retrieve the query and logical plan by a given statement key
fn plan(&self, stmt_key: String) -> Option<SqlPlan> {
fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
let guard = self.prepared_stmts.read();
guard.get(&stmt_key).cloned()
guard.get(stmt_key).cloned()
}

/// Save the prepared statement and return the parameters and result columns
Expand Down Expand Up @@ -227,6 +244,66 @@ impl MysqlInstanceShim {
Ok((params, columns))
}

async fn do_execute<'a>(
&mut self,
query_ctx: QueryContextRef,
stmt_key: String,
params: Params<'a>,
) -> Result<Vec<std::result::Result<Output, error::Error>>> {
let sql_plan = match self.plan(&stmt_key) {
None => {
return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
}
Some(sql_plan) => sql_plan,
};

let outputs = match sql_plan.plan {
Some(plan) => {
let param_types = plan
.get_parameter_types()
.context(DataFrameSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect::<HashMap<_, _>>();

if params.len() != param_types.len() {
return error::InternalSnafu {
err_msg: "Prepare statement params number mismatch".to_string(),
}
.fail();
}

let plan = match params {
Params::ProtocolParams(params) => {
replace_params_with_values(&plan, param_types, &params)
}
Params::CliParams(params) => {
replace_params_with_exprs(&plan, param_types, &params)
}
}?;

debug!("Mysql execute prepared plan: {}", plan.display_indent());
vec![
self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
.await,
]
}
None => {
let param_strs = match params {
Params::ProtocolParams(params) => {
params.iter().map(convert_param_value_to_string).collect()
}
Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
};
let query = replace_params(param_strs, sql_plan.query);
debug!("Mysql execute replaced query: {}", query);
self.do_query(&query, query_ctx.clone()).await
}
};

Ok(outputs)
}

/// Remove the prepared statement by a given statement key
fn do_close(&mut self, stmt_key: String) {
let mut guard = self.prepared_stmts.write();
Expand Down Expand Up @@ -356,62 +433,20 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi

let params: Vec<ParamValue> = p.into_iter().collect();
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
let sql_plan = match self.plan(stmt_key) {
None => {
w.error(
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
b"prepare statement not found",
)
.await?;
return Ok(());
}
Some(sql_plan) => sql_plan,
};

let outputs = match sql_plan.plan {
Some(plan) => {
let param_types = plan
.get_parameter_types()
.context(DataFrameSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect::<HashMap<_, _>>();

if params.len() != param_types.len() {
return error::InternalSnafu {
err_msg: "prepare statement params number mismatch".to_string(),
}
.fail();
}

let plan = match replace_params_with_values(&plan, param_types, &params) {
Ok(plan) => plan,
Err(e) => {
let (kind, err) = handle_err(e, query_ctx);
debug!(
"Failed to replace params on execute, kind: {:?}, err: {}",
kind, err
);
w.error(kind, err.as_bytes()).await?;

return Ok(());
}
};

debug!("Mysql execute prepared plan: {}", plan.display_indent());
vec![
self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
.await,
]
}
None => {
let param_strs = params
.iter()
.map(|x| convert_param_value_to_string(x))
.collect();
let query = replace_params(param_strs, sql_plan.query);
debug!("Mysql execute replaced query: {}", query);
self.do_query(&query, query_ctx.clone()).await
let outputs = match self
.do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
.await
{
Ok(outputs) => outputs,
Err(e) => {
let (kind, err) = handle_err(e, query_ctx);
debug!(
"Failed to execute prepared statement, kind: {:?}, err: {}",
kind, err
);
w.error(kind, err.as_bytes()).await?;
return Ok(());
}
};

Expand Down Expand Up @@ -469,67 +504,19 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
} else if query_upcase.starts_with("EXECUTE ") {
match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
// TODO: similar to on_execute, refactor this
Ok((stmt_name, params)) => {
let sql_plan = match self.plan(stmt_name) {
None => {
writer
.error(
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
b"prepare statement not found",
)
.await?;
return Ok(());
}
Some(sql_plan) => sql_plan,
};

let outputs = match sql_plan.plan {
Some(plan) => {
let param_types = plan
.get_parameter_types()
.context(DataFrameSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect::<HashMap<_, _>>();

if params.len() != param_types.len() {
writer
.error(
ErrorKind::ER_SP_BADSTATEMENT,
b"prepare statement params number mismatch",
)
.await?;
return Ok(());
}

let plan = match replace_params_with_exprs(&plan, param_types, &params)
{
Ok(plan) => plan,
Err(e) => {
let (kind, err) = handle_err(e, query_ctx);
debug!(
"Failed to replace params on query, kind: {:?}, err: {}",
kind, err
);
writer.error(kind, err.as_bytes()).await?;

return Ok(());
}
};

debug!("Mysql execute prepared plan: {}", plan.display_indent());
vec![
self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
.await,
]
}
None => {
let param_strs = params.iter().map(|x| x.to_string()).collect();
let query = replace_params(param_strs, sql_plan.query);
debug!("Mysql execute replaced query: {}", query);
let outputs = self.do_query(&query, query_ctx.clone()).await;
writer::write_output(writer, query_ctx, outputs).await?;
let outputs = match self
.do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
.await
{
Ok(outputs) => outputs,
Err(e) => {
let (kind, err) = handle_err(e, query_ctx);
debug!(
"Failed to execute prepared statement, kind: {:?}, err: {}",
kind, err
);
writer.error(kind, err.as_bytes()).await?;
return Ok(());
}
};
Expand Down Expand Up @@ -623,8 +610,8 @@ fn convert_param_value_to_string(param: &ParamValue) -> String {
ValueInner::Double(u) => u.to_string(),
ValueInner::NULL => "NULL".to_string(),
ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
ValueInner::Date(_) => NaiveDate::from(param.value).to_string(),
ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(),
ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
}
}
Expand All @@ -643,7 +630,7 @@ fn format_duration(duration: Duration) -> String {
let seconds = duration.as_secs() % 60;
let minutes = (duration.as_secs() / 60) % 60;
let hours = (duration.as_secs() / 60) / 60;
format!("{}:{}:{}", hours, minutes, seconds)
format!("'{}:{}:{}'", hours, minutes, seconds)
}

fn replace_params_with_values(
Expand Down

0 comments on commit 1255638

Please sign in to comment.