Skip to content

Commit

Permalink
为聚合函数添加验证
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick16262 committed Jun 11, 2024
1 parent 8f1a4ad commit fe0843b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
62 changes: 41 additions & 21 deletions src/observer/sql/expr/expression_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,21 +541,10 @@ RC ProjectExpressionResovler::resolve_projection_list(
return rc;
}

assert(refactor.aggregate_childs().size() == refactor.aggregate_cells().size());
assert(refactor.aggregate_types().size() == refactor.aggregate_cells().size());
vector aggregate_childs = std::move(refactor.aggregate_childs());
vector aggregate_cells = std::move(refactor.aggregate_cells());
vector aggregate_types = refactor.aggregate_types();

for (int i = 0; i < aggregate_childs.size(); i++) {
unique_ptr<Expression> aggr_child_expr;
rc = generator_.generate_expression(aggregate_childs[i].get(), aggr_child_expr);
if (rc != RC::SUCCESS) {
LOG_WARN("generate aggregate child expression failed rc = %d:%s", rc, strrc(rc));
return rc;
}
aggregate_desc_.push_back(
std::make_unique<AggregateDesc>(aggregate_types[i], std::move(aggr_child_expr), aggregate_cells[i]));
rc = push_groupping(refactor);
if (rc != RC::SUCCESS) {
LOG_WARN("push groupping failed rc = %d:%s", rc, strrc(rc));
return rc;
}

query_exprs.push_back(std::move(expr));
Expand All @@ -581,11 +570,6 @@ RC ProjectExpressionResovler::resolve_projection_list(
}
}

if (aggregate_desc_.size() != sql_nodes.size() && !aggregate_desc_.empty()) {
LOG_WARN("aggregate function cannot be mixed with other expressions");
return RC::INVALID_AGGREGATE;
}

return RC::SUCCESS;
}

Expand Down Expand Up @@ -654,4 +638,40 @@ RC ProjectExpressionResovler::wildcard_fields(
}

return RC::SUCCESS;
}
}

RC ProjectExpressionResovler::push_groupping(ExpressionStructRefactor &refactor)
{
RC rc;

assert(refactor.aggregate_childs().size() == refactor.aggregate_cells().size());
assert(refactor.aggregate_types().size() == refactor.aggregate_cells().size());
vector aggregate_childs = std::move(refactor.aggregate_childs());
vector aggregate_cells = std::move(refactor.aggregate_cells());
vector aggregate_types = refactor.aggregate_types();

if ((must_not_aggregate && !aggregate_childs.empty()) || (must_aggregate && aggregate_childs.empty())) {
LOG_WARN("aggregate function with non-aggregate fields is not allowed");
return RC::INVALID_AGGREGATE;
}

if (aggregate_childs.empty()) {
must_not_aggregate = true;
return RC::SUCCESS;
} else {
must_aggregate = true;
}

for (int i = 0; i < aggregate_childs.size(); i++) {
unique_ptr<Expression> aggr_child_expr;
rc = generator_.generate_expression(aggregate_childs[i].get(), aggr_child_expr);
if (rc != RC::SUCCESS) {
LOG_WARN("generate aggregate child expression failed rc = %d:%s", rc, strrc(rc));
return rc;
}
aggregate_desc_.push_back(
std::make_unique<AggregateDesc>(aggregate_types[i], std::move(aggr_child_expr), aggregate_cells[i]));
}

return RC::SUCCESS;
}
26 changes: 17 additions & 9 deletions src/observer/sql/expr/expression_resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class ExpressionGenerator
}
virtual ~ExpressionGenerator() = default;

void set_attr_schema(std::vector<TupleCellSpec> current_tuple_schema) { current_tuple_schema_ = std::move(current_tuple_schema); }
void set_attr_schema(std::vector<TupleCellSpec> current_tuple_schema)
{
current_tuple_schema_ = std::move(current_tuple_schema);
}

RC generate_expression(const ExpressionSqlNode *sql_node, std::unique_ptr<Expression> &expr);

Expand All @@ -60,10 +63,10 @@ class ExpressionGenerator
RC generate_expression(const ExistsExpressionSqlNode *sql_node, std::unique_ptr<Expression> &expr);

private:
std::unordered_multimap<std::string, TableFactorDesc> field_table_map_; // 字段->表
std::unordered_set<std::string> outter_alias_set_; // 父查询中的别名
std::vector<TupleCellSpec> current_tuple_schema_; // 当前查询的schema
Db *db_ = nullptr; // nullable
std::unordered_multimap<std::string, TableFactorDesc> field_table_map_; // 字段->表
std::unordered_set<std::string> outter_alias_set_; // 父查询中的别名
std::vector<TupleCellSpec> current_tuple_schema_; // 当前查询的schema
Db *db_ = nullptr; // nullable
};

/**
Expand Down Expand Up @@ -107,9 +110,10 @@ class WhereConditionExpressionResolver
public:
WhereConditionExpressionResolver(
Db *db, std::vector<TableFactorDesc> table_desc, std::vector<TupleCellSpec> tuple_schema)
: generator_(db, table_desc), db_(db), table_desc_(std::move(table_desc)), tuple_schema_(tuple_schema){
generator_.set_attr_schema(tuple_schema);
};
: generator_(db, table_desc), db_(db), table_desc_(std::move(table_desc)), tuple_schema_(tuple_schema)
{
generator_.set_attr_schema(tuple_schema);
};
virtual ~WhereConditionExpressionResolver() = default;

RC resolve(ExpressionSqlNode *sql_node, std::unique_ptr<Expression> &expr);
Expand Down Expand Up @@ -158,7 +162,7 @@ class ProjectExpressionResovler
virtual ~ProjectExpressionResovler() = default;

RC resolve_projection_list(
const vector<ExpressionWithAliasSqlNode *> &sql_nodes, vector<unique_ptr<Expression>> &query_exprs);
const vector<ExpressionWithAliasSqlNode *> &sql_nodes, vector<unique_ptr<Expression>> &query_exprs);

public:
std::vector<SubqueryType> &subquery_types() { return subquery_types_; }
Expand All @@ -169,6 +173,7 @@ class ProjectExpressionResovler

private:
RC wildcard_fields(FieldExpressionSqlNode *wildcard_expression, vector<unique_ptr<Expression>> &query_exprs);
RC push_groupping(ExpressionStructRefactor &refactor);

private:
std::vector<SubqueryType> subquery_types_;
Expand All @@ -186,4 +191,7 @@ class ProjectExpressionResovler
std::vector<TableFactorDesc> table_desc_;
std::vector<TupleCellSpec> outter_tuple_;
std::vector<ExpressionSqlNode *> group_exprs_;

bool must_aggregate = false;
bool must_not_aggregate = false;
};

0 comments on commit fe0843b

Please sign in to comment.