diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index 320454bc431..6af1c99b2b4 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -619,3 +619,22 @@ cases: sql: | select c1 + 8 from (select 9 as c1) + - id: 18 + mode: request-unsupport + inputs: + - name: t1 + columns: ["col1 int32", "std_ts timestamp", "col2 map"] + indexs: ["index1:col1:std_ts"] + # insert map values by insert stmts only + inserts: + - "insert into t1 values (1, timestamp(1000), map(12, 'abc'))" + - "insert into t1 values (2, timestamp(1000), map(13, 'abc'))" + sql: | + select col1, col2[12] as out from t1; + expect: + order: col1 + columns: ["col1 int", "out string"] + data: | + 1, abc + 2, null + diff --git a/hybridse/examples/toydb/src/cmd/toydb_run_engine.cc b/hybridse/examples/toydb/src/cmd/toydb_run_engine.cc index dd892c636d0..a3da3076968 100644 --- a/hybridse/examples/toydb/src/cmd/toydb_run_engine.cc +++ b/hybridse/examples/toydb/src/cmd/toydb_run_engine.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "absl/strings/match.h" #include "testing/toydb_engine_test_base.h" @@ -104,8 +103,7 @@ int RunSingle(const std::string& yaml_path) { int main(int argc, char** argv) { ::google::ParseCommandLineFlags(&argc, &argv, false); - InitializeNativeTarget(); - InitializeNativeTargetAsmPrinter(); + ::hybridse::vm::Engine::InitializeGlobalLLVM(); if (FLAGS_yaml_path != "") { return ::hybridse::vm::RunSingle(FLAGS_yaml_path); } else { diff --git a/hybridse/examples/toydb/src/testing/toydb_engine_test.cc b/hybridse/examples/toydb/src/testing/toydb_engine_test.cc index 02438aeebac..b6e82038597 100644 --- a/hybridse/examples/toydb/src/testing/toydb_engine_test.cc +++ b/hybridse/examples/toydb/src/testing/toydb_engine_test.cc @@ -15,7 +15,7 @@ */ #include "gtest/gtest.h" -#include "gtest/internal/gtest-param-util.h" +#include "vm/engine.h" #include "testing/toydb_engine_test_base.h" using namespace llvm; // NOLINT (build/namespaces) @@ -126,8 +126,7 @@ TEST_P(BatchRequestEngineTest, TestClusterBatchRequestEngine) { } // namespace hybridse int main(int argc, char** argv) { - InitializeNativeTarget(); - InitializeNativeTargetAsmPrinter(); + ::hybridse::vm::Engine::InitializeGlobalLLVM(); ::testing::InitGoogleTest(&argc, argv); // ::hybridse::vm::CoreAPI::EnableSignalTraceback(); return RUN_ALL_TESTS(); diff --git a/hybridse/examples/toydb/src/testing/toydb_engine_test_base.h b/hybridse/examples/toydb/src/testing/toydb_engine_test_base.h index eedfcca3680..239518c7faa 100644 --- a/hybridse/examples/toydb/src/testing/toydb_engine_test_base.h +++ b/hybridse/examples/toydb/src/testing/toydb_engine_test_base.h @@ -26,6 +26,7 @@ #include #include +#include "absl/strings/substitute.h" #include "case/case_data_mock.h" #include "case/sql_case.h" #include "glog/logging.h" @@ -98,6 +99,14 @@ class ToydbBatchEngineTestRunner : public BatchEngineTestRunner { CheckSqliteCompatible(sql_case_, GetSession()->GetSchema(), output_rows); } } + const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override { + auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table))); + if (it == name_table_map_.end()) { + return nullptr; + } + + return &it->second->GetTableDef(); + } private: std::shared_ptr catalog_; @@ -143,6 +152,15 @@ class ToydbRequestEngineTestRunner : public RequestEngineTestRunner { return table->Put(reinterpret_cast(row.buf()), row.size()); } + const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override { + auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table))); + if (it == name_table_map_.end()) { + return nullptr; + } + + return &it->second->GetTableDef(); + } + private: std::shared_ptr catalog_; std::map, std::shared_ptr<::hybridse::storage::Table>> name_table_map_; @@ -189,6 +207,15 @@ class ToydbBatchRequestEngineTestRunner : public BatchRequestEngineTestRunner { return table->Put(reinterpret_cast(row.buf()), row.size()); } + const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override { + auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table))); + if (it == name_table_map_.end()) { + return nullptr; + } + + return &it->second->GetTableDef(); + } + private: std::shared_ptr catalog_; std::map, std::shared_ptr<::hybridse::storage::Table>> name_table_map_; diff --git a/hybridse/include/case/sql_case.h b/hybridse/include/case/sql_case.h index 7cc05bba1d5..cb2d9907b37 100644 --- a/hybridse/include/case/sql_case.h +++ b/hybridse/include/case/sql_case.h @@ -118,10 +118,10 @@ class SqlCase { const codec::Schema ExtractParameterTypes() const; bool ExtractInputData(std::vector& rows, // NOLINT - int32_t input_idx = 0) const; - bool ExtractInputData( - const TableInfo& info, - std::vector& rows) const; // NOLINT + int32_t input_idx, const codec::Schema& sc) const; + bool ExtractInputData(const TableInfo& info, + std::vector& rows, // NOLINT + const codec::Schema&) const; bool ExtractOutputData( std::vector& rows) const; // NOLINT @@ -331,6 +331,10 @@ std::vector InitCases(std::string yaml_path, std::vector f void InitCases(std::string yaml_path, std::vector& cases); // NOLINT void InitCases(std::string yaml_path, std::vector& cases, // NOLINT const std::vector& filters); + +// TODO(someone): consider move the function to production code so others will take usage. +absl::StatusOr> ExtractInsertRow(vm::HybridSeJitWrapper*, absl::string_view insert, + const codec::Schema* table_schema); } // namespace sqlcase } // namespace hybridse #endif // HYBRIDSE_INCLUDE_CASE_SQL_CASE_H_ diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index ee81c222714..a5c8579a9d9 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -358,6 +358,22 @@ class SqlNode : public NodeBase { bool Equals(const SqlNode *node) const override; + // Return this node cast as a NodeType. + // Use only when this node is known to be that type, otherwise, behavior is undefined. + template + const NodeType *GetAsOrNull() const { + static_assert(std::is_base_of::value, + "NodeType must be a member of the SqlNode class hierarchy"); + return dynamic_cast(this); + } + + template + NodeType *GetAsOrNull() { + static_assert(std::is_base_of::value, + "NodeType must be a member of the SqlNode class hierarchy"); + return dynamic_cast(this); + } + SqlNodeType type_; private: @@ -1883,6 +1899,8 @@ class ColumnSchemaNode : public SqlNode { bool not_null() const { return not_null_; } const ExprNode *default_value() const { return default_value_; } + absl::Status GetProtoColumnSchema(type::ColumnSchema *) const; + std::string DebugString() const; private: @@ -1902,6 +1920,8 @@ class ColumnDefNode : public SqlNode { const ColumnSchemaNode *schema() const { return schema_; } + absl::Status GetProtoColumnDef(type::ColumnDef *) const; + // deprecated, use ColumnDefNode::schema instead DataType GetColumnType() const { return schema_->type(); } @@ -2025,8 +2045,11 @@ class CreateStmt : public SqlNode { ~CreateStmt() {} - NodePointVector* MutableColumnDefList() { return &column_desc_list_; } - const NodePointVector &GetColumnDefList() const { return column_desc_list_; } + NodePointVector* MutableTableElementList() { return &column_desc_list_; } + const NodePointVector &GetTableElementList() const { return column_desc_list_; } + + // collect the column definitions in column_desc_list_, and convert to the proto representation type::ColumnDef + absl::StatusOr GetColumnDefListAsSchema() const; std::string GetTableName() const { return table_name_; } std::string GetDbName() const { return db_name_; } diff --git a/hybridse/include/plan/plan_api.h b/hybridse/include/plan/plan_api.h index 1e4f3b74845..371d0dd32fa 100644 --- a/hybridse/include/plan/plan_api.h +++ b/hybridse/include/plan/plan_api.h @@ -29,6 +29,8 @@ using hybridse::base::Status; using hybridse::node::NodeManager; using hybridse::node::NodePointVector; using hybridse::node::PlanNodeList; + +// TODO(someone): rm class PlanAPI class PlanAPI { public: // parse SQL string to logic plan. ASTNode and LogicNode saved in SqlContext @@ -47,6 +49,22 @@ class PlanAPI { static const std::string GenerateName(const std::string prefix, int id); }; +// Parse the input str and SQL type and convert to TypeNode representation +// +// unimplemnted, reserved for later usage +absl::StatusOr ParseType(absl::string_view, NodeManager*); + +// parse the input string as table elements and extract those element that is table_column_definition, +// then returns the corresponding proto representation. +// +// it expect input `str` joined every element by comma(,), then a CREATE TABLE SQL is created with the +// format of 'CREATE TABLE t1 ( {str} )'. +// SQL parse allows three kind of table element, which is: +// - table_column_definition +// - table_index_definition +// - table_constraint_definition +// while this method extract table_column_definition only +absl::StatusOr ParseTableColumSchema(absl::string_view str); } // namespace plan } // namespace hybridse #endif // HYBRIDSE_INCLUDE_PLAN_PLAN_API_H_ diff --git a/hybridse/src/case/sql_case.cc b/hybridse/src/case/sql_case.cc index 8af73741caa..eb4dfae181f 100644 --- a/hybridse/src/case/sql_case.cc +++ b/hybridse/src/case/sql_case.cc @@ -16,6 +16,7 @@ #include "case/sql_case.h" +#include #include #include #include @@ -23,7 +24,6 @@ #include #include -#include "absl/cleanup/cleanup.h" #include "absl/strings/ascii.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" @@ -34,7 +34,11 @@ #include "codec/fe_row_codec.h" #include "glog/logging.h" #include "node/sql_node.h" -#include "yaml-cpp/yaml.h" +#include "plan/plan_api.h" +#include "vm/engine.h" +#include "zetasql/parser/parser.h" +#include "planv2/ast_node_converter.h" +#include "vm/jit_wrapper.h" namespace hybridse { namespace sqlcase { @@ -44,6 +48,23 @@ static absl::Mutex mtx; // working directory, where the yaml file lives static std::filesystem::path working_dir; +static vm::HybridSeJitWrapper* createJitWrapper() { + hybridse::vm::Engine::InitializeGlobalLLVM(); + base::Status s; + auto wrapper = vm::HybridSeJitWrapper::CreateWithDefaultSymbols(&s); + if (!s.isOK()) { + LOG(ERROR) << "fail to get jit: " << s; + std::exit(1); + } + return wrapper; +} + +static vm::HybridSeJitWrapper* getJitWrapper() { + // ensuare call once by local static + static vm::HybridSeJitWrapper* wrapper = createJitWrapper(); + return wrapper; +} + bool SqlCase::TTLParse(const std::string& org_type_str, std::vector& ttls) { std::string type_str = org_type_str; @@ -271,6 +292,20 @@ bool SqlCase::ExtractSchema(const std::vector& columns, LOG(WARNING) << "Invalid Schema Format"; return false; } + // expect the string of table column schema: + // 1. {name}{type} + // legacy specification, issues may arrise when dealing with type contains options + // 2. {name}{type} + // favored specification, which is exactly the same as SQL table column schema, + // you may nest any composed type like map in {type} + + auto column_list = absl::StrJoin(columns, ","); + auto rs = plan::ParseTableColumSchema(column_list); + if (rs.ok()) { + table.mutable_columns()->CopyFrom(rs.value()); + return true; + } + // fallback legacy approach try { for (auto col : columns) { boost::trim(col); @@ -409,13 +444,15 @@ bool SqlCase::AddInput(const TableInfo& table_data) { return true; } bool SqlCase::ExtractInputData(std::vector& rows, - int32_t input_idx) const { - return ExtractInputData(inputs_[input_idx], rows); + int32_t input_idx, + const codec::Schema& sc) const { + return ExtractInputData(inputs_[input_idx], rows, sc); } bool SqlCase::ExtractInputData(const TableInfo& input, - std::vector& rows) const { + std::vector& rows, + const codec::Schema& sc) const { try { - if (input.data_.empty() && input.rows_.empty()) { + if (input.data_.empty() && input.rows_.empty() && input.inserts_.empty()) { LOG(WARNING) << "Empty Data String"; return false; } @@ -425,7 +462,20 @@ bool SqlCase::ExtractInputData(const TableInfo& input, return false; } - if (!input.data_.empty()) { + if (!input.inserts_.empty()) { + for (auto& sql : input.inserts_) { + // shit happens, resue jit cause duplcate symbols + // FIXME(#3748): use getJitWrapper + auto jit = std::unique_ptr(createJitWrapper()); + auto rs = ExtractInsertRow(jit.get(), sql, &sc); + if (!rs.ok()) { + LOG(ERROR) << rs.status(); + return false; + } + + rows.insert(rows.end(), rs.value().begin(), rs.value().end()); + } + } else if (!input.data_.empty()) { if (!ExtractRows(table.columns(), input.data_, rows)) { return false; } @@ -1713,5 +1763,48 @@ std::string SqlCase::SqlCaseBaseDir() { } return ""; } + +absl::StatusOr> ExtractInsertRow(vm::HybridSeJitWrapper* jit, absl::string_view insert, + const codec::Schema* table_schema) { + zetasql::ParserOptions parser_opts; + zetasql::LanguageOptions language_opts; + language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE); + parser_opts.set_language_options(&language_opts); + std::unique_ptr parser_output; + auto zetasql_status = zetasql::ParseStatement(insert, parser_opts, &parser_output); + CHECK_ABSL_STATUS(zetasql_status); + + node::SqlNode* sql_node = nullptr; + node::NodeManager nm; + CHECK_STATUS_TO_ABSL(plan::ConvertStatement(parser_output->statement(), &nm, &sql_node)); + + auto* insert_stmt = sql_node->GetAsOrNull(); + if (insert_stmt == nullptr) { + return absl::FailedPreconditionError("not a insert statement"); + } + + if (!insert_stmt->columns_.empty()) { + // implementation limitation + return absl::UnimplementedError("insert with custom columns not support"); + } + + codec::RowBuilder2 builder(jit, std::vector{*table_schema}); + + CHECK_STATUS_TO_ABSL(builder.Init()); + + std::vector rows; + for (auto expr : insert_stmt->values_) { + auto expr_list = expr->GetAsOrNull(); + if (expr_list == nullptr) { + return absl::FailedPreconditionError( + absl::Substitute("unexpected insert statement value: $0", expr->GetExprString())); + } + codec::Row row; + CHECK_STATUS_TO_ABSL(builder.Build(expr_list->children_, &row)); + rows.push_back(row); + } + + return rows; +} } // namespace sqlcase } // namespace hybridse diff --git a/hybridse/src/case/sql_case_test.cc b/hybridse/src/case/sql_case_test.cc index 5e9dc333da2..5e70251e37b 100644 --- a/hybridse/src/case/sql_case_test.cc +++ b/hybridse/src/case/sql_case_test.cc @@ -538,10 +538,10 @@ TEST_F(SqlCaseTest, ExtractSqlCase) { // Check Data { type::TableDef output_table; + ASSERT_TRUE(sql_case.ExtractInputTableDef(output_table)); std::vector rows; - ASSERT_TRUE(sql_case.ExtractInputData(rows)); + ASSERT_TRUE(sql_case.ExtractInputData(rows, 0, output_table.columns())); ASSERT_EQ(5u, rows.size()); - sql_case.ExtractInputTableDef(output_table); hybridse::codec::RowView row_view(output_table.columns()); { diff --git a/hybridse/src/codec/fe_row_codec.cc b/hybridse/src/codec/fe_row_codec.cc index e92c66bbd83..f761c28fbe4 100644 --- a/hybridse/src/codec/fe_row_codec.cc +++ b/hybridse/src/codec/fe_row_codec.cc @@ -19,12 +19,12 @@ #include #include +#include "absl/strings/str_join.h" #include "codec/type_codec.h" #include "gflags/gflags.h" #include "codegen/insert_row_builder.h" #include "glog/logging.h" #include "proto/fe_common.pb.h" -#include "vm/engine.h" DECLARE_bool(enable_spark_unsaferow_format); @@ -1103,7 +1103,11 @@ base::Status RowBuilder2::Build(const std::vector& values, code auto expect_cols = std::accumulate(schemas_.begin(), schemas_.end(), 0, [](int val, const auto& e) { return val + e.size(); }); CHECK_TRUE(values.size() == expect_cols, common::kCodegenEncodeError, "pass in expr number do not match, expect ", - expect_cols, " but got ", values.size()); + expect_cols, " but got ", values.size(), ": (", + absl::StrJoin( + values, ", ", + [](std::string* out, const node::ExprNode* expr) { absl::StrAppend(out, expr->GetExprString()); }), + ")"); int col_idx = 0; Row row; diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc index 5b1d18e5973..77689a61927 100644 --- a/hybridse/src/node/node_manager.cc +++ b/hybridse/src/node/node_manager.cc @@ -439,7 +439,7 @@ CreateStmt *NodeManager::MakeCreateTableNode(bool op_if_not_exist, const std::st const std::string &table_name, SqlNodeList *column_desc_list, SqlNodeList *table_option_list) { CreateStmt *node_ptr = new CreateStmt(db_name, table_name, op_if_not_exist); - FillSqlNodeList2NodeVector(column_desc_list, *(node_ptr->MutableColumnDefList())); + FillSqlNodeList2NodeVector(column_desc_list, *(node_ptr->MutableTableElementList())); FillSqlNodeList2NodeVector(table_option_list, *(node_ptr->MutableTableOptionList())); return RegisterNode(node_ptr); } diff --git a/hybridse/src/node/sql_node.cc b/hybridse/src/node/sql_node.cc index 478805c5f05..3fc8c067ca6 100644 --- a/hybridse/src/node/sql_node.cc +++ b/hybridse/src/node/sql_node.cc @@ -23,6 +23,7 @@ #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -2784,5 +2785,89 @@ void SetOperationNode::Print(std::ostream &output, const std::string &org_tab) c PrintSqlNode(output, org_tab + INDENT + INDENT, node, std::to_string(i), i + 1 == inputs().size()); } } + +absl::StatusOr CreateStmt::GetColumnDefListAsSchema() const { + codec::Schema sc; + + for (auto col : GetTableElementList()) { + auto *col_def = col->GetAsOrNull(); + if (col_def == nullptr) { + continue; + } + + CHECK_ABSL_STATUS(col_def->GetProtoColumnDef(sc.Add())); + } + + return sc; +} +absl::Status ColumnDefNode::GetProtoColumnDef(type::ColumnDef* def) const { + def->set_name(GetColumnName()); + def->set_is_not_null(GetIsNotNull()); + + CHECK_ABSL_STATUS(schema()->GetProtoColumnSchema(def->mutable_schema())); + + if (def->schema().has_base_type()) { + def->set_type(def->schema().base_type()); + } + + return absl::OkStatus(); +} + +absl::Status ColumnSchemaNode::GetProtoColumnSchema(type::ColumnSchema* sc_alloca) const { + sc_alloca->set_is_not_null(not_null()); + + switch (type()) { + case hybridse::node::kBool: + sc_alloca->set_base_type(type::kBool); + break; + case hybridse::node::kInt16: + sc_alloca->set_base_type(type::kInt16); + break; + case hybridse::node::kInt32: + sc_alloca->set_base_type(type::kInt32); + break; + case hybridse::node::kInt64: + sc_alloca->set_base_type(type::kInt64); + break; + case hybridse::node::kFloat: + sc_alloca->set_base_type(type::kFloat); + break; + case hybridse::node::kDouble: + sc_alloca->set_base_type(type::kDouble); + break; + case hybridse::node::kDate: + sc_alloca->set_base_type(type::kDate); + break; + case hybridse::node::kTimestamp: + sc_alloca->set_base_type(type::kTimestamp); + break; + case hybridse::node::kVarchar: + sc_alloca->set_base_type(type::kVarchar); + break; + case hybridse::node::kArray: { + if (generics().size() != 1) { + return absl::FailedPreconditionError( + absl::Substitute("expect generic size = 1 for array type, but got $0", generics().size())); + } + auto* arr = sc_alloca->mutable_array_type(); + CHECK_ABSL_STATUS(generics_[0]->GetProtoColumnSchema(arr->mutable_ele_type())); + break; + } + case hybridse::node::kMap: { + if (generics().size() != 2) { + return absl::FailedPreconditionError( + absl::Substitute("expect generic size = 2 for map type, but got $0", generics().size())); + } + auto* mp = sc_alloca->mutable_map_type(); + CHECK_ABSL_STATUS(generics_[0]->GetProtoColumnSchema(mp->mutable_key_type())); + CHECK_ABSL_STATUS(generics_[1]->GetProtoColumnSchema(mp->mutable_value_type())); + break; + } + default: + return absl::UnimplementedError(absl::StrCat("unsupported type: ", DebugString())); + } + + return absl::OkStatus(); +} } // namespace node } // namespace hybridse diff --git a/hybridse/src/passes/physical/batch_request_optimize_test.cc b/hybridse/src/passes/physical/batch_request_optimize_test.cc index 48259b68ed4..4c910e9e358 100644 --- a/hybridse/src/passes/physical/batch_request_optimize_test.cc +++ b/hybridse/src/passes/physical/batch_request_optimize_test.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "testing/engine_test_base.h" #include "vm/sql_compiler.h" +#include "vm/engine.h" namespace hybridse { namespace vm { @@ -253,7 +254,6 @@ TEST_P(BatchRequestOptimizeTest, test_with_common_columns) { int main(int argc, char** argv) { ::testing::GTEST_FLAG(color) = "yes"; ::testing::InitGoogleTest(&argc, argv); - InitializeNativeTarget(); - InitializeNativeTargetAsmPrinter(); + ::hybridse::vm::Engine::InitializeGlobalLLVM(); return RUN_ALL_TESTS(); } diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index a6656dfc0f3..8dd57b6f9e5 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -452,9 +452,9 @@ base::Status Planner::CreateSetPlanNode(const node::SetNode *root, node::PlanNod base::Status Planner::CreateCreateTablePlan(const node::SqlNode *root, node::PlanNode **output) { CHECK_TRUE(nullptr != root, common::kPlanError, "fail to create table plan with null node") auto create_tree = dynamic_cast(root); - auto* out = node_manager_->MakeCreateTablePlanNode(create_tree->GetDbName(), create_tree->GetTableName(), - create_tree->GetColumnDefList(), create_tree->GetTableOptionList(), - create_tree->GetOpIfNotExist()); + auto *out = node_manager_->MakeCreateTablePlanNode( + create_tree->GetDbName(), create_tree->GetTableName(), create_tree->GetTableElementList(), + create_tree->GetTableOptionList(), create_tree->GetOpIfNotExist()); out->like_clause_ = create_tree->like_clause_; *output = out; return base::Status::OK(); diff --git a/hybridse/src/planv2/ast_node_converter.h b/hybridse/src/planv2/ast_node_converter.h index a40bacc2e10..631569156d2 100644 --- a/hybridse/src/planv2/ast_node_converter.h +++ b/hybridse/src/planv2/ast_node_converter.h @@ -29,6 +29,8 @@ namespace plan { // ======================================================================================= // base::Status ConvertASTScript(const zetasql::ASTScript* body, node::NodeManager* node_manager, node::SqlNodeList** output); +base::Status ConvertStatement(const zetasql::ASTStatement* stmt, node::NodeManager* node_manager, + node::SqlNode** output); // ======================================================================================= // // all interfaces below not consider public, which might moved later @@ -36,8 +38,6 @@ base::Status ConvertASTScript(const zetasql::ASTScript* body, node::NodeManager* base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node::NodeManager* node_manager, node::ExprNode** output); -base::Status ConvertStatement(const zetasql::ASTStatement* stmt, node::NodeManager* node_manager, - node::SqlNode** output); base::Status ConvertOrderBy(const zetasql::ASTOrderBy* order_by, node::NodeManager* node_manager, node::OrderByNode** output); diff --git a/hybridse/src/planv2/plan_api.cc b/hybridse/src/planv2/plan_api.cc index d3f8f7644bf..6f3cdfe5871 100644 --- a/hybridse/src/planv2/plan_api.cc +++ b/hybridse/src/planv2/plan_api.cc @@ -15,6 +15,8 @@ */ #include "plan/plan_api.h" +#include "absl/strings/substitute.h" +#include "planv2/ast_node_converter.h" #include "planv2/planner_v2.h" #include "zetasql/parser/parser.h" #include "zetasql/public/error_helpers.h" @@ -87,5 +89,29 @@ const std::string PlanAPI::GenerateName(const std::string prefix, int id) { return name; } +absl::StatusOr ParseTableColumSchema(absl::string_view str) { + zetasql::ParserOptions parser_opts; + zetasql::LanguageOptions language_opts; + language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE); + parser_opts.set_language_options(&language_opts); + std::unique_ptr parser_output; + auto sql = absl::Substitute("CREATE TABLE t1 ($0)", str); + auto zetasql_status = zetasql::ParseStatement(sql, parser_opts, &parser_output); + if (!zetasql_status.ok()) { + return zetasql_status; + } + + node::SqlNode *sql_node = nullptr; + node::NodeManager nm; + CHECK_STATUS_TO_ABSL(ConvertStatement(parser_output->statement(), &nm, &sql_node)); + + auto* create = sql_node->GetAsOrNull(); + if (create == nullptr) { + return absl::FailedPreconditionError("not a create table statement"); + } + + return create->GetColumnDefListAsSchema(); +} + } // namespace plan } // namespace hybridse diff --git a/hybridse/src/testing/engine_test_base.cc b/hybridse/src/testing/engine_test_base.cc index 3aebea8f2de..f55b5454339 100644 --- a/hybridse/src/testing/engine_test_base.cc +++ b/hybridse/src/testing/engine_test_base.cc @@ -14,17 +14,36 @@ * limitations under the License. */ #include "testing/engine_test_base.h" + +#include + +#include "base/texttable.h" +#include "plan/plan_api.h" +#include "boost/algorithm/string.hpp" #include "vm/sql_compiler.h" +#include "google/protobuf/util/message_differencer.h" namespace hybridse { namespace vm { bool IsNaN(float x) { return x != x; } bool IsNaN(double x) { return x != x; } -void CheckSchema(const vm::Schema& schema, const vm::Schema& exp_schema) { +void CheckSchema(const codec::Schema& schema, const codec::Schema& exp_schema) { ASSERT_EQ(schema.size(), exp_schema.size()); + ::google::protobuf::util::MessageDifferencer differ; + // approximate equal for float values + differ.set_float_comparison(::google::protobuf::util::MessageDifferencer::FloatComparison::APPROXIMATE); + // equivalent avoid the issue that some optional bool fields that may contains a default value + differ.set_message_field_comparison( + ::google::protobuf::util::MessageDifferencer::MessageFieldComparison::EQUIVALENT); for (int i = 0; i < schema.size(); i++) { - ASSERT_EQ(schema.Get(i).DebugString(), exp_schema.Get(i).DebugString()) << "Fail column type at " << i; + std::string diff_str; + differ.ReportDifferencesToString(&diff_str); + ASSERT_TRUE(differ.Compare(schema.Get(i), exp_schema.Get(i))) + << "Fail column type at " << i + << "\ngot: " << schema.Get(i).ShortDebugString() + << "\nbut expect: " << exp_schema.Get(i).ShortDebugString() + << "\ndifference: " << diff_str; } } diff --git a/hybridse/src/testing/engine_test_base.h b/hybridse/src/testing/engine_test_base.h index 0805ff1b3c5..24b66645208 100644 --- a/hybridse/src/testing/engine_test_base.h +++ b/hybridse/src/testing/engine_test_base.h @@ -19,36 +19,15 @@ #include #include #include -#include #include #include #include -#include #include -#include "base/texttable.h" -#include "boost/algorithm/string.hpp" #include "case/sql_case.h" #include "codec/fe_row_codec.h" #include "codec/fe_row_selector.h" -#include "codec/list_iterator_codec.h" -#include "gflags/gflags.h" #include "gtest/gtest.h" -#include "gtest/internal/gtest-param-util.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" -#include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/GVN.h" -#include "plan/plan_api.h" -#include "sys/time.h" #include "vm/engine.h" #include "testing/test_base.h" #define MAX_DEBUG_LINES_CNT 20 @@ -134,6 +113,8 @@ class EngineTestRunner { int return_code() const { return return_code_; } const SqlCase& sql_case() const { return sql_case_; } + virtual const type::TableDef* GetTableDef(absl::string_view, absl::string_view) = 0; + void RunCheck(); void RunBenchmark(size_t iters); @@ -157,8 +138,15 @@ class BatchEngineTestRunner : public EngineTestRunner { Status PrepareData() override { for (int32_t i = 0; i < sql_case_.CountInputs(); i++) { auto input = sql_case_.inputs()[i]; + std::string table_name = sql_case_.inputs_[i].name_; + std::string table_db_name = sql_case_.inputs_[i].db_.empty() ? sql_case_.db() : sql_case_.inputs_[i].db_; + auto table_def = GetTableDef(table_db_name, table_name); + if (!table_def) { + CHECK_TRUE(false, common::kTableNotFound, "table ", table_name, " not exist"); + } + std::vector rows; - sql_case_.ExtractInputData(rows, i); + sql_case_.ExtractInputData(rows, i, table_def->columns()); size_t repeat = sql_case_.inputs()[i].repeat_; if (repeat > 1) { size_t row_num = rows.size(); @@ -171,9 +159,6 @@ class BatchEngineTestRunner : public EngineTestRunner { } } if (!rows.empty()) { - std::string table_name = sql_case_.inputs_[i].name_; - std::string table_db_name = - sql_case_.inputs_[i].db_.empty() ? sql_case_.db() : sql_case_.inputs_[i].db_; CHECK_TRUE(AddRowsIntoTable(table_db_name, table_name, rows), common::kTablePutFailed, "Fail to add rows into table ", table_name); @@ -211,31 +196,40 @@ class RequestEngineTestRunner : public EngineTestRunner { auto request_session = std::dynamic_pointer_cast(session_); CHECK_TRUE(request_session != nullptr, common::kNullPointer); - std::string request_name = request_session->GetRequestName(); + std::string request_table = request_session->GetRequestName(); std::string request_db_name = request_session->GetRequestDbName().empty() ? sql_case_.db() : request_session->GetRequestDbName(); + auto request_table_def = GetTableDef(request_db_name, request_table); + if (!request_table.empty() && !request_table_def) { + CHECK_TRUE(false, common::kTableNotFound, "table ", request_table, " not exist"); + } if (has_batch_request) { CHECK_TRUE(1 <= sql_case_.batch_request_.rows_.size(), common::kSqlCaseError, "RequestEngine can't handler emtpy rows batch requests"); - CHECK_TRUE(sql_case_.ExtractInputData(sql_case_.batch_request_, - request_rows_), - common::kSqlCaseError, "Extract case request rows failed"); + CHECK_TRUE( + sql_case_.ExtractInputData(sql_case_.batch_request_, request_rows_, request_table_def->columns()), + common::kSqlCaseError, "Extract case request rows failed"); } for (int32_t i = 0; i < sql_case_.CountInputs(); i++) { - std::string input_name = sql_case_.inputs_[i].name_; + std::string table_name = sql_case_.inputs_[i].name_; std::string table_db_name = sql_case_.inputs_[i].db_.empty() ? sql_case_.db() : sql_case_.inputs_[i].db_; - if ((table_db_name == request_db_name) && (input_name == request_name) && !has_batch_request) { - CHECK_TRUE(sql_case_.ExtractInputData(request_rows_, i), + auto table_def = GetTableDef(table_db_name, table_name); + if (!table_def) { + CHECK_TRUE(false, common::kTableNotFound, "table ", table_name, " not exist"); + } + + if ((table_db_name == request_db_name) && (table_name == request_table) && !has_batch_request) { + CHECK_TRUE(sql_case_.ExtractInputData(request_rows_, i, table_def->columns()), common::kSqlCaseError, "Extract case request rows failed"); continue; } else { std::vector rows; if (!sql_case_.inputs_[i].rows_.empty() || !sql_case_.inputs_[i].data_.empty()) { - CHECK_TRUE(sql_case_.ExtractInputData(rows, i), common::kSqlCaseError, + CHECK_TRUE(sql_case_.ExtractInputData(rows, i, table_def->columns()), common::kSqlCaseError, "Extract case request rows failed"); } @@ -247,14 +241,14 @@ class RequestEngineTestRunner : public EngineTestRunner { store_rows.push_back(row); } } - CHECK_TRUE(AddRowsIntoTable(table_db_name, input_name, store_rows), + CHECK_TRUE(AddRowsIntoTable(table_db_name, table_name, store_rows), common::kTablePutFailed, - "Fail to add rows into table ", input_name); + "Fail to add rows into table ", table_name); } else { - CHECK_TRUE(AddRowsIntoTable(table_db_name, input_name, rows), + CHECK_TRUE(AddRowsIntoTable(table_db_name, table_name, rows), common::kTablePutFailed, - "Fail to add rows into table ", input_name); + "Fail to add rows into table ", table_name); } } } @@ -329,11 +323,14 @@ class BatchRequestEngineTestRunner : public EngineTestRunner { for (int32_t i = 0; i < sql_case_.CountInputs(); i++) { auto input = sql_case_.inputs()[i]; std::vector rows; - sql_case_.ExtractInputData(rows, i); + std::string table_name = sql_case_.inputs_[i].name_; + std::string table_db_name = sql_case_.inputs_[i].db_.empty() ? sql_case_.db() : sql_case_.inputs_[i].db_; + auto table_def = GetTableDef(table_db_name, table_name); + if (!table_def) { + CHECK_TRUE(false, common::kTableNotFound, "table ", table_name, " not exist"); + } + sql_case_.ExtractInputData(rows, i, table_def->columns()); if (!rows.empty()) { - std::string table_name = sql_case_.inputs_[i].name_; - std::string table_db_name = - sql_case_.inputs_[i].db_.empty() ? sql_case_.db() : sql_case_.inputs_[i].db_; if ((table_db_name == request_db_name && table_name == request_name) && !has_batch_request) { original_request_data.push_back(rows.back()); diff --git a/hybridse/src/vm/engine_compile_test.cc b/hybridse/src/vm/engine_compile_test.cc index b4a7c715f9b..e30cc569f15 100644 --- a/hybridse/src/vm/engine_compile_test.cc +++ b/hybridse/src/vm/engine_compile_test.cc @@ -713,8 +713,7 @@ TEST_F(EngineCompileTest, ExternalFunctionTest) { } // namespace hybridse int main(int argc, char** argv) { - InitializeNativeTarget(); - InitializeNativeTargetAsmPrinter(); + ::hybridse::vm::Engine::InitializeGlobalLLVM(); ::testing::InitGoogleTest(&argc, argv); // ::hybridse::vm::CoreAPI::EnableSignalTraceback(); return RUN_ALL_TESTS(); diff --git a/hybridse/src/vm/jit_wrapper.cc b/hybridse/src/vm/jit_wrapper.cc index 876df037ef0..8ac09391a10 100644 --- a/hybridse/src/vm/jit_wrapper.cc +++ b/hybridse/src/vm/jit_wrapper.cc @@ -99,6 +99,9 @@ HybridSeJitWrapper* HybridSeJitWrapper::CreateWithDefaultSymbols(udf::UdfLibrary } return jit; } +HybridSeJitWrapper* HybridSeJitWrapper::CreateWithDefaultSymbols(base::Status* status, const JitOptions& options) { + return CreateWithDefaultSymbols(udf::DefaultUdfLibrary::get(), status, options); +} HybridSeJitWrapper* HybridSeJitWrapper::Create(const JitOptions& jit_options) { if (jit_options.IsEnableMcjit()) { diff --git a/hybridse/src/vm/jit_wrapper.h b/hybridse/src/vm/jit_wrapper.h index f7a578a4306..cad22888982 100644 --- a/hybridse/src/vm/jit_wrapper.h +++ b/hybridse/src/vm/jit_wrapper.h @@ -55,6 +55,7 @@ class HybridSeJitWrapper { // create the JIT wrapper with default builtin symbols imported already static HybridSeJitWrapper* CreateWithDefaultSymbols(udf::UdfLibrary*, base::Status*, const JitOptions& jit_options = {}); + static HybridSeJitWrapper* CreateWithDefaultSymbols(base::Status*, const JitOptions& jit_options = {}); static HybridSeJitWrapper* Create(const JitOptions& jit_options); static HybridSeJitWrapper* Create(); diff --git a/src/sdk/mini_cluster_bm.cc b/src/sdk/mini_cluster_bm.cc index 7f771ed3e3c..605e2e1a381 100644 --- a/src/sdk/mini_cluster_bm.cc +++ b/src/sdk/mini_cluster_bm.cc @@ -77,7 +77,7 @@ void BM_RequestQuery(benchmark::State& state, hybridse::sqlcase::SqlCase& sql_ca } std::vector request_rows; - if (!sql_case.ExtractInputData(sql_case.batch_request_, request_rows)) { + if (!sql_case.ExtractInputData(sql_case.batch_request_, request_rows, request_table.columns())) { state.SkipWithError("benchmark error: hybridse case input data invalid"); return; } @@ -243,7 +243,7 @@ void BM_BatchRequestQuery(benchmark::State& state, hybridse::sqlcase::SqlCase& s } std::vector request_rows; - if (!sql_case.ExtractInputData(sql_case.batch_request_, request_rows)) { + if (!sql_case.ExtractInputData(sql_case.batch_request_, request_rows, request_table.columns())) { state.SkipWithError("benchmark error: hybridse case input data invalid"); return; } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index ac5b8ec5153..b302c080846 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -43,6 +43,7 @@ #include "brpc/channel.h" #include "cmd/display.h" #include "codec/encrypt.h" +#include "codegen/insert_row_builder.h" #include "common/timer.h" #include "glog/logging.h" #include "nameserver/system_table.h" @@ -62,7 +63,6 @@ #include "sdk/split.h" #include "udf/udf.h" #include "vm/catalog.h" -#include "codegen/insert_row_builder.h" DECLARE_string(bucket_size); DECLARE_uint32(replica_num); diff --git a/src/sdk/sql_sdk_base_test.cc b/src/sdk/sql_sdk_base_test.cc index d66ecc8c75f..4df7816c690 100644 --- a/src/sdk/sql_sdk_base_test.cc +++ b/src/sdk/sql_sdk_base_test.cc @@ -16,7 +16,6 @@ #include "sdk/sql_sdk_base_test.h" -#include "absl/random/random.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_replace.h" @@ -411,11 +410,11 @@ void SQLSDKQueryTest::RequestExecuteSQL(hybridse::sqlcase::SqlCase& sql_case, / if (!sql_case.inputs().empty()) { if (!has_batch_request) { ASSERT_TRUE(sql_case.ExtractInputTableDef(insert_table, 0)); - ASSERT_TRUE(sql_case.ExtractInputData(insert_rows, 0)); + ASSERT_TRUE(sql_case.ExtractInputData(insert_rows, 0, insert_table.columns())); sql_case.BuildInsertSqlListFromInput(0, &inserts); } else { ASSERT_TRUE(sql_case.ExtractInputTableDef(sql_case.batch_request_, insert_table)); - ASSERT_TRUE(sql_case.ExtractInputData(sql_case.batch_request_, insert_rows)); + ASSERT_TRUE(sql_case.ExtractInputData(sql_case.batch_request_, insert_rows, insert_table.columns())); } CheckSchema(insert_table.columns(), *(request_row->GetSchema().get())); DLOG(INFO) << "Request Row:\n"; @@ -571,7 +570,7 @@ void SQLSDKQueryTest::BatchRequestExecuteSQLWithCommonColumnIndices(hybridse::sq } else { ASSERT_TRUE(sql_case.ExtractInputTableDef(sql_case.inputs_[0], batch_request_table)); std::vector rows; - ASSERT_TRUE(sql_case.ExtractInputData(sql_case.inputs_[0], rows)); + ASSERT_TRUE(sql_case.ExtractInputData(sql_case.inputs_[0], rows, batch_request_table.columns())); request_rows.push_back(rows.back()); } CheckSchema(batch_request_table.columns(), *(request_row->GetSchema().get())); @@ -798,7 +797,7 @@ void DeploymentEnv::CallDeployProcedure() const { std::vector insert_rows; std::vector inserts; ASSERT_TRUE(sql_case_->ExtractInputTableDef(insert_table, 0)); - ASSERT_TRUE(sql_case_->ExtractInputData(insert_rows, 0)); + ASSERT_TRUE(sql_case_->ExtractInputData(insert_rows, 0, insert_table.columns())); sql_case_->BuildInsertSqlListFromInput(0, &inserts); test::SQLCaseTest::CheckSchema(insert_table.columns(), *(request_row->GetSchema().get())); LOG(INFO) << "Request Row:\n"; @@ -849,7 +848,7 @@ void DeploymentEnv::CallDeployProcedureTiny() const { hybridse::type::TableDef insert_table; std::vector insert_rows; ASSERT_TRUE(sql_case_->ExtractInputTableDef(insert_table, 0)); - ASSERT_TRUE(sql_case_->ExtractInputData(insert_rows, 0)); + ASSERT_TRUE(sql_case_->ExtractInputData(insert_rows, 0, insert_table.columns())); hybridse::codec::RowView row_view(insert_table.columns()); for (size_t i = 0; i < insert_rows.size(); i++) {