Skip to content

Commit

Permalink
feat: map type in yaml testing framework
Browse files Browse the repository at this point in the history
allow map data type in YAML testing framework for input definition, you
may define input table's schema as:

  columns: ["col1 int", "col2 map<int, string>"]

which is exactly the same as the SQL syntax in CREATE TABLE statement.

And define table values with `inputs[*].inserts`:

  inserts:
    - insert into t1 values (1, map(12, "abc"))

just write down the insert statement SQL.

*LIMITATIONS*

- unsupported: map data type in `expect` field
- INSERT statement with custom columns definition, implementation
  limits, each insert value expression must matches exactly to the table column definition
  • Loading branch information
aceforeverd committed Feb 21, 2024
1 parent 395ab2b commit 7a9c66b
Show file tree
Hide file tree
Showing 23 changed files with 398 additions and 84 deletions.
19 changes: 19 additions & 0 deletions cases/query/udf_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,22 @@ cases:
sql: |
select c1 + 8 from (select 9 as c1)
- id: 18
mode: request-unsupport
inputs:
- name: t1
columns: ["col1 int32", "std_ts timestamp", "col2 map<int,string>"]
indexs: ["index1:col1:std_ts"]
# insert map values by insert stmts only
inserts:
- "insert into t1 values (1, timestamp(1000), map(12, 'abc'))"
- "insert into t1 values (2, timestamp(1000), map(13, 'abc'))"
sql: |
select col1, col2[12] as out from t1;
expect:
order: col1
columns: ["col1 int", "out string"]
data: |
1, abc
2, null
4 changes: 1 addition & 3 deletions hybridse/examples/toydb/src/cmd/toydb_run_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>

#include "absl/strings/match.h"
#include "testing/toydb_engine_test_base.h"
Expand Down Expand Up @@ -104,8 +103,7 @@ int RunSingle(const std::string& yaml_path) {

int main(int argc, char** argv) {
::google::ParseCommandLineFlags(&argc, &argv, false);
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
::hybridse::vm::Engine::InitializeGlobalLLVM();
if (FLAGS_yaml_path != "") {
return ::hybridse::vm::RunSingle(FLAGS_yaml_path);
} else {
Expand Down
5 changes: 2 additions & 3 deletions hybridse/examples/toydb/src/testing/toydb_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

#include "gtest/gtest.h"
#include "gtest/internal/gtest-param-util.h"
#include "vm/engine.h"
#include "testing/toydb_engine_test_base.h"

using namespace llvm; // NOLINT (build/namespaces)
Expand Down Expand Up @@ -126,8 +126,7 @@ TEST_P(BatchRequestEngineTest, TestClusterBatchRequestEngine) {
} // namespace hybridse

int main(int argc, char** argv) {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
::hybridse::vm::Engine::InitializeGlobalLLVM();
::testing::InitGoogleTest(&argc, argv);
// ::hybridse::vm::CoreAPI::EnableSignalTraceback();
return RUN_ALL_TESTS();
Expand Down
27 changes: 27 additions & 0 deletions hybridse/examples/toydb/src/testing/toydb_engine_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <utility>
#include <vector>

#include "absl/strings/substitute.h"
#include "case/case_data_mock.h"
#include "case/sql_case.h"
#include "glog/logging.h"
Expand Down Expand Up @@ -98,6 +99,14 @@ class ToydbBatchEngineTestRunner : public BatchEngineTestRunner {
CheckSqliteCompatible(sql_case_, GetSession()->GetSchema(), output_rows);
}
}
const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
Expand Down Expand Up @@ -143,6 +152,15 @@ class ToydbRequestEngineTestRunner : public RequestEngineTestRunner {
return table->Put(reinterpret_cast<char*>(row.buf()), row.size());
}

const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
std::map<std::pair<std::string, std::string>, std::shared_ptr<::hybridse::storage::Table>> name_table_map_;
Expand Down Expand Up @@ -189,6 +207,15 @@ class ToydbBatchRequestEngineTestRunner : public BatchRequestEngineTestRunner {
return table->Put(reinterpret_cast<char*>(row.buf()), row.size());
}

const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
std::map<std::pair<std::string, std::string>, std::shared_ptr<::hybridse::storage::Table>> name_table_map_;
Expand Down
12 changes: 8 additions & 4 deletions hybridse/include/case/sql_case.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ class SqlCase {
const codec::Schema ExtractParameterTypes() const;

bool ExtractInputData(std::vector<hybridse::codec::Row>& rows, // NOLINT
int32_t input_idx = 0) const;
bool ExtractInputData(
const TableInfo& info,
std::vector<hybridse::codec::Row>& rows) const; // NOLINT
int32_t input_idx, const codec::Schema& sc) const;
bool ExtractInputData(const TableInfo& info,
std::vector<hybridse::codec::Row>& rows, // NOLINT
const codec::Schema&) const;
bool ExtractOutputData(
std::vector<hybridse::codec::Row>& rows) const; // NOLINT

Expand Down Expand Up @@ -331,6 +331,10 @@ std::vector<SqlCase> InitCases(std::string yaml_path, std::vector<std::string> f
void InitCases(std::string yaml_path, std::vector<SqlCase>& cases); // NOLINT
void InitCases(std::string yaml_path, std::vector<SqlCase>& cases, // NOLINT
const std::vector<std::string>& filters);

// TODO(someone): consider move the function to production code so others will take usage.
absl::StatusOr<std::vector<codec::Row>> ExtractInsertRow(vm::HybridSeJitWrapper*, absl::string_view insert,
const codec::Schema* table_schema);
} // namespace sqlcase
} // namespace hybridse
#endif // HYBRIDSE_INCLUDE_CASE_SQL_CASE_H_
27 changes: 25 additions & 2 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,22 @@ class SqlNode : public NodeBase<SqlNode> {

bool Equals(const SqlNode *node) const override;

// Return this node cast as a NodeType.
// Use only when this node is known to be that type, otherwise, behavior is undefined.
template <typename NodeType>
const NodeType *GetAsOrNull() const {
static_assert(std::is_base_of<SqlNode, NodeType>::value,
"NodeType must be a member of the SqlNode class hierarchy");
return dynamic_cast<const NodeType *>(this);
}

template <typename NodeType>
NodeType *GetAsOrNull() {
static_assert(std::is_base_of<SqlNode, NodeType>::value,
"NodeType must be a member of the SqlNode class hierarchy");
return dynamic_cast<NodeType *>(this);
}

SqlNodeType type_;

private:
Expand Down Expand Up @@ -1883,6 +1899,8 @@ class ColumnSchemaNode : public SqlNode {
bool not_null() const { return not_null_; }
const ExprNode *default_value() const { return default_value_; }

absl::Status GetProtoColumnSchema(type::ColumnSchema *) const;

std::string DebugString() const;

private:
Expand All @@ -1902,6 +1920,8 @@ class ColumnDefNode : public SqlNode {

const ColumnSchemaNode *schema() const { return schema_; }

absl::Status GetProtoColumnDef(type::ColumnDef *) const;

// deprecated, use ColumnDefNode::schema instead
DataType GetColumnType() const { return schema_->type(); }

Expand Down Expand Up @@ -2025,8 +2045,11 @@ class CreateStmt : public SqlNode {

~CreateStmt() {}

NodePointVector* MutableColumnDefList() { return &column_desc_list_; }
const NodePointVector &GetColumnDefList() const { return column_desc_list_; }
NodePointVector* MutableTableElementList() { return &column_desc_list_; }
const NodePointVector &GetTableElementList() const { return column_desc_list_; }

// collect the column definitions in column_desc_list_, and convert to the proto representation type::ColumnDef
absl::StatusOr<codec::Schema> GetColumnDefListAsSchema() const;

std::string GetTableName() const { return table_name_; }
std::string GetDbName() const { return db_name_; }
Expand Down
18 changes: 18 additions & 0 deletions hybridse/include/plan/plan_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using hybridse::base::Status;
using hybridse::node::NodeManager;
using hybridse::node::NodePointVector;
using hybridse::node::PlanNodeList;

// TODO(someone): rm class PlanAPI
class PlanAPI {
public:
// parse SQL string to logic plan. ASTNode and LogicNode saved in SqlContext
Expand All @@ -47,6 +49,22 @@ class PlanAPI {
static const std::string GenerateName(const std::string prefix, int id);
};

// Parse the input str and SQL type and convert to TypeNode representation
//
// unimplemnted, reserved for later usage
absl::StatusOr<node::TypeNode*> ParseType(absl::string_view, NodeManager*);

// parse the input string as table elements and extract those element that is table_column_definition,
// then returns the corresponding proto representation.
//
// it expect input `str` joined every element by comma(,), then a CREATE TABLE SQL is created with the
// format of 'CREATE TABLE t1 ( {str} )'.
// SQL parse allows three kind of table element, which is:
// - table_column_definition
// - table_index_definition
// - table_constraint_definition
// while this method extract table_column_definition only
absl::StatusOr<codec::Schema> ParseTableColumSchema(absl::string_view str);
} // namespace plan
} // namespace hybridse
#endif // HYBRIDSE_INCLUDE_PLAN_PLAN_API_H_
107 changes: 100 additions & 7 deletions hybridse/src/case/sql_case.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

#include "case/sql_case.h"

#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <optional>
#include <set>
#include <string>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/strings/ascii.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
Expand All @@ -34,7 +34,11 @@
#include "codec/fe_row_codec.h"
#include "glog/logging.h"
#include "node/sql_node.h"
#include "yaml-cpp/yaml.h"
#include "plan/plan_api.h"
#include "vm/engine.h"
#include "zetasql/parser/parser.h"
#include "planv2/ast_node_converter.h"
#include "vm/jit_wrapper.h"

namespace hybridse {
namespace sqlcase {
Expand All @@ -44,6 +48,23 @@ static absl::Mutex mtx;
// working directory, where the yaml file lives
static std::filesystem::path working_dir;

static vm::HybridSeJitWrapper* createJitWrapper() {
hybridse::vm::Engine::InitializeGlobalLLVM();
base::Status s;
auto wrapper = vm::HybridSeJitWrapper::CreateWithDefaultSymbols(&s);
if (!s.isOK()) {
LOG(ERROR) << "fail to get jit: " << s;
std::exit(1);
}
return wrapper;
}

static vm::HybridSeJitWrapper* getJitWrapper() {
// ensuare call once by local static
static vm::HybridSeJitWrapper* wrapper = createJitWrapper();
return wrapper;
}

bool SqlCase::TTLParse(const std::string& org_type_str,
std::vector<int64_t>& ttls) {
std::string type_str = org_type_str;
Expand Down Expand Up @@ -271,6 +292,20 @@ bool SqlCase::ExtractSchema(const std::vector<std::string>& columns,
LOG(WARNING) << "Invalid Schema Format";
return false;
}
// expect the string of table column schema:
// 1. {name}<colon>{type}
// legacy specification, issues may arrise when dealing with type contains options
// 2. {name}<any space>{type}
// favored specification, which is exactly the same as SQL table column schema,
// you may nest any composed type like map in {type}

auto column_list = absl::StrJoin(columns, ",");
auto rs = plan::ParseTableColumSchema(column_list);
if (rs.ok()) {
table.mutable_columns()->CopyFrom(rs.value());
return true;
}
// fallback legacy approach
try {
for (auto col : columns) {
boost::trim(col);
Expand Down Expand Up @@ -409,13 +444,15 @@ bool SqlCase::AddInput(const TableInfo& table_data) {
return true;
}
bool SqlCase::ExtractInputData(std::vector<Row>& rows,
int32_t input_idx) const {
return ExtractInputData(inputs_[input_idx], rows);
int32_t input_idx,
const codec::Schema& sc) const {
return ExtractInputData(inputs_[input_idx], rows, sc);
}
bool SqlCase::ExtractInputData(const TableInfo& input,
std::vector<Row>& rows) const {
std::vector<Row>& rows,
const codec::Schema& sc) const {
try {
if (input.data_.empty() && input.rows_.empty()) {
if (input.data_.empty() && input.rows_.empty() && input.inserts_.empty()) {
LOG(WARNING) << "Empty Data String";
return false;
}
Expand All @@ -425,7 +462,20 @@ bool SqlCase::ExtractInputData(const TableInfo& input,
return false;
}

if (!input.data_.empty()) {
if (!input.inserts_.empty()) {
for (auto& sql : input.inserts_) {
// shit happens, resue jit cause duplcate symbols
// FIXME(#3748): use getJitWrapper
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(createJitWrapper());
auto rs = ExtractInsertRow(jit.get(), sql, &sc);
if (!rs.ok()) {
LOG(ERROR) << rs.status();
return false;
}

rows.insert(rows.end(), rs.value().begin(), rs.value().end());
}
} else if (!input.data_.empty()) {
if (!ExtractRows(table.columns(), input.data_, rows)) {
return false;
}
Expand Down Expand Up @@ -1713,5 +1763,48 @@ std::string SqlCase::SqlCaseBaseDir() {
}
return "";
}

absl::StatusOr<std::vector<codec::Row>> ExtractInsertRow(vm::HybridSeJitWrapper* jit, absl::string_view insert,
const codec::Schema* table_schema) {
zetasql::ParserOptions parser_opts;
zetasql::LanguageOptions language_opts;
language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE);
parser_opts.set_language_options(&language_opts);
std::unique_ptr<zetasql::ParserOutput> parser_output;
auto zetasql_status = zetasql::ParseStatement(insert, parser_opts, &parser_output);
CHECK_ABSL_STATUS(zetasql_status);

node::SqlNode* sql_node = nullptr;
node::NodeManager nm;
CHECK_STATUS_TO_ABSL(plan::ConvertStatement(parser_output->statement(), &nm, &sql_node));

auto* insert_stmt = sql_node->GetAsOrNull<node::InsertStmt>();
if (insert_stmt == nullptr) {
return absl::FailedPreconditionError("not a insert statement");
}

if (!insert_stmt->columns_.empty()) {
// implementation limitation
return absl::UnimplementedError("insert with custom columns not support");
}

codec::RowBuilder2 builder(jit, std::vector<codec::Schema>{*table_schema});

CHECK_STATUS_TO_ABSL(builder.Init());

std::vector<codec::Row> rows;
for (auto expr : insert_stmt->values_) {
auto expr_list = expr->GetAsOrNull<node::ExprListNode>();
if (expr_list == nullptr) {
return absl::FailedPreconditionError(
absl::Substitute("unexpected insert statement value: $0", expr->GetExprString()));
}
codec::Row row;
CHECK_STATUS_TO_ABSL(builder.Build(expr_list->children_, &row));
rows.push_back(row);
}

return rows;
}
} // namespace sqlcase
} // namespace hybridse
Loading

0 comments on commit 7a9c66b

Please sign in to comment.