Skip to content

Commit

Permalink
Implement EXPLAIN for mysql, also fix NULLs
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Apr 26, 2024
1 parent f66ee26 commit f057f73
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 19 deletions.
30 changes: 30 additions & 0 deletions nexus/peer-mysql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod ast;
mod client;
mod stream;

use std::fmt::Write;

use peer_cursor::{
CursorManager, CursorModification, QueryExecutor, QueryOutput, RecordStream, Schema,
};
Expand Down Expand Up @@ -61,6 +63,34 @@ impl QueryExecutor for MySqlQueryExecutor {
async fn execute(&self, stmt: &Statement) -> PgWireResult<QueryOutput> {
// only support SELECT statements
match stmt {
Statement::Explain { analyze, format, statement, .. } => {
if let Statement::Query(ref query) = **statement {
let mut query = query.clone();
ast::rewrite_query(&self.peer_name, &mut query);
let mut querystr = String::from("EXPLAIN ");
if *analyze {
querystr.push_str("ANALYZE ");
}
if let Some(format) = format {
write!(querystr, "FORMAT={} ", format).ok();
}
write!(querystr, "{}", query).ok();
tracing::info!("mysql rewritten query: {}", query);

let cursor = self.query(querystr).await?;
Ok(QueryOutput::Stream(Box::pin(cursor)))
} else {
let error = format!(
"only EXPLAIN SELECT statements are supported in mysql. got: {}",
statement
);
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"fdw_error".to_owned(),
error,
))))
}
}
Statement::Query(query) => {
let mut query = query.clone();
ast::rewrite_query(&self.peer_name, &mut query);
Expand Down
45 changes: 26 additions & 19 deletions nexus/peer-mysql/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,50 +104,57 @@ impl MyRecordStream {
}
}

pub fn mysql_row_to_values(row: &Row) -> Vec<Value> {
row.columns_ref()
.iter()
.enumerate()
.map(|(i, col)| match col.column_type() {
pub fn mysql_row_to_values(row: Row) -> Vec<Value> {
use mysql_async::from_value;
let columns = row.columns();
row.unwrap()
.into_iter()
.zip(columns.iter())
.map(|(val, col)|
if val == mysql_async::Value::NULL {
Value::Null
} else {
match col.column_type() {
ColumnType::MYSQL_TYPE_NULL | ColumnType::MYSQL_TYPE_UNKNOWN => Value::Null,
ColumnType::MYSQL_TYPE_TINY => Value::TinyInt(row.get(i).unwrap()),
ColumnType::MYSQL_TYPE_TINY => Value::TinyInt(from_value(val)),
ColumnType::MYSQL_TYPE_SHORT | ColumnType::MYSQL_TYPE_YEAR => {
Value::SmallInt(row.get(i).unwrap())
Value::SmallInt(from_value(val))
}
ColumnType::MYSQL_TYPE_LONG | ColumnType::MYSQL_TYPE_INT24 => {
Value::Integer(row.get(i).unwrap())
Value::Integer(from_value(val))
}
ColumnType::MYSQL_TYPE_LONGLONG => Value::BigInt(row.get(i).unwrap()),
ColumnType::MYSQL_TYPE_FLOAT => Value::Float(row.get(i).unwrap()),
ColumnType::MYSQL_TYPE_DOUBLE => Value::Double(row.get(i).unwrap()),
ColumnType::MYSQL_TYPE_LONGLONG => Value::BigInt(from_value(val)),
ColumnType::MYSQL_TYPE_FLOAT => Value::Float(from_value(val)),
ColumnType::MYSQL_TYPE_DOUBLE => Value::Double(from_value(val)),
ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
Value::Numeric(row.get(i).unwrap())
Value::Numeric(from_value(val))
}
ColumnType::MYSQL_TYPE_VARCHAR
| ColumnType::MYSQL_TYPE_VAR_STRING
| ColumnType::MYSQL_TYPE_STRING
| ColumnType::MYSQL_TYPE_ENUM
| ColumnType::MYSQL_TYPE_SET => Value::Text(row.get(i).unwrap()),
| ColumnType::MYSQL_TYPE_SET => Value::Text(from_value(val)),
ColumnType::MYSQL_TYPE_TINY_BLOB
| ColumnType::MYSQL_TYPE_MEDIUM_BLOB
| ColumnType::MYSQL_TYPE_LONG_BLOB
| ColumnType::MYSQL_TYPE_BLOB
| ColumnType::MYSQL_TYPE_BIT
| ColumnType::MYSQL_TYPE_GEOMETRY => {
Value::Binary(row.get::<Vec<u8>, usize>(i).unwrap().into())
Value::Binary(from_value::<Vec<u8>>(val).into())
}
ColumnType::MYSQL_TYPE_DATE | ColumnType::MYSQL_TYPE_NEWDATE => {
Value::Date(row.get(i).unwrap())
Value::Date(from_value(val))
}
ColumnType::MYSQL_TYPE_TIME | ColumnType::MYSQL_TYPE_TIME2 => {
Value::Time(row.get(i).unwrap())
Value::Time(from_value(val))
}
ColumnType::MYSQL_TYPE_TIMESTAMP
| ColumnType::MYSQL_TYPE_TIMESTAMP2
| ColumnType::MYSQL_TYPE_DATETIME
| ColumnType::MYSQL_TYPE_DATETIME2 => Value::PostgresTimestamp(row.get(i).unwrap()),
ColumnType::MYSQL_TYPE_JSON => Value::JsonB(row.get(i).unwrap()),
| ColumnType::MYSQL_TYPE_DATETIME2 => Value::PostgresTimestamp(from_value(val)),
ColumnType::MYSQL_TYPE_JSON => Value::JsonB(from_value(val)),
ColumnType::MYSQL_TYPE_TYPED_ARRAY => Value::Null,
}
})
.collect()
}
Expand All @@ -158,7 +165,7 @@ impl Stream for MyRecordStream {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let row_stream = &mut self.stream;
match Pin::new(row_stream).poll_next(cx) {
Poll::Ready(Some(client::Response::Row(ref row))) => Poll::Ready(Some(Ok(Record {
Poll::Ready(Some(client::Response::Row(row))) => Poll::Ready(Some(Ok(Record {
schema: self.schema.clone(),
values: mysql_row_to_values(row),
}))),
Expand Down

0 comments on commit f057f73

Please sign in to comment.