Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support map data type in yaml testing framework #3765

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cases/query/udf_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,22 @@ cases:
sql: |
select c1 + 8 from (select 9 as c1)

- id: 18
mode: request-unsupport
inputs:
- name: t1
columns: ["col1 int32", "std_ts timestamp", "col2 map<int,string>"]
indexs: ["index1:col1:std_ts"]
# insert map values by insert stmts only
inserts:
- "insert into t1 values (1, timestamp(1000), map(12, 'abc'))"
- "insert into t1 values (2, timestamp(1000), map(13, 'abc'))"
sql: |
select col1, col2[12] as out from t1;
expect:
order: col1
columns: ["col1 int", "out string"]
data: |
1, abc
2, null

4 changes: 1 addition & 3 deletions hybridse/examples/toydb/src/cmd/toydb_run_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>

#include "absl/strings/match.h"
#include "testing/toydb_engine_test_base.h"
Expand Down Expand Up @@ -104,8 +103,7 @@ int RunSingle(const std::string& yaml_path) {

int main(int argc, char** argv) {
::google::ParseCommandLineFlags(&argc, &argv, false);
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
::hybridse::vm::Engine::InitializeGlobalLLVM();
if (FLAGS_yaml_path != "") {
return ::hybridse::vm::RunSingle(FLAGS_yaml_path);
} else {
Expand Down
5 changes: 2 additions & 3 deletions hybridse/examples/toydb/src/testing/toydb_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

#include "gtest/gtest.h"
#include "gtest/internal/gtest-param-util.h"
#include "vm/engine.h"
#include "testing/toydb_engine_test_base.h"

using namespace llvm; // NOLINT (build/namespaces)
Expand Down Expand Up @@ -126,8 +126,7 @@ TEST_P(BatchRequestEngineTest, TestClusterBatchRequestEngine) {
} // namespace hybridse

int main(int argc, char** argv) {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
::hybridse::vm::Engine::InitializeGlobalLLVM();
::testing::InitGoogleTest(&argc, argv);
// ::hybridse::vm::CoreAPI::EnableSignalTraceback();
return RUN_ALL_TESTS();
Expand Down
27 changes: 27 additions & 0 deletions hybridse/examples/toydb/src/testing/toydb_engine_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <utility>
#include <vector>

#include "absl/strings/substitute.h"
#include "case/case_data_mock.h"
#include "case/sql_case.h"
#include "glog/logging.h"
Expand Down Expand Up @@ -98,6 +99,14 @@ class ToydbBatchEngineTestRunner : public BatchEngineTestRunner {
CheckSqliteCompatible(sql_case_, GetSession()->GetSchema(), output_rows);
}
}
const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
Expand Down Expand Up @@ -143,6 +152,15 @@ class ToydbRequestEngineTestRunner : public RequestEngineTestRunner {
return table->Put(reinterpret_cast<char*>(row.buf()), row.size());
}

const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
std::map<std::pair<std::string, std::string>, std::shared_ptr<::hybridse::storage::Table>> name_table_map_;
Expand Down Expand Up @@ -189,6 +207,15 @@ class ToydbBatchRequestEngineTestRunner : public BatchRequestEngineTestRunner {
return table->Put(reinterpret_cast<char*>(row.buf()), row.size());
}

const type::TableDef* GetTableDef(absl::string_view db, absl::string_view table) override {
auto it = name_table_map_.find(std::make_pair(std::string(db), std::string(table)));
if (it == name_table_map_.end()) {
return nullptr;
}

return &it->second->GetTableDef();
}

private:
std::shared_ptr<tablet::TabletCatalog> catalog_;
std::map<std::pair<std::string, std::string>, std::shared_ptr<::hybridse::storage::Table>> name_table_map_;
Expand Down
12 changes: 8 additions & 4 deletions hybridse/include/case/sql_case.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ class SqlCase {
const codec::Schema ExtractParameterTypes() const;

bool ExtractInputData(std::vector<hybridse::codec::Row>& rows, // NOLINT
int32_t input_idx = 0) const;
bool ExtractInputData(
const TableInfo& info,
std::vector<hybridse::codec::Row>& rows) const; // NOLINT
int32_t input_idx, const codec::Schema& sc) const;
bool ExtractInputData(const TableInfo& info,
std::vector<hybridse::codec::Row>& rows, // NOLINT
const codec::Schema&) const;
bool ExtractOutputData(
std::vector<hybridse::codec::Row>& rows) const; // NOLINT

Expand Down Expand Up @@ -331,6 +331,10 @@ std::vector<SqlCase> InitCases(std::string yaml_path, std::vector<std::string> f
void InitCases(std::string yaml_path, std::vector<SqlCase>& cases); // NOLINT
void InitCases(std::string yaml_path, std::vector<SqlCase>& cases, // NOLINT
const std::vector<std::string>& filters);

// TODO(someone): consider move the function to production code so others will take usage.
absl::StatusOr<std::vector<codec::Row>> ExtractInsertRow(vm::HybridSeJitWrapper*, absl::string_view insert,
const codec::Schema* table_schema);
} // namespace sqlcase
} // namespace hybridse
#endif // HYBRIDSE_INCLUDE_CASE_SQL_CASE_H_
27 changes: 25 additions & 2 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,22 @@ class SqlNode : public NodeBase<SqlNode> {

bool Equals(const SqlNode *node) const override;

// Return this node cast as a NodeType.
// Use only when this node is known to be that type, otherwise, behavior is undefined.
template <typename NodeType>
const NodeType *GetAsOrNull() const {
static_assert(std::is_base_of<SqlNode, NodeType>::value,
"NodeType must be a member of the SqlNode class hierarchy");
return dynamic_cast<const NodeType *>(this);
}

template <typename NodeType>
NodeType *GetAsOrNull() {
static_assert(std::is_base_of<SqlNode, NodeType>::value,
"NodeType must be a member of the SqlNode class hierarchy");
return dynamic_cast<NodeType *>(this);
}

SqlNodeType type_;

private:
Expand Down Expand Up @@ -1883,6 +1899,8 @@ class ColumnSchemaNode : public SqlNode {
bool not_null() const { return not_null_; }
const ExprNode *default_value() const { return default_value_; }

absl::Status GetProtoColumnSchema(type::ColumnSchema *) const;

std::string DebugString() const;

private:
Expand All @@ -1902,6 +1920,8 @@ class ColumnDefNode : public SqlNode {

const ColumnSchemaNode *schema() const { return schema_; }

absl::Status GetProtoColumnDef(type::ColumnDef *) const;

// deprecated, use ColumnDefNode::schema instead
DataType GetColumnType() const { return schema_->type(); }

Expand Down Expand Up @@ -2025,8 +2045,11 @@ class CreateStmt : public SqlNode {

~CreateStmt() {}

NodePointVector* MutableColumnDefList() { return &column_desc_list_; }
const NodePointVector &GetColumnDefList() const { return column_desc_list_; }
NodePointVector* MutableTableElementList() { return &column_desc_list_; }
const NodePointVector &GetTableElementList() const { return column_desc_list_; }

// collect the column definitions in column_desc_list_, and convert to the proto representation type::ColumnDef
absl::StatusOr<codec::Schema> GetColumnDefListAsSchema() const;

std::string GetTableName() const { return table_name_; }
std::string GetDbName() const { return db_name_; }
Expand Down
18 changes: 18 additions & 0 deletions hybridse/include/plan/plan_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using hybridse::base::Status;
using hybridse::node::NodeManager;
using hybridse::node::NodePointVector;
using hybridse::node::PlanNodeList;

// TODO(someone): rm class PlanAPI
class PlanAPI {
public:
// parse SQL string to logic plan. ASTNode and LogicNode saved in SqlContext
Expand All @@ -47,6 +49,22 @@ class PlanAPI {
static const std::string GenerateName(const std::string prefix, int id);
};

// Parse the input str and SQL type and convert to TypeNode representation
//
// unimplemnted, reserved for later usage
absl::StatusOr<node::TypeNode*> ParseType(absl::string_view, NodeManager*);

// parse the input string as table elements and extract those element that is table_column_definition,
// then returns the corresponding proto representation.
//
// it expect input `str` joined every element by comma(,), then a CREATE TABLE SQL is created with the
// format of 'CREATE TABLE t1 ( {str} )'.
// SQL parse allows three kind of table element, which is:
// - table_column_definition
// - table_index_definition
// - table_constraint_definition
// while this method extract table_column_definition only
absl::StatusOr<codec::Schema> ParseTableColumSchema(absl::string_view str);
} // namespace plan
} // namespace hybridse
#endif // HYBRIDSE_INCLUDE_PLAN_PLAN_API_H_
107 changes: 100 additions & 7 deletions hybridse/src/case/sql_case.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

#include "case/sql_case.h"

#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <optional>
#include <set>
#include <string>
#include <vector>

#include "absl/cleanup/cleanup.h"
#include "absl/strings/ascii.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
Expand All @@ -34,7 +34,11 @@
#include "codec/fe_row_codec.h"
#include "glog/logging.h"
#include "node/sql_node.h"
#include "yaml-cpp/yaml.h"
#include "plan/plan_api.h"
#include "vm/engine.h"
#include "zetasql/parser/parser.h"
#include "planv2/ast_node_converter.h"
#include "vm/jit_wrapper.h"

namespace hybridse {
namespace sqlcase {
Expand All @@ -44,6 +48,23 @@ static absl::Mutex mtx;
// working directory, where the yaml file lives
static std::filesystem::path working_dir;

static vm::HybridSeJitWrapper* createJitWrapper() {
hybridse::vm::Engine::InitializeGlobalLLVM();
base::Status s;
auto wrapper = vm::HybridSeJitWrapper::CreateWithDefaultSymbols(&s);
if (!s.isOK()) {
LOG(ERROR) << "fail to get jit: " << s;
std::exit(1);
}
return wrapper;
}

static vm::HybridSeJitWrapper* getJitWrapper() {
// ensuare call once by local static
static vm::HybridSeJitWrapper* wrapper = createJitWrapper();
return wrapper;
}

bool SqlCase::TTLParse(const std::string& org_type_str,
std::vector<int64_t>& ttls) {
std::string type_str = org_type_str;
Expand Down Expand Up @@ -271,6 +292,20 @@ bool SqlCase::ExtractSchema(const std::vector<std::string>& columns,
LOG(WARNING) << "Invalid Schema Format";
return false;
}
// expect the string of table column schema:
// 1. {name}<colon>{type}
// legacy specification, issues may arrise when dealing with type contains options
// 2. {name}<any space>{type}
// favored specification, which is exactly the same as SQL table column schema,
// you may nest any composed type like map in {type}

auto column_list = absl::StrJoin(columns, ",");
auto rs = plan::ParseTableColumSchema(column_list);
if (rs.ok()) {
table.mutable_columns()->CopyFrom(rs.value());
return true;
}
// fallback legacy approach
try {
for (auto col : columns) {
boost::trim(col);
Expand Down Expand Up @@ -409,13 +444,15 @@ bool SqlCase::AddInput(const TableInfo& table_data) {
return true;
}
bool SqlCase::ExtractInputData(std::vector<Row>& rows,
int32_t input_idx) const {
return ExtractInputData(inputs_[input_idx], rows);
int32_t input_idx,
const codec::Schema& sc) const {
return ExtractInputData(inputs_[input_idx], rows, sc);
}
bool SqlCase::ExtractInputData(const TableInfo& input,
std::vector<Row>& rows) const {
std::vector<Row>& rows,
const codec::Schema& sc) const {
try {
if (input.data_.empty() && input.rows_.empty()) {
if (input.data_.empty() && input.rows_.empty() && input.inserts_.empty()) {
LOG(WARNING) << "Empty Data String";
return false;
}
Expand All @@ -425,7 +462,20 @@ bool SqlCase::ExtractInputData(const TableInfo& input,
return false;
}

if (!input.data_.empty()) {
if (!input.inserts_.empty()) {
for (auto& sql : input.inserts_) {
// shit happens, resue jit cause duplcate symbols
// FIXME(#3748): use getJitWrapper
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(createJitWrapper());
auto rs = ExtractInsertRow(jit.get(), sql, &sc);
if (!rs.ok()) {
LOG(ERROR) << rs.status();
return false;
}

rows.insert(rows.end(), rs.value().begin(), rs.value().end());
}
} else if (!input.data_.empty()) {
if (!ExtractRows(table.columns(), input.data_, rows)) {
return false;
}
Expand Down Expand Up @@ -1713,5 +1763,48 @@ std::string SqlCase::SqlCaseBaseDir() {
}
return "";
}

absl::StatusOr<std::vector<codec::Row>> ExtractInsertRow(vm::HybridSeJitWrapper* jit, absl::string_view insert,
const codec::Schema* table_schema) {
zetasql::ParserOptions parser_opts;
zetasql::LanguageOptions language_opts;
language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE);
parser_opts.set_language_options(&language_opts);
std::unique_ptr<zetasql::ParserOutput> parser_output;
auto zetasql_status = zetasql::ParseStatement(insert, parser_opts, &parser_output);
CHECK_ABSL_STATUS(zetasql_status);

node::SqlNode* sql_node = nullptr;
node::NodeManager nm;
CHECK_STATUS_TO_ABSL(plan::ConvertStatement(parser_output->statement(), &nm, &sql_node));

auto* insert_stmt = sql_node->GetAsOrNull<node::InsertStmt>();
if (insert_stmt == nullptr) {
return absl::FailedPreconditionError("not a insert statement");
}

if (!insert_stmt->columns_.empty()) {
// implementation limitation
return absl::UnimplementedError("insert with custom columns not support");
}

codec::RowBuilder2 builder(jit, std::vector<codec::Schema>{*table_schema});

CHECK_STATUS_TO_ABSL(builder.Init());

std::vector<codec::Row> rows;
for (auto expr : insert_stmt->values_) {
auto expr_list = expr->GetAsOrNull<node::ExprListNode>();
if (expr_list == nullptr) {
return absl::FailedPreconditionError(
absl::Substitute("unexpected insert statement value: $0", expr->GetExprString()));
}
codec::Row row;
CHECK_STATUS_TO_ABSL(builder.Build(expr_list->children_, &row));
rows.push_back(row);
}

return rows;
}
} // namespace sqlcase
} // namespace hybridse
Loading
Loading