Skip to content

Commit

Permalink
parse and create list type
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Nov 18, 2024
1 parent 7d08a50 commit 3ad3f7e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 55 deletions.
78 changes: 65 additions & 13 deletions wren-core/core/src/logical_plan/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use datafusion::common::plan_err;
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result;
use datafusion::logical_expr::sqlparser::ast::ArrayElemTypeDef;
use datafusion::logical_expr::sqlparser::dialect::GenericDialect;
use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource};
use datafusion::sql::sqlparser::ast;
Expand All @@ -23,13 +24,33 @@ use petgraph::Graph;
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};

fn create_mock_list_type() -> DataType {
let string_filed = Arc::new(Field::new("string", DataType::Utf8, false));
DataType::List(string_filed)
fn create_list_type(array_type: &str) -> Result<DataType> {
if let ast::DataType::Array(value) = parse_type(array_type)? {
let data_type = match value {
ArrayElemTypeDef::None => {
return plan_err!("Array type must have an element type")
}
ArrayElemTypeDef::AngleBracket(data_type) => {
map_data_type(&data_type.to_string())?
}
ArrayElemTypeDef::SquareBracket(_, _) => {
unreachable!()
}
ArrayElemTypeDef::Parenthesis(_) => {
return plan_err!(
"The format of the array type should be 'array<element_type>'"
)
}
};
return Ok(DataType::List(Arc::new(Field::new(
"element", data_type, false,
))));
}
unreachable!()
}

fn create_struct_type(struct_type: &str) -> Result<DataType> {
let sql_type = parse_struct_type(struct_type).unwrap();
let sql_type = parse_type(struct_type)?;
let mut builder = SchemaBuilder::new();
let mut counter = 0;
match sql_type {
Expand Down Expand Up @@ -59,7 +80,7 @@ fn create_struct_type(struct_type: &str) -> Result<DataType> {
Ok(DataType::Struct(fields))
}

fn parse_struct_type(struct_type: &str) -> Result<ast::DataType> {
fn parse_type(struct_type: &str) -> Result<ast::DataType> {
let dialect = GenericDialect {};
Ok(Parser::new(&dialect)
.try_with_sql(struct_type)?
Expand All @@ -72,7 +93,7 @@ pub fn map_data_type(data_type: &str) -> Result<DataType> {
// Currently, we don't care about the element type of the array or struct.
// We only care about the array or struct itself.
if data_type.starts_with("array") {
return Ok(create_mock_list_type());
return create_list_type(data_type);
}
if data_type.starts_with("struct") {
return create_struct_type(data_type);
Expand Down Expand Up @@ -280,7 +301,9 @@ pub fn expr_to_columns(

#[cfg(test)]
mod test {
use crate::logical_plan::utils::{create_mock_list_type, create_struct_type};
use crate::logical_plan::utils::{
create_list_type, create_struct_type, map_data_type,
};
use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion::common::Result;

Expand Down Expand Up @@ -338,19 +361,48 @@ mod test {
("null", DataType::Null),
("geography", DataType::Utf8),
("range", DataType::Utf8),
("array<int64>", create_mock_list_type()),
("array<int64>", create_list_type("array<int64>")?),
(
"struct<name string, age int>",
create_struct_type("struct<name string, age int>")?,
),
];
for (data_type, expected) in test_cases {
let result = super::map_data_type(data_type)?;
let result = map_data_type(data_type)?;
assert_eq!(result, expected);
// test case insensitivity
let result = super::map_data_type(&data_type.to_uppercase())?;
let result = map_data_type(&data_type.to_uppercase())?;
assert_eq!(result, expected);
}

let _ = map_data_type("array").map_err(|e| {
assert_eq!(
e.to_string(),
"SQL error: ParserError(\"Expected: <, found: EOF\")"
);
});

let _ = map_data_type("array<>").map_err(|e| {
assert_eq!(
e.to_string(),
"SQL error: ParserError(\"Expected: <, found: <> at Line: 1, Column: 6\")"
);
});

let _ = map_data_type("array(int64)").map_err(|e| {
assert_eq!(
e.to_string(),
"SQL error: ParserError(\"Expected: <, found: ( at Line: 1, Column: 6\")"
);
});

let _ = map_data_type("struct").map_err(|e| {
assert_eq!(
e.to_string(),
"Error during planning: struct must have at least one field"
);
});

Ok(())
}

Expand All @@ -376,12 +428,12 @@ mod test {
let expected = DataType::Struct(fields);
assert_eq!(result, expected);
let struct_string = "STRUCT<>";
create_struct_type(struct_string).map_err(|e| {
let _ = create_struct_type(struct_string).map_err(|e| {
assert_eq!(
e.to_string(),
"Error during planning: struct must have at least one field"
);
}).unwrap();
)
});
Ok(())
}
}
51 changes: 9 additions & 42 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,8 @@ mod test {
use crate::mdl::manifest::Manifest;
use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL};
use datafusion::arrow::array::{
ArrayRef, Float64Array, Int64Array, ListArray, RecordBatch, StringArray,
StructArray, TimestampNanosecondArray,
ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray,
};
use datafusion::arrow::datatypes::{DataType, Field, Fields, Int32Type, TimeUnit};
use datafusion::assert_batches_eq;
use datafusion::common::not_impl_err;
use datafusion::common::Result;
Expand Down Expand Up @@ -1101,7 +1099,6 @@ mod test {
#[tokio::test]
async fn test_list() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("list_table", list_table()?)?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
Expand All @@ -1124,7 +1121,6 @@ mod test {
#[tokio::test]
async fn test_struct() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("struct_table", struct_table()?)?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
Expand Down Expand Up @@ -1167,12 +1163,14 @@ mod test {
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = "select struct_col.float_field from wren.test.struct_table";
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await.map_err(|e| {
assert_eq!(
e.to_string(),
"Error during planning: struct must have at least one field"
)
}).unwrap();
let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql)
.await
.map_err(|e| {
assert_eq!(
e.to_string(),
"Error during planning: struct must have at least one field"
)
});
Ok(())
}

Expand Down Expand Up @@ -1239,35 +1237,4 @@ mod test {
])
.unwrap()
}

fn list_table() -> Result<RecordBatch> {
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let list_array: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(data));
Ok(RecordBatch::try_from_iter(vec![("list_col", list_array)]).unwrap())
}

fn struct_table() -> Result<RecordBatch> {
let field: Fields = vec![
Field::new("float_field", DataType::Float64, true),
Field::new(
"time_field",
DataType::Timestamp(TimeUnit::Nanosecond, None),
true,
),
]
.into();
let f64arr: ArrayRef = Arc::new(Float64Array::from(vec![1.0])) as ArrayRef;
let timearr: ArrayRef =
Arc::new(TimestampNanosecondArray::from(vec![1])) as ArrayRef;

let struct_arr: ArrayRef =
Arc::new(StructArray::try_new(field, vec![f64arr, timearr], None)?);
Ok(RecordBatch::try_from_iter(vec![("struct_col", struct_arr)]).unwrap())
}
}

0 comments on commit 3ad3f7e

Please sign in to comment.