diff --git a/src/ast/dml/select.rs b/src/ast/dml/select.rs index b6a159c..c2847e6 100644 --- a/src/ast/dml/select.rs +++ b/src/ast/dml/select.rs @@ -226,3 +226,65 @@ impl From for SQLExpression { SQLExpression::Subquery(SubqueryExpression::Select(Box::new(value))) } } + +#[cfg(test)] +#[allow(non_snake_case)] +mod tests { + use crate::ast::{ + dml::expressions::call::CallExpression, + types::{AggregateFunction, BuiltInFunction, Function}, + }; + + use super::*; + + #[test] + fn test_from_table() { + let select_query = SelectQuery::builder() + .set_from_table(TableName::new(None, "table".into())) + .build(); + + assert!(select_query.has_from_table()); + } + + #[test] + fn test_has_group_by() { + let select_query = SelectQuery::builder().build(); + + assert_eq!(select_query.has_group_by(), false); + + let select_query = SelectQuery::builder() + .add_group_by(GroupByItem { + item: SelectColumn { + table_name: None, + column_name: "foo".into(), + }, + }) + .build(); + + assert_eq!(select_query.has_group_by(), true); + } + + #[test] + fn test_get_aggregate_column() { + let select_query = SelectQuery::builder() + .add_select_item( + SelectItem::builder() + .set_item(SQLExpression::FunctionCall(CallExpression { + function: Function::BuiltIn(BuiltInFunction::Aggregate( + AggregateFunction::Count, + )), + arguments: vec![SQLExpression::SelectColumn(SelectColumn { + table_name: None, + column_name: "bar".into(), + })], + })) + .build(), + ) + .build(); + + let aggregate_columns = select_query.get_aggregate_column(); + + assert_eq!(aggregate_columns.len(), 1); + assert_eq!(aggregate_columns[0].column_name, "bar"); + } +}