diff --git a/cases/function/join/test_lastjoin_simple.yaml b/cases/function/join/test_lastjoin_simple.yaml index 4d23b312ef2..589e98bd05b 100644 --- a/cases/function/join/test_lastjoin_simple.yaml +++ b/cases/function/join/test_lastjoin_simple.yaml @@ -1067,4 +1067,4 @@ cases: rows: - [ "aa",2,131,1590738990000 ] - [ "bb",21,NULL,NULL ] - - [ "dd", 41, NULL, NULL ] \ No newline at end of file + - [ "dd", 41, NULL, NULL ] diff --git a/cases/plan/back_quote_identifier.yaml b/cases/plan/back_quote_identifier.yaml index cafce9e5b2d..4743634c370 100644 --- a/cases/plan/back_quote_identifier.yaml +++ b/cases/plan/back_quote_identifier.yaml @@ -131,12 +131,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a-1 | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b-1 | | +-column_type: string - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a-1, b-1] diff --git a/cases/plan/create.yaml b/cases/plan/create.yaml index 66bb1ee548c..6210401ee9d 100644 --- a/cases/plan/create.yaml +++ b/cases/plan/create.yaml @@ -163,12 +163,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b | | +-column_type: string - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a, b] @@ -218,12 +216,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a | | +-column_type: int16 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b | | +-column_type: float - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a] @@ -274,12 +270,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b | | +-column_type: timestamp - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a] @@ -627,12 +621,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b | | +-column_type: timestamp - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a] @@ -685,33 +677,27 @@ cases: | +-0: | | +-node[kColumnDesc] | | +-column_name: column1 - | | +-column_type: int32 - | | +-NOT NULL: 1 + | | +-column_type: int32 NOT NULL | +-1: | | +-node[kColumnDesc] | | +-column_name: column2 - | | +-column_type: int16 - | | +-NOT NULL: 1 + | | +-column_type: int16 NOT NULL | +-2: | | +-node[kColumnDesc] | | +-column_name: column5 - | | +-column_type: string - | | +-NOT NULL: 1 + | | +-column_type: string NOT NULL | +-3: | | +-node[kColumnDesc] | | +-column_name: column6 - | | +-column_type: string - | | +-NOT NULL: 1 + | | +-column_type: string NOT NULL | +-4: | | +-node[kColumnDesc] | | +-column_name: std_ts - | | +-column_type: timestamp - | | +-NOT NULL: 1 + | | +-column_type: timestamp NOT NULL | +-5: | | +-node[kColumnDesc] | | +-column_name: std_date - | | +-column_type: date - | | +-NOT NULL: 1 + | | +-column_type: date NOT NULL | +-6: | +-node[kColumnIndex] | +-keys: [column2] @@ -743,33 +729,27 @@ cases: | +-0: | | +-node[kColumnDesc] | | +-column_name: column1 - | | +-column_type: int32 - | | +-NOT NULL: 1 + | | +-column_type: int32 NOT NULL | +-1: | | +-node[kColumnDesc] | | +-column_name: column2 - | | +-column_type: int16 - | | +-NOT NULL: 1 + | | +-column_type: int16 NOT NULL | +-2: | | +-node[kColumnDesc] | | +-column_name: column5 - | | +-column_type: string - | | +-NOT NULL: 1 + | | +-column_type: string NOT NULL | +-3: | | +-node[kColumnDesc] | | +-column_name: column6 - | | +-column_type: string - | | +-NOT NULL: 1 + | | +-column_type: string NOT NULL | +-4: | | +-node[kColumnDesc] | | +-column_name: std_ts - | | +-column_type: timestamp - | | +-NOT NULL: 1 + | | +-column_type: timestamp NOT NULL | +-5: | | +-node[kColumnDesc] | | +-column_name: std_date - | | +-column_type: date - | | +-NOT NULL: 1 + | | +-column_type: date NOT NULL | +-6: | +-node[kColumnIndex] | +-keys: [column2] @@ -796,17 +776,11 @@ cases: | +-0: | | +-node[kColumnDesc] | | +-column_name: column1 - | | +-column_type: int32 - | | +-NOT NULL: 0 - | | +-default_value: - | | +-expr[primary] - | | +-value: 1 - | | +-type: int32 + | | +-column_type: int32 DEFAULT 1 | +-1: | +-node[kColumnDesc] | +-column_name: column2 | +-column_type: int32 - | +-NOT NULL: 0 +-table_option_list: [] - id: 27 desc: Column default value with explicit type @@ -824,20 +798,11 @@ cases: | +-0: | | +-node[kColumnDesc] | | +-column_name: column1 - | | +-column_type: string - | | +-NOT NULL: 0 - | | +-default_value: - | | +-expr[cast] - | | +-cast_type: string - | | +-expr: - | | +-expr[primary] - | | +-value: 1 - | | +-type: int32 + | | +-column_type: string DEFAULT string(1) | +-1: | +-node[kColumnDesc] | +-column_name: column3 | +-column_type: int32 - | +-NOT NULL: 0 +-table_option_list: [] - id: 28 desc: Create table with database.table @@ -856,12 +821,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: column1 | | +-column_type: string - | | +-NOT NULL: 0 | +-1: | +-node[kColumnDesc] | +-column_name: column3 | +-column_type: int32 - | +-NOT NULL: 0 +-table_option_list: [] - id: 29 desc: create index with db name prefix @@ -898,12 +861,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: column1 | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: column2 | | +-column_type: timestamp - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [column1] @@ -934,12 +895,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: a | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: b | | +-column_type: timestamp - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [a] @@ -1049,12 +1008,10 @@ cases: | | +-node[kColumnDesc] | | +-column_name: column1 | | +-column_type: int32 - | | +-NOT NULL: 0 | +-1: | | +-node[kColumnDesc] | | +-column_name: column2 | | +-column_type: timestamp - | | +-NOT NULL: 0 | +-2: | +-node[kColumnIndex] | +-keys: [column1] @@ -1068,3 +1025,45 @@ cases: +-0: +-node[kCompressType] +-compress_type: snappy + - id: 35 + desc: Create table with array & map type + sql: | + create table t1 (id int, + member ARRAY NOT NULL, + attrs MAP NOT NULL); + expect: + node_tree_str: | + +-node[CREATE] + +-table: t1 + +-IF NOT EXIST: 0 + +-column_desc_list[list]: + | +-0: + | | +-node[kColumnDesc] + | | +-column_name: id + | | +-column_type: int32 + | +-1: + | | +-node[kColumnDesc] + | | +-column_name: member + | | +-column_type: array NOT NULL + | +-2: + | +-node[kColumnDesc] + | +-column_name: attrs + | +-column_type: map NOT NULL + +-table_option_list: [] + plan_tree_str: | + +-[kCreatePlan] + +-table: t1 + +-column_desc_list[list]: + | +-0: + | | +-node[kColumnDesc] + | | +-column_name: id + | | +-column_type: int32 + | +-1: + | | +-node[kColumnDesc] + | | +-column_name: member + | | +-column_type: array NOT NULL + | +-2: + | +-node[kColumnDesc] + | +-column_name: attrs + | +-column_type: map NOT NULL + +-table_option_list: [] diff --git a/cases/plan/simple_query.yaml b/cases/plan/simple_query.yaml index 66cc542fbc0..f7a439cc011 100644 --- a/cases/plan/simple_query.yaml +++ b/cases/plan/simple_query.yaml @@ -644,3 +644,35 @@ cases: +-[kTablePlan] +-table: t +-alias: t1 + + - id: map_data_type + desc: access map value with []operator + sql: | + select map(1, 2)[1] + expect: + node_tree_str: | + +-node[kQuery]: kQuerySelect + +-distinct_opt: false + +-where_expr: null + +-group_expr_list: null + +-having_expr: null + +-order_expr_list: null + +-limit: null + +-select_list[list]: + | +-0: + | +-node[kResTarget] + | +-val: + | | map(1, 2)[1] + | +-name: + +-tableref_list: [] + +-window_list: [] + plan_tree_str: | + +-[kQueryPlan] + +-[kProjectPlan] + +-table: + +-project_list_vec[list]: + +-[kProjectList] + +-projects on table [list]: + +-[kProjectNode] + +-[0]map(1, 2)[1]: map(1, 2)[1] + null diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index ded80e003ce..218e791ab5a 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -554,3 +554,35 @@ cases: - c1 bool data: | true, false + + # ================================================================ + # Map data type + # ================================================================ + - id: 13 + sql: | + select + map(1, "2")[1] as e1, + map("abc", 100)["abc"] as e2, + map(1, "2", 3, "4")[5] as e3, + map("c", 99, "d", 101)["d"] as e4, + map(date("2012-12-12"), "e", date("2013-11-11"), "f", date("2014-10-10"), "g")[date("2013-11-11")] as e5, + map(timestamp(88), timestamp(1000), timestamp(99), timestamp(2000)) [timestamp(99)] as e6, + map('1', 2, '3', 4, '5', 6, '7', 8, '9', 10, '11', 12)['9'] as e7, + map('1', 2, '3', 4, '5', 6, '7', 8, '9', 10, '11', 12)['10'] as e8, + # first match on duplicate keys + map('1', 2, '1', 4, '1', 6, '7', 8, '9', 10, '11', 12)['1'] as e9, + map("c", 99, "d", NULL)["d"] as e10, + expect: + columns: ["e1 string", "e2 int", "e3 string", "e4 int", "e5 string", "e6 timestamp", "e7 int", "e8 int", "e9 int", "e10 int"] + data: | + 2, 100, NULL, 101, f, 2000, 10, NULL, 2, NULL + - id: 14 + sql: | + select + array_contains(map_keys(map(1, '2', 3, '4')), 1) as e1, + array_contains(map_keys(map('1', 2, '3', 4)), '2') as e2, + array_contains(map_keys(map(timestamp(88), timestamp(1000), timestamp(99), timestamp(2000))) , timestamp(99)) as e3, + expect: + columns: ["e1 bool", "e2 bool", "e3 bool"] + data: | + true, false, true diff --git a/hybridse/include/base/fe_status.h b/hybridse/include/base/fe_status.h index b91b8d8fb16..8f11a16a8c8 100644 --- a/hybridse/include/base/fe_status.h +++ b/hybridse/include/base/fe_status.h @@ -16,11 +16,12 @@ #ifndef HYBRIDSE_INCLUDE_BASE_FE_STATUS_H_ #define HYBRIDSE_INCLUDE_BASE_FE_STATUS_H_ + +#include #include #include -#include "glog/logging.h" + #include "proto/fe_common.pb.h" -#include "proto/fe_type.pb.h" namespace hybridse { namespace base { diff --git a/hybridse/include/node/expr_node.h b/hybridse/include/node/expr_node.h index 442064b6873..490e4d48c28 100644 --- a/hybridse/include/node/expr_node.h +++ b/hybridse/include/node/expr_node.h @@ -18,7 +18,6 @@ #define HYBRIDSE_INCLUDE_NODE_EXPR_NODE_H_ #include -#include #include "base/fe_status.h" #include "codec/fe_row_codec.h" diff --git a/hybridse/include/node/node_base.h b/hybridse/include/node/node_base.h index 8aa678c90a8..c6894f2b682 100644 --- a/hybridse/include/node/node_base.h +++ b/hybridse/include/node/node_base.h @@ -22,7 +22,6 @@ #include #include "base/fe_object.h" -#include "glog/logging.h" #include "node/node_enum.h" namespace hybridse { diff --git a/hybridse/include/node/node_enum.h b/hybridse/include/node/node_enum.h index 7c9ebf0ecbe..38d8336258f 100644 --- a/hybridse/include/node/node_enum.h +++ b/hybridse/include/node/node_enum.h @@ -17,9 +17,6 @@ #ifndef HYBRIDSE_INCLUDE_NODE_NODE_ENUM_H_ #define HYBRIDSE_INCLUDE_NODE_NODE_ENUM_H_ -#include -#include "proto/fe_common.pb.h" -#include "proto/fe_type.pb.h" namespace hybridse { namespace node { @@ -98,6 +95,7 @@ enum SqlNodeType { kAlterTableStmt, kShowStmt, kCompressType, + kColumnSchema, kSqlNodeTypeLast, // debug type }; @@ -143,7 +141,8 @@ enum ExprType { kExprIn, kExprEscaped, kExprArray, - kExprFake, // not a real one + kExprArrayElement, // extract value from a array or map, with `[]` operator + kExprFake, // not a real one kExprLast = kExprFake, }; @@ -175,9 +174,21 @@ enum DataType { kArray, // fixed size. In SQL: [1, 2, 3] or ARRAY[1, 2, 3] kDataTypeFake, // not a data type, for testing purpose only kLastDataType = kDataTypeFake, + // the tree type are not moved above kLastDataType for compatibility // it may necessary to do it in the further + + // kVoid + // A distinct data type: signifies no value or meaningful result. + // Typically used for function that does not returns value. kVoid = 100, + // kNull + // A special marker representing the absence of a value. + // Not a true data type but a placeholder for missing or unknown information. + // A `NULL` literal can be eventually resolved to: + // - NULL of void type, if no extra info provided: 'SELECT NULL' + // - NULL of int (or any other) type, extra information provided, e.g with 'CAST' operator + // 'SELECT CAST(NULL as INT)' kNull = 101, kPlaceholder = 102 }; diff --git a/hybridse/include/node/node_manager.h b/hybridse/include/node/node_manager.h index 6949faf6f88..de0bedfe14b 100644 --- a/hybridse/include/node/node_manager.h +++ b/hybridse/include/node/node_manager.h @@ -21,7 +21,6 @@ #ifndef HYBRIDSE_INCLUDE_NODE_NODE_MANAGER_H_ #define HYBRIDSE_INCLUDE_NODE_NODE_MANAGER_H_ -#include #include #include #include @@ -172,10 +171,6 @@ class NodeManager { const std::string &table_name, SqlNodeList *column_desc_list, SqlNodeList *partition_meta_list); - SqlNode *MakeColumnDescNode(const std::string &column_name, - const DataType data_type, - bool op_not_null, - ExprNode* default_value = nullptr); SqlNode *MakeColumnIndexNode(SqlNodeList *keys, SqlNode *ts, SqlNode *ttl, SqlNode *version); SqlNode *MakeColumnIndexNode(SqlNodeList *index_item_list); diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 8d641ad8283..f4070773b12 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -450,9 +450,7 @@ class ExprNode : public SqlNode { uint32_t GetChildNum() const { return children_.size(); } const ExprType GetExprType() const { return expr_type_; } - void PushBack(ExprNode *node_ptr) { children_.push_back(node_ptr); } - std::vector children_; void Print(std::ostream &output, const std::string &org_tab) const override; virtual const std::string GetExprString() const; virtual const std::string GenerateExpressionName() const; @@ -542,6 +540,8 @@ class ExprNode : public SqlNode { static Status RlikeTypeAccept(node::NodeManager* nm, const TypeNode* lhs, const TypeNode* rhs, const TypeNode** output); + std::vector children_; + private: const TypeNode *output_type_ = nullptr; bool nullable_ = true; @@ -570,10 +570,26 @@ class ArrayExpr : public ExprNode { Status InferAttr(ExprAnalysisContext *ctx) override; - // array type may specific already in SQL, e.g. ARRAY[1,2,3] + // array type may specified type in SQL already, e.g. ARRAY[1,2,3] TypeNode* specific_type_ = nullptr; }; +// extract value from array or map value, using '[]' operator +class ArrayElementExpr : public ExprNode { + public: + ArrayElementExpr(ExprNode *array, ExprNode *pos) ABSL_ATTRIBUTE_NONNULL(); + ~ArrayElementExpr() override {} + + ExprNode *array() const; + ExprNode *position() const; + + void Print(std::ostream &output, const std::string &org_tab) const override; + const std::string GetExprString() const override; + ArrayElementExpr *ShadowCopy(NodeManager *nm) const override; + + Status InferAttr(ExprAnalysisContext *ctx) override; +}; + class FnNode : public SqlNode { public: FnNode() : SqlNode(kFn, 0, 0), indent(0) {} @@ -1836,31 +1852,53 @@ class ResTarget : public SqlNode { NodePointVector indirection_; /* subscripts, field names, and '*', or NIL */ }; +class ColumnSchemaNode : public SqlNode { + public: + ColumnSchemaNode(DataType type, bool attr_not_null, const ExprNode *default_val) + : SqlNode(kColumnSchema, 0, 0), type_(type), not_null_(attr_not_null), default_value_(default_val) {} + + ColumnSchemaNode(DataType type, absl::Span generics, bool attr_not_null, + const ExprNode *default_val) + : SqlNode(kColumnSchema, 0, 0), + type_(type), + generics_(generics.begin(), generics.end()), + not_null_(attr_not_null), + default_value_(default_val) {} + ~ColumnSchemaNode() override {} + + DataType type() const { return type_; } + absl::Span generics() const { return generics_; } + bool not_null() const { return not_null_; } + const ExprNode *default_value() const { return default_value_; } + + std::string DebugString() const; + + private: + DataType type_; + std::vector generics_; + bool not_null_; + const ExprNode* default_value_ = nullptr; +}; + class ColumnDefNode : public SqlNode { public: - ColumnDefNode() : SqlNode(kColumnDesc, 0, 0), column_name_(""), column_type_() {} - ColumnDefNode(const std::string &name, const DataType &data_type, bool op_not_null, ExprNode *default_value) - : SqlNode(kColumnDesc, 0, 0), - column_name_(name), - column_type_(data_type), - op_not_null_(op_not_null), - default_value_(default_value) {} + ColumnDefNode(const std::string &name, const ColumnSchemaNode *schema) + : SqlNode(kColumnDesc, 0, 0), column_name_(name), schema_(schema) {} ~ColumnDefNode() {} std::string GetColumnName() const { return column_name_; } - DataType GetColumnType() const { return column_type_; } + DataType GetColumnType() const { return schema_->type(); } - ExprNode* GetDefaultValue() const { return default_value_; } + const ExprNode* GetDefaultValue() const { return schema_->default_value(); } + + bool GetIsNotNull() const { return schema_->not_null(); } - bool GetIsNotNull() const { return op_not_null_; } void Print(std::ostream &output, const std::string &org_tab) const; private: std::string column_name_; - DataType column_type_; - bool op_not_null_; - ExprNode* default_value_ = nullptr; + const ColumnSchemaNode* schema_; }; class InsertStmt : public SqlNode { diff --git a/hybridse/include/node/type_node.h b/hybridse/include/node/type_node.h index e27ef34ce46..110b6329e59 100644 --- a/hybridse/include/node/type_node.h +++ b/hybridse/include/node/type_node.h @@ -21,6 +21,7 @@ #include #include "codec/fe_row_codec.h" +#include "node/expr_node.h" #include "node/sql_node.h" #include "vm/schemas_context.h" @@ -31,7 +32,7 @@ class NodeManager; class TypeNode : public SqlNode { public: - TypeNode() : SqlNode(node::kType, 0, 0), base_(hybridse::node::kNull) {} + TypeNode() : SqlNode(node::kType, 0, 0), base_(hybridse::node::kVoid) {} explicit TypeNode(hybridse::node::DataType base) : SqlNode(node::kType, 0, 0), base_(base), generics_({}) {} explicit TypeNode(hybridse::node::DataType base, const TypeNode *v1) @@ -48,44 +49,44 @@ class TypeNode : public SqlNode { generics_nullable_({false, false}) {} ~TypeNode() override {} - friend bool operator==(const TypeNode& lhs, const TypeNode& rhs); + friend bool operator==(const TypeNode &lhs, const TypeNode &rhs); + + // Return this node cast as a NodeType. + // Use only when this node is known to be that type, otherwise, behavior is undefined. + template + const NodeType *GetAsOrNull() const { + static_assert(std::is_base_of::value, + "NodeType must be a member of the TypeNode class hierarchy"); + return dynamic_cast(this); + } + + template + NodeType *GetAsOrNull() { + static_assert(std::is_base_of::value, + "NodeType must be a member of the TypeNode class hierarchy"); + return dynamic_cast(this); + } // canonical name for the type // this affect the function generated by codegen - virtual const std::string GetName() const { - std::string type_name = DataTypeName(base_); - if (!generics_.empty()) { - for (auto type : generics_) { - type_name.append("_"); - type_name.append(type->GetName()); - } - } - return type_name; - } + virtual const std::string GetName() const; // readable string representation virtual std::string DebugString() const; - const hybridse::node::TypeNode *GetGenericType(size_t idx) const { - return generics_[idx]; - } + const hybridse::node::TypeNode *GetGenericType(size_t idx) const; bool IsGenericNullable(size_t idx) const { return generics_nullable_[idx]; } size_t GetGenericSize() const { return generics_.size(); } hybridse::node::DataType base() const { return base_; } - const std::vector &generics() const { - return generics_; - } + const std::vector &generics() const { return generics_; } - void AddGeneric(const node::TypeNode *dtype, bool nullable) { - generics_.push_back(dtype); - generics_nullable_.push_back(nullable); - } + void AddGeneric(const node::TypeNode *dtype, bool nullable); void Print(std::ostream &output, const std::string &org_tab) const override; - virtual bool Equals(const SqlNode *node) const; + bool Equals(const SqlNode *node) const override; TypeNode *ShadowCopy(NodeManager *) const override; TypeNode *DeepCopy(NodeManager *) const override; @@ -105,9 +106,22 @@ class TypeNode : public SqlNode { bool IsFloating() const; bool IsGeneric() const; + virtual bool IsMap() const { return false; } + virtual bool IsArray() const { return base_ == kArray; } + static Status CheckTypeNodeNotNull(const TypeNode *left_type); hybridse::node::DataType base_; + + // generics_ not empty if it is a complex data type: + // 1. base = ARRAY, generics = [ element_type ] + // 2. base = MAP, generics = [ key_type, value_type ] + // 3. base = STRUCT, generics = [ fileld_type, ... ] (unimplemented) + // inner types, not exists in SQL level + // 4. base = LIST, generics = [ element_type ] + // 5. base = ITERATOR, generics = [ element_type ] + // 6. base = TUPLE (like STRUCT), generics = [ element_type, ... ] + // 7. ... (might others, undocumented) std::vector generics_; std::vector generics_nullable_; }; @@ -120,9 +134,7 @@ class OpaqueTypeNode : public TypeNode { size_t bytes() const { return bytes_; } - const std::string GetName() const override { - return "opaque<" + std::to_string(bytes_) + ">"; - } + const std::string GetName() const override; OpaqueTypeNode *ShadowCopy(NodeManager *) const override; @@ -173,11 +185,28 @@ class FixedArrayType : public TypeNode { std::string DebugString() const override; FixedArrayType *ShadowCopy(NodeManager *) const override; + bool IsArray() const override { return true; } + private: const TypeNode* ele_ty_; uint64_t num_elements_; }; +class MapType : public TypeNode { + public: + MapType(const TypeNode *key_ty, const TypeNode *value_ty, bool value_not_null = false) ABSL_ATTRIBUTE_NONNULL(); + ~MapType() override; + + bool IsMap() const override { return true; } + + const TypeNode *key_type() const; + const TypeNode *value_type() const; + bool value_nullable() const; + + // test if input args can safely apply to a map function + static absl::StatusOr InferMapType(NodeManager *, absl::Span types); +}; + } // namespace node } // namespace hybridse #endif // HYBRIDSE_INCLUDE_NODE_TYPE_NODE_H_ diff --git a/hybridse/include/plan/plan_api.h b/hybridse/include/plan/plan_api.h index 0ad45f91f9f..1e4f3b74845 100644 --- a/hybridse/include/plan/plan_api.h +++ b/hybridse/include/plan/plan_api.h @@ -15,9 +15,13 @@ */ #ifndef HYBRIDSE_INCLUDE_PLAN_PLAN_API_H_ #define HYBRIDSE_INCLUDE_PLAN_PLAN_API_H_ + #include #include + #include "node/node_manager.h" +#include "vm/sql_ctx.h" + namespace hybridse { namespace plan { @@ -27,6 +31,10 @@ using hybridse::node::NodePointVector; using hybridse::node::PlanNodeList; class PlanAPI { public: + // parse SQL string to logic plan. ASTNode and LogicNode saved in SqlContext + static base::Status CreatePlanTreeFromScript(vm::SqlContext* ctx); + + // deprecated, use CreatePlanTreeFromScript(vm::SqlContext*) instead static bool CreatePlanTreeFromScript(const std::string& sql, PlanNodeList& plan_trees, // NOLINT NodeManager* node_manager, @@ -34,6 +42,7 @@ class PlanAPI { bool is_batch_mode = true, bool is_cluster = false, bool enable_batch_window_parallelization = false, const std::unordered_map* extra_options = nullptr); + static const int GetPlanLimitCount(node::PlanNode* plan_trees); static const std::string GenerateName(const std::string prefix, int id); }; diff --git a/hybridse/include/vm/sql_ctx.h b/hybridse/include/vm/sql_ctx.h new file mode 100644 index 00000000000..25182b86647 --- /dev/null +++ b/hybridse/include/vm/sql_ctx.h @@ -0,0 +1,91 @@ +/** + * Copyright (c) 2023 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_INCLUDE_VM_SQL_CTX_H_ +#define HYBRIDSE_INCLUDE_VM_SQL_CTX_H_ + +#include +#include +#include + +#include "node/node_manager.h" +#include "vm/engine_context.h" + +namespace zetasql { +class ParserOutput; +} + +namespace hybridse { +namespace vm { + +class HybridSeJitWrapper; +class ClusterJob; + +struct SqlContext { + // mode: batch|request|batch request + ::hybridse::vm::EngineMode engine_mode; + bool is_cluster_optimized = false; + bool is_batch_request_optimized = false; + bool enable_expr_optimize = false; + bool enable_batch_window_parallelization = true; + bool enable_window_column_pruning = false; + + // the sql content + std::string sql; + // the database + std::string db; + + std::unique_ptr ast_node; + // the logical plan + ::hybridse::node::PlanNodeList logical_plan; + ::hybridse::vm::PhysicalOpNode* physical_plan = nullptr; + + std::shared_ptr cluster_job; + // TODO(wangtaize) add a light jit engine + // eg using bthead to compile ir + hybridse::vm::JitOptions jit_options; + std::shared_ptr jit = nullptr; + Schema schema; + Schema request_schema; + std::string request_db_name; + std::string request_name; + Schema parameter_types; + uint32_t row_size; + uint32_t limit_cnt = 0; + std::string ir; + std::string logical_plan_str; + std::string physical_plan_str; + std::string encoded_schema; + std::string encoded_request_schema; + ::hybridse::node::NodeManager nm; + ::hybridse::udf::UdfLibrary* udf_library = nullptr; + + ::hybridse::vm::BatchRequestInfo batch_request_info; + + std::shared_ptr> options; + + // [ALPHA] SQL diagnostic infos + // not standardized, only index hints, no error, no warning, no other hint/info + std::shared_ptr index_hints; + + SqlContext(); + ~SqlContext(); +}; + +} // namespace vm +} // namespace hybridse + +#endif // HYBRIDSE_INCLUDE_VM_SQL_CTX_H_ diff --git a/hybridse/src/codegen/aggregate_ir_builder.cc b/hybridse/src/codegen/aggregate_ir_builder.cc index 19e2a4f5cc3..22de3d3d742 100644 --- a/hybridse/src/codegen/aggregate_ir_builder.cc +++ b/hybridse/src/codegen/aggregate_ir_builder.cc @@ -21,10 +21,10 @@ #include #include +#include "codegen/buf_ir_builder.h" #include "codegen/expr_ir_builder.h" #include "codegen/ir_base_builder.h" #include "codegen/variable_ir_builder.h" -#include "gflags/gflags.h" #include "glog/logging.h" namespace hybridse { namespace codegen { diff --git a/hybridse/src/codegen/array_ir_builder.cc b/hybridse/src/codegen/array_ir_builder.cc index 5bf1bf06e99..5f3d22edc5c 100644 --- a/hybridse/src/codegen/array_ir_builder.cc +++ b/hybridse/src/codegen/array_ir_builder.cc @@ -17,26 +17,26 @@ #include "codegen/array_ir_builder.h" #include + +#include "codegen/context.h" #include "codegen/ir_base_builder.h" namespace hybridse { namespace codegen { +#define SZ_IDX 2 +#define RAW_IDX 0 +#define NULL_IDX 1 + ArrayIRBuilder::ArrayIRBuilder(::llvm::Module* m, llvm::Type* ele_ty) : StructTypeIRBuilder(m), element_type_(ele_ty) { InitStructType(); } -ArrayIRBuilder::ArrayIRBuilder(::llvm::Module* m, llvm::Type* ele_ty, llvm::Value* num_ele) - : StructTypeIRBuilder(m), element_type_(ele_ty), num_elements_(num_ele) { - InitStructType(); -} - void ArrayIRBuilder::InitStructType() { // name must unique between different array type std::string name = absl::StrCat("fe.array_", GetLlvmObjectString(element_type_)); - ::llvm::StringRef sr(name); - ::llvm::StructType* stype = m_->getTypeByName(sr); + ::llvm::StructType* stype = m_->getTypeByName(name); if (stype != NULL) { struct_type_ = stype; return; @@ -46,29 +46,36 @@ void ArrayIRBuilder::InitStructType() { ::llvm::Type* arr_type = element_type_->getPointerTo(); ::llvm::Type* nullable_type = ::llvm::IntegerType::getInt1Ty(m_->getContext())->getPointerTo(); ::llvm::Type* size_type = ::llvm::IntegerType::getInt64Ty(m_->getContext()); - std::vector<::llvm::Type*> elements = {arr_type, nullable_type, size_type}; - stype->setBody(::llvm::ArrayRef<::llvm::Type*>(elements)); + stype->setBody({arr_type, nullable_type, size_type}); struct_type_ = stype; } -base::Status ArrayIRBuilder::NewFixedArray(llvm::BasicBlock* bb, const std::vector& elements, - NativeValue* output) const { - // TODO(ace): reduce IR size with loop block - - CHECK_TRUE(num_elements_ != nullptr, common::kCodegenError, "num elements unknown"); - +absl::StatusOr ArrayIRBuilder::Construct(CodeGenContext* ctx, + absl::Span elements) const { + auto bb = ctx->GetCurrentBlock(); // alloc array struct llvm::Value* array_alloca = nullptr; - CHECK_TRUE(Create(bb, &array_alloca), common::kCodegenError, "can't create struct type for array"); + if (!Allocate(bb, &array_alloca)) { + return absl::InternalError("can't create struct type for array"); + } // ============================ // Init array elements // ============================ llvm::IRBuilder<> builder(bb); + auto num_elements = ctx->GetBuilder()->getInt64(elements.size()); + if (!Set(bb, array_alloca, SZ_IDX, num_elements)) { + return absl::InternalError("fail to set array size"); + } + + if (elements.empty()) { + // empty array + return NativeValue::Create(array_alloca); + } // init raw array and nullable array - auto* raw_array_ptr = builder.CreateAlloca(element_type_, num_elements_); - auto* nullables_ptr = builder.CreateAlloca(builder.getInt1Ty(), num_elements_); + auto* raw_array_ptr = builder.CreateAlloca(element_type_, num_elements); + auto* nullables_ptr = builder.CreateAlloca(builder.getInt1Ty(), num_elements); // fullfill the array struct auto* idx_val_ptr = builder.CreateAlloca(builder.getInt64Ty()); @@ -88,41 +95,26 @@ base::Status ArrayIRBuilder::NewFixedArray(llvm::BasicBlock* bb, const std::vect } // Set raw array - CHECK_TRUE(Set(bb, array_alloca, 0, raw_array_ptr), common::kCodegenError); + if (!Set(bb, array_alloca, RAW_IDX, raw_array_ptr)) { + return absl::InternalError("fail to set array values"); + } // Set nullable list - CHECK_TRUE(Set(bb, array_alloca, 1, nullables_ptr), common::kCodegenError); - - ::llvm::Value* array_sz = builder.CreateLoad(idx_val_ptr); - CHECK_TRUE(Set(bb, array_alloca, 2, array_sz), common::kCodegenError); - - *output = NativeValue::Create(array_alloca); - return base::Status::OK(); -} - - -base::Status ArrayIRBuilder::NewEmptyArray(llvm::BasicBlock* bb, NativeValue* output) const { - llvm::Value* array_alloca = nullptr; - CHECK_TRUE(Create(bb, &array_alloca), common::kCodegenError, "can't create struct type for array"); - - llvm::IRBuilder<> builder(bb); - - ::llvm::Value* array_sz = builder.getInt64(0); - CHECK_TRUE(Set(bb, array_alloca, 2, array_sz), common::kCodegenError); - - *output = NativeValue::Create(array_alloca); + if (!Set(bb, array_alloca, NULL_IDX, nullables_ptr)) { + return absl::InternalError("fail to set array nulls"); + } - return base::Status::OK(); + return NativeValue::Create(array_alloca); } bool ArrayIRBuilder::CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) { llvm::Value* array_alloca = nullptr; - if (!Create(block, &array_alloca)) { + if (!Allocate(block, &array_alloca)) { return false; } llvm::IRBuilder<> builder(block); ::llvm::Value* array_sz = builder.getInt64(0); - if (!Set(block, array_alloca, 2, array_sz)) { + if (!Set(block, array_alloca, SZ_IDX, array_sz)) { return false; } diff --git a/hybridse/src/codegen/array_ir_builder.h b/hybridse/src/codegen/array_ir_builder.h index 66ef2fe05da..b6ff275ac45 100644 --- a/hybridse/src/codegen/array_ir_builder.h +++ b/hybridse/src/codegen/array_ir_builder.h @@ -17,9 +17,6 @@ #ifndef HYBRIDSE_SRC_CODEGEN_ARRAY_IR_BUILDER_H_ #define HYBRIDSE_SRC_CODEGEN_ARRAY_IR_BUILDER_H_ -#include - -#include "absl/base/attributes.h" #include "codegen/struct_ir_builder.h" namespace hybridse { @@ -29,27 +26,15 @@ namespace codegen { // - Array of raw values: T* // - Array of nullable values: bool* // - array size: int64 -class ArrayIRBuilder : public StructTypeIRBuilder { +class ArrayIRBuilder : public StructTypeIRBuilder { public: // Array builder with num elements unknown ArrayIRBuilder(::llvm::Module* m, llvm::Type* ele_ty); - // Array builder with num elements known at some point - ArrayIRBuilder(::llvm::Module* m, llvm::Type* ele_ty, llvm::Value* num_ele); - ~ArrayIRBuilder() override {} // create a new array from `elements` as value - ABSL_MUST_USE_RESULT - base::Status NewFixedArray(llvm::BasicBlock* bb, const std::vector& elements, - NativeValue* output) const; - - ABSL_MUST_USE_RESULT - base::Status NewEmptyArray(llvm::BasicBlock* bb, NativeValue* output) const; - - void InitStructType() override; - - bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; + absl::StatusOr Construct(CodeGenContext* ctx, absl::Span args) const override; bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override { return true; } @@ -57,9 +42,13 @@ class ArrayIRBuilder : public StructTypeIRBuilder { CHECK_TRUE(false, common::kCodegenError, "casting to array un-implemented"); }; + private: + void InitStructType() override; + + bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; + private: ::llvm::Type* element_type_ = nullptr; - ::llvm::Value* num_elements_ = nullptr; }; } // namespace codegen diff --git a/hybridse/src/codegen/block_ir_builder.cc b/hybridse/src/codegen/block_ir_builder.cc index 6f53e80aa40..818229553ca 100644 --- a/hybridse/src/codegen/block_ir_builder.cc +++ b/hybridse/src/codegen/block_ir_builder.cc @@ -15,15 +15,15 @@ */ #include "codegen/block_ir_builder.h" + #include "codegen/context.h" #include "codegen/expr_ir_builder.h" +#include "codegen/ir_base_builder.h" #include "codegen/list_ir_builder.h" #include "codegen/struct_ir_builder.h" #include "codegen/type_ir_builder.h" #include "codegen/variable_ir_builder.h" #include "glog/logging.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/IRBuilder.h" using ::hybridse::common::kCodegenError; diff --git a/hybridse/src/codegen/date_ir_builder.cc b/hybridse/src/codegen/date_ir_builder.cc index 19bf319d7c3..1bfb1d31160 100644 --- a/hybridse/src/codegen/date_ir_builder.cc +++ b/hybridse/src/codegen/date_ir_builder.cc @@ -55,7 +55,7 @@ bool DateIRBuilder::NewDate(::llvm::BasicBlock* block, ::llvm::Value** output) { return false; } ::llvm::Value* date; - if (!Create(block, &date)) { + if (!Allocate(block, &date)) { return false; } if (!SetDate(block, date, @@ -73,7 +73,7 @@ bool DateIRBuilder::NewDate(::llvm::BasicBlock* block, ::llvm::Value* days, return false; } ::llvm::Value* date; - if (!Create(block, &date)) { + if (!Allocate(block, &date)) { return false; } if (!SetDate(block, date, days)) { diff --git a/hybridse/src/codegen/date_ir_builder.h b/hybridse/src/codegen/date_ir_builder.h index d9004d48da1..1d51cc98ceb 100644 --- a/hybridse/src/codegen/date_ir_builder.h +++ b/hybridse/src/codegen/date_ir_builder.h @@ -28,8 +28,6 @@ class DateIRBuilder : public StructTypeIRBuilder { explicit DateIRBuilder(::llvm::Module* m); ~DateIRBuilder(); - void InitStructType() override; - bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override; base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override; @@ -46,6 +44,9 @@ class DateIRBuilder : public StructTypeIRBuilder { ::llvm::Value** output, base::Status& status); // NOLINT bool Year(::llvm::BasicBlock* block, ::llvm::Value* date, ::llvm::Value** output, base::Status& status); // NOLINT + private: + void InitStructType() override; + bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; }; } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/expr_ir_builder.cc b/hybridse/src/codegen/expr_ir_builder.cc index 6b95bfb8ce1..ccf3838cbcf 100644 --- a/hybridse/src/codegen/expr_ir_builder.cc +++ b/hybridse/src/codegen/expr_ir_builder.cc @@ -19,8 +19,10 @@ #include #include #include +#include #include "base/numeric.h" +#include "codegen/arithmetic_expr_ir_builder.h" #include "codegen/array_ir_builder.h" #include "codegen/buf_ir_builder.h" #include "codegen/cond_select_ir_builder.h" @@ -28,11 +30,19 @@ #include "codegen/date_ir_builder.h" #include "codegen/ir_base_builder.h" #include "codegen/list_ir_builder.h" +#include "codegen/map_ir_builder.h" +#include "codegen/predicate_expr_ir_builder.h" +#include "codegen/scope_var.h" #include "codegen/timestamp_ir_builder.h" #include "codegen/type_ir_builder.h" #include "codegen/udf_ir_builder.h" +#include "codegen/variable_ir_builder.h" #include "codegen/window_ir_builder.h" #include "glog/logging.h" +#include "llvm/IR/IRBuilder.h" +#include "node/node_manager.h" +#include "node/type_node.h" +#include "passes/resolve_fn_and_attrs.h" #include "proto/fe_common.pb.h" #include "udf/default_udf_library.h" #include "vm/schemas_context.h" @@ -199,6 +209,10 @@ Status ExprIRBuilder::Build(const ::hybridse::node::ExprNode* node, CHECK_STATUS(BuildArrayExpr(dynamic_cast(node), output)); break; } + case ::hybridse::node::kExprArrayElement: { + CHECK_STATUS(BuildArrayElement(dynamic_cast(node), output)); + break; + } default: { return Status(kCodegenError, "Expression Type " + @@ -1157,13 +1171,6 @@ Status ExprIRBuilder::BuildArrayExpr(const ::hybridse::node::ArrayExpr* node, Na llvm::IRBuilder<> builder(ctx_->GetCurrentBlock()); - if (node->GetChildNum() == 0) { - // build empty array - ArrayIRBuilder ir_builder(ctx_->GetModule(), ele_type); - CHECK_STATUS(ir_builder.NewEmptyArray(ctx_->GetCurrentBlock(), output)); - return Status::OK(); - } - CastExprIRBuilder cast_builder(ctx_->GetCurrentBlock()); std::vector elements; for (auto& ele : node->children_) { @@ -1178,11 +1185,46 @@ Status ExprIRBuilder::BuildArrayExpr(const ::hybridse::node::ArrayExpr* node, Na } } - ::llvm::Value* num_elements = builder.getInt64(elements.size()); - ArrayIRBuilder array_builder(ctx_->GetModule(), ele_type, num_elements); - CHECK_STATUS(array_builder.NewFixedArray(ctx_->GetCurrentBlock(), elements, output)); + ArrayIRBuilder array_builder(ctx_->GetModule(), ele_type); + auto rs = array_builder.Construct(ctx_, elements); + if (!rs.ok()) { + FAIL_STATUS(kCodegenError, rs.status()); + } + + *output = rs.value(); return Status::OK(); } +Status ExprIRBuilder::BuildArrayElement(const ::hybridse::node::ArrayElementExpr* expr, NativeValue* output) { + auto* arr_type = expr->array()->GetOutputType(); + NativeValue arr_val; + CHECK_STATUS(Build(expr->array(), &arr_val)); + + NativeValue pos_val; + CHECK_STATUS(Build(expr->position(), &pos_val)); + + std::unique_ptr type_builder; + + if (arr_type->IsMap()) { + auto* map_type = arr_type->GetAsOrNull(); + ::llvm::Type* key_type = nullptr; + ::llvm::Type* value_type = nullptr; + CHECK_TRUE(GetLlvmType(ctx_->GetModule(), map_type->key_type(), &key_type), kCodegenError); + CHECK_TRUE(GetLlvmType(ctx_->GetModule(), map_type->value_type(), &value_type), kCodegenError); + type_builder.reset(new MapIRBuilder(ctx_->GetModule(), key_type, value_type)); + } else if (arr_type->IsArray()) { + ::llvm::Type* ele_type = nullptr; + CHECK_TRUE(GetLlvmType(ctx_->GetModule(), arr_type->GetGenericType(0), &ele_type), kCodegenError); + type_builder.reset(new ArrayIRBuilder(ctx_->GetModule(), ele_type)); + } else { + return {common::kCodegenError, absl::StrCat("can't get element from type ", arr_type->DebugString())}; + } + + auto res = type_builder->ExtractElement(ctx_, arr_val, pos_val); + CHECK_TRUE(res.ok(), common::kCodegenError, res.status().ToString()); + *output = res.value(); + + return {}; +} } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/expr_ir_builder.h b/hybridse/src/codegen/expr_ir_builder.h index 6838d96a88b..051c9a32bfd 100644 --- a/hybridse/src/codegen/expr_ir_builder.h +++ b/hybridse/src/codegen/expr_ir_builder.h @@ -17,24 +17,12 @@ #ifndef HYBRIDSE_SRC_CODEGEN_EXPR_IR_BUILDER_H_ #define HYBRIDSE_SRC_CODEGEN_EXPR_IR_BUILDER_H_ -#include -#include #include #include + #include "base/fe_status.h" -#include "codegen/arithmetic_expr_ir_builder.h" -#include "codegen/buf_ir_builder.h" -#include "codegen/predicate_expr_ir_builder.h" -#include "codegen/row_ir_builder.h" -#include "codegen/scope_var.h" -#include "codegen/variable_ir_builder.h" -#include "codegen/window_ir_builder.h" -#include "llvm/IR/IRBuilder.h" -#include "node/node_manager.h" +#include "codegen/context.h" #include "node/sql_node.h" -#include "node/type_node.h" -#include "passes/resolve_fn_and_attrs.h" -#include "vm/schemas_context.h" namespace hybridse { namespace codegen { @@ -117,6 +105,8 @@ class ExprIRBuilder { Status BuildArrayExpr(const ::hybridse::node::ArrayExpr* node, NativeValue* output); + Status BuildArrayElement(const ::hybridse::node::ArrayElementExpr*, NativeValue*); + private: CodeGenContext* ctx_; diff --git a/hybridse/src/codegen/fn_let_ir_builder.cc b/hybridse/src/codegen/fn_let_ir_builder.cc index 362e4a83df6..6d8e86e3933 100644 --- a/hybridse/src/codegen/fn_let_ir_builder.cc +++ b/hybridse/src/codegen/fn_let_ir_builder.cc @@ -15,13 +15,14 @@ */ #include "codegen/fn_let_ir_builder.h" + #include "codegen/aggregate_ir_builder.h" +#include "codegen/buf_ir_builder.h" #include "codegen/context.h" #include "codegen/expr_ir_builder.h" #include "codegen/ir_base_builder.h" #include "codegen/variable_ir_builder.h" #include "glog/logging.h" -#include "vm/transform.h" using ::hybridse::common::kCodegenError; diff --git a/hybridse/src/codegen/ir_base_builder.cc b/hybridse/src/codegen/ir_base_builder.cc index 992d41d0998..81fadbfdd3d 100644 --- a/hybridse/src/codegen/ir_base_builder.cc +++ b/hybridse/src/codegen/ir_base_builder.cc @@ -556,7 +556,24 @@ bool GetFullType(node::NodeManager* nm, ::llvm::Type* type, return false; } case hybridse::node::kMap: { - LOG(WARNING) << "fail to get type for map"; + if (type->isPointerTy()) { + auto type_pointee = type->getPointerElementType(); + if (type_pointee->isStructTy()) { + auto* key_type = type_pointee->getStructElementType(1); + const node::TypeNode* key = nullptr; + if (key_type->isPointerTy() && !GetFullType(nm, key_type->getPointerElementType(), &key)) { + return false; + } + const node::TypeNode* value = nullptr; + auto* value_type = type_pointee->getStructElementType(2); + if (value_type->isPointerTy() && !GetFullType(nm, value_type->getPointerElementType(), &value)) { + return false; + } + + *type_node = nm->MakeNode(key, value); + return true; + } + } return false; } default: { @@ -643,6 +660,9 @@ bool GetBaseType(::llvm::Type* type, ::hybridse::node::DataType* output) { } else if (struct_name.startswith("fe.array_")) { *output = hybridse::node::kArray; return true; + } else if (struct_name.startswith("fe.map_")) { + *output = hybridse::node::kMap; + return true; } LOG(WARNING) << "no mapping pointee_ty for llvm pointee_ty " << pointee_ty->getStructName().str(); diff --git a/hybridse/src/codegen/ir_base_builder_test.h b/hybridse/src/codegen/ir_base_builder_test.h index 478d8ae5ea3..af29e4fd56c 100644 --- a/hybridse/src/codegen/ir_base_builder_test.h +++ b/hybridse/src/codegen/ir_base_builder_test.h @@ -22,8 +22,8 @@ #include #include +#include "codegen/ir_base_builder.h" #include "llvm/IR/Verifier.h" -#include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" #include "base/fe_status.h" @@ -34,8 +34,7 @@ #include "passes/resolve_fn_and_attrs.h" #include "udf/default_udf_library.h" #include "udf/literal_traits.h" -#include "udf/udf.h" -#include "vm/sql_compiler.h" +#include "vm/jit_wrapper.h" namespace hybridse { namespace codegen { @@ -360,8 +359,7 @@ void ModuleFunctionBuilderWithFullInfo::ExpandApplyArg( ::llvm::Value* alloca; if (TypeIRBuilder::IsStructPtr(expect_ty)) { auto struct_builder = - StructTypeIRBuilder::CreateStructTypeIRBuilder( - function->getEntryBlock().getModule(), expect_ty); + StructTypeIRBuilder::CreateStructTypeIRBuilder(function->getEntryBlock().getModule(), expect_ty); struct_builder->CreateDefault(&function->getEntryBlock(), &alloca); arg = builder.CreateSelect( diff --git a/hybridse/src/codegen/map_ir_builder.cc b/hybridse/src/codegen/map_ir_builder.cc new file mode 100644 index 00000000000..8945c88f9b7 --- /dev/null +++ b/hybridse/src/codegen/map_ir_builder.cc @@ -0,0 +1,326 @@ +/* + * Copyright 2022 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/map_ir_builder.h" + +#include + +#include "absl/status/status.h" +#include "codegen/array_ir_builder.h" +#include "codegen/cast_expr_ir_builder.h" +#include "codegen/context.h" +#include "codegen/ir_base_builder.h" +#include "codegen/cond_select_ir_builder.h" +#include "codegen/predicate_expr_ir_builder.h" + +namespace hybridse { +namespace codegen { + +static const char* PREFIX = "fe.map"; +#define SZ_IDX 0 +#define KEY_VEC_IDX 1 +#define VALUE_VEC_IDX 2 +#define VALUE_NULL_VEC_IDX 3 + +MapIRBuilder::MapIRBuilder(::llvm::Module* m, ::llvm::Type* key_ty, ::llvm::Type* value_ty) + : StructTypeIRBuilder(m), key_type_(key_ty), value_type_(value_ty) { + InitStructType(); +} + +void MapIRBuilder::InitStructType() { + std::string name = + absl::StrCat(PREFIX, "__", GetLlvmObjectString(key_type_), "_", GetLlvmObjectString(value_type_), "__"); + ::llvm::StringRef sr(name); + ::llvm::StructType* stype = m_->getTypeByName(sr); + if (stype != NULL) { + struct_type_ = stype; + return; + } + stype = ::llvm::StructType::create(m_->getContext(), name); + + // %map__{key}_{value}__ = { size, vec, vec, vec } + ::llvm::Type* size_type = ::llvm::IntegerType::getInt64Ty(m_->getContext()); + // ::llvm::Type* key_vec = ::llvm::VectorType::get(key_type_, {MIN_VEC_SIZE, true}); + LOG(INFO) << "key vec is " << GetLlvmObjectString(key_type_); + ::llvm::Type* key_vec = key_type_->getPointerTo(); + ::llvm::Type* value_vec = value_type_->getPointerTo(); + ::llvm::Type* value_null_type = ::llvm::IntegerType::getInt1Ty(m_->getContext())->getPointerTo(); + stype->setBody({size_type, key_vec, value_vec, value_null_type}); + struct_type_ = stype; +} + +absl::StatusOr MapIRBuilder::Construct(CodeGenContext* ctx, absl::Span args) const { + EnsureOK(); + + ::llvm::Value* map_alloca = nullptr; + if (!Allocate(ctx->GetCurrentBlock(), &map_alloca)) { + return absl::FailedPreconditionError(absl::StrCat("unable to allocate ", GetLlvmObjectString(struct_type_))); + } + + auto builder = ctx->GetBuilder(); + auto* original_size = builder->getInt64(args.size() / 2); + auto* key_vec = builder->CreateAlloca(key_type_, original_size, "key_vec"); + auto* value_vec = builder->CreateAlloca(value_type_, original_size, "value_vec"); + auto* value_nulls_vec = builder->CreateAlloca(builder->getInt1Ty(), original_size, "value_nulls_vec"); + + // creating raw values for map + + CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + + // original vector, may contains duplicate keys + auto* original_keys = builder->CreateAlloca(key_type_, original_size, "original_keys"); + auto* original_keys_is_null = builder->CreateAlloca(builder->getInt1Ty(), original_size, "original_keys_is_null"); + auto* original_values = builder->CreateAlloca(value_type_, original_size, "original_values"); + auto* original_values_is_null = + builder->CreateAlloca(builder->getInt1Ty(), original_size, "original_values_is_null"); + for (size_t i = 0; i < args.size(); i += 2) { + auto* update_idx = builder->getInt64(i / 2); + NativeValue key = args[i]; + if (key.GetValue(builder)->getType() != key_type_) { + auto s = cast_builder.Cast(key, key_type_, &key); + if (!s.isOK()) { + return absl::InternalError(absl::StrCat("fail to case map key: ", s.str())); + } + } + NativeValue value = args[i + 1]; + if (value.GetValue(builder)->getType() != value_type_) { + auto s = cast_builder.Cast(value, value_type_, &value); + if (!s.isOK()) { + return absl::InternalError(absl::StrCat("fail to case map value: ", s.str())); + } + } + builder->CreateStore(key.GetIsNull(ctx), builder->CreateGEP(original_keys_is_null, update_idx)); + builder->CreateStore(key.GetValue(ctx), builder->CreateGEP(original_keys, update_idx)); + builder->CreateStore(value.GetIsNull(ctx), builder->CreateGEP(original_values_is_null, update_idx)); + builder->CreateStore(value.GetValue(ctx), builder->CreateGEP(original_values, update_idx)); + } + + ::llvm::Value* update_idx_ptr = builder->CreateAlloca(builder->getInt64Ty(), nullptr, "update_idx"); + builder->CreateStore(builder->getInt64(0), update_idx_ptr); + ::llvm::Value* true_idx_ptr = builder->CreateAlloca(builder->getInt64Ty(), nullptr, "true_idx"); + builder->CreateStore(builder->getInt64(0), true_idx_ptr); + + auto s = ctx->CreateWhile( + [&](llvm::Value** cond) -> base::Status { + *cond = builder->CreateAnd( + builder->CreateICmpSLT(builder->CreateLoad(update_idx_ptr), original_size, "if_while_true"), + builder->CreateICmpSLT(builder->CreateLoad(true_idx_ptr), original_size)); + return {}; + }, + [&]() -> base::Status { + auto idx = builder->CreateLoad(update_idx_ptr, "update_idx_value"); + auto true_idx = builder->CreateLoad(true_idx_ptr, "true_idx_value"); + CHECK_STATUS(ctx->CreateBranchNot( + builder->CreateLoad(builder->CreateGEP(original_keys_is_null, idx)), [&]() -> base::Status { + // write to map if key is not null + builder->CreateStore(builder->CreateLoad(builder->CreateGEP(original_keys, idx)), + builder->CreateGEP(key_vec, true_idx)); + builder->CreateStore(builder->CreateLoad(builder->CreateGEP(original_values, idx)), + builder->CreateGEP(value_vec, true_idx)); + builder->CreateStore(builder->CreateLoad(builder->CreateGEP(original_values_is_null, idx)), + builder->CreateGEP(value_nulls_vec, true_idx)); + + builder->CreateStore(builder->CreateAdd(builder->getInt64(1), true_idx), true_idx_ptr); + return {}; + })); + + builder->CreateStore(builder->CreateAdd(builder->getInt64(1), idx), update_idx_ptr); + return {}; + }); + if (!s.isOK()) { + return absl::InternalError(s.str()); + } + + auto* final_size = builder->CreateLoad(true_idx_ptr, "true_size"); + auto as = Set(ctx, map_alloca, {final_size, key_vec, value_vec, value_nulls_vec}); + + if (!as.ok()) { + return as; + } + + return NativeValue::Create(map_alloca); +} + +bool MapIRBuilder::CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) { + llvm::Value* map_alloca = nullptr; + if (!Allocate(block, &map_alloca)) { + return false; + } + + llvm::IRBuilder<> builder(block); + ::llvm::Value* size = builder.getInt64(0); + if (!Set(block, map_alloca, SZ_IDX, size)) { + return false; + } + + *output = map_alloca; + return true; +} + +absl::StatusOr MapIRBuilder::ExtractElement(CodeGenContext* ctx, const NativeValue& arr, + const NativeValue& key) const { + EnsureOK(); + + auto builder = ctx->GetBuilder(); + auto* arr_is_null = arr.GetIsNull(ctx); + auto* key_is_null = key.GetIsNull(ctx); + + auto* out_val_alloca = builder->CreateAlloca(value_type_); + builder->CreateStore(::llvm::UndefValue::get(value_type_), out_val_alloca); + auto* out_null_alloca = builder->CreateAlloca(builder->getInt1Ty()); + builder->CreateStore(builder->getInt1(true), out_null_alloca); + + auto s = ctx->CreateBranch( + builder->CreateOr(arr_is_null, key_is_null), + [&]() -> base::Status { + return {}; + }, + [&]() -> base::Status { + NativeValue casted_key = key; + if (key.GetType() != key_type_) { + CastExprIRBuilder cast_builder(ctx->GetCurrentBlock()); + CHECK_STATUS(cast_builder.Cast(key, key_type_, &casted_key)); + } + auto* key_val = casted_key.GetValue(ctx); + + auto* map_ptr = arr.GetValue(ctx); + ::llvm::Value* sz = nullptr; + CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr, SZ_IDX, &sz), common::kCodegenError); + + ::llvm::Value* keys = nullptr; + CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr, KEY_VEC_IDX, &keys), common::kCodegenError); + + ::llvm::Value* idx_alloc = builder->CreateAlloca(builder->getInt64Ty()); + builder->CreateStore(builder->getInt64(0), idx_alloc); + ::llvm::Value* found_idx_alloc = builder->CreateAlloca(builder->getInt64Ty()); + builder->CreateStore(builder->getInt64(-1), found_idx_alloc); + + CHECK_STATUS(ctx->CreateWhile( + [&](::llvm::Value** cond) -> base::Status { + ::llvm::Value* idx = builder->CreateLoad(idx_alloc); + ::llvm::Value* found = builder->CreateLoad(found_idx_alloc); + *cond = builder->CreateAnd(builder->CreateICmpSLT(idx, sz), + builder->CreateICmpSLT(found, builder->getInt64(0))); + return {}; + }, + [&]() -> base::Status { + ::llvm::Value* idx = builder->CreateLoad(idx_alloc); + // key never null + auto* ele = builder->CreateLoad(builder->CreateGEP(keys, idx)); + ::llvm::Value* eq = nullptr; + base::Status s; + PredicateIRBuilder::BuildEqExpr(ctx->GetCurrentBlock(), ele, key_val, &eq, s); + CHECK_STATUS(s); + + ::llvm::Value* update_found_idx = builder->CreateSelect(eq, idx, builder->getInt64(-1)); + + builder->CreateStore(update_found_idx, found_idx_alloc); + builder->CreateStore(builder->CreateAdd(idx, builder->getInt64(1)), idx_alloc); + return {}; + })); + + auto* found_idx = builder->CreateLoad(found_idx_alloc); + + CHECK_STATUS(ctx->CreateBranch( + builder->CreateAnd(builder->CreateICmpSLT(found_idx, sz), + builder->CreateICmpSGE(found_idx, builder->getInt64(0))), + [&]() -> base::Status { + ::llvm::Value* values = nullptr; + CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr, VALUE_VEC_IDX, &values), common::kCodegenError); + + ::llvm::Value* value_nulls = nullptr; + CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr, VALUE_NULL_VEC_IDX, &value_nulls), + common::kCodegenError); + + auto* val = builder->CreateLoad(builder->CreateGEP(values, found_idx)); + auto* val_nullable = builder->CreateLoad(builder->CreateGEP(value_nulls, found_idx)); + + builder->CreateStore(val, out_val_alloca); + builder->CreateStore(val_nullable, out_null_alloca); + return {}; + }, + [&]() -> base::Status { return {}; })); + + return {}; + }); + + if (!s.isOK()) { + return absl::InvalidArgumentError(s.str()); + } + + auto* out_val = builder->CreateLoad(out_val_alloca); + auto* out_null_val = builder->CreateLoad(out_null_alloca); + + return NativeValue::CreateWithFlag(out_val, out_null_val); +} + +absl::StatusOr MapIRBuilder::MapKeys(CodeGenContext* ctx, const NativeValue& in) const { + EnsureOK(); + + auto map_is_null = in.GetIsNull(ctx); + auto map_ptr = in.GetValue(ctx); + + auto builder = ctx->GetBuilder(); + ::llvm::Value* keys_ptr = nullptr; + if (!Load(ctx->GetCurrentBlock(), map_ptr, KEY_VEC_IDX, &keys_ptr)) { + return absl::FailedPreconditionError("failed to extract map keys"); + } + if (!keys_ptr->getType()->isPointerTy()) { + return absl::FailedPreconditionError("map keys entry is not pointer"); + } + ::llvm::Value* size = nullptr; + if (!Load(ctx->GetCurrentBlock(), map_ptr, SZ_IDX, &size)) { + return absl::FailedPreconditionError("failed to extract map size"); + } + + // construct nulls as [false ...] + auto nulls = builder->CreateAlloca(builder->getInt1Ty(), size); + auto idx_ptr = builder->CreateAlloca(builder->getInt64Ty()); + builder->CreateStore(builder->getInt64(0), idx_ptr); + ctx->CreateWhile( + [&](::llvm::Value** cond) -> base::Status { + *cond = builder->CreateICmpSLT(builder->CreateLoad(idx_ptr), size); + return {}; + }, + [&]() -> base::Status { + auto idx = builder->CreateLoad(idx_ptr); + + builder->CreateStore(builder->getInt1(false), builder->CreateGEP(nulls, idx)); + + builder->CreateStore(builder->CreateAdd(idx, builder->getInt64(1)), idx_ptr); + return {}; + }); + + ArrayIRBuilder array_builder(ctx->GetModule(), keys_ptr->getType()->getPointerElementType()); + auto rs = array_builder.ConstructFromRaw(ctx, {keys_ptr, nulls, size}); + + if (!rs.ok()) { + return rs.status(); + } + + NativeValue out; + CondSelectIRBuilder cond_builder; + auto s = cond_builder.Select(ctx->GetCurrentBlock(), NativeValue::Create(map_is_null), + NativeValue::CreateNull(array_builder.GetType()), NativeValue::Create(rs.value()), &out); + + if (!s.isOK()) { + return absl::FailedPreconditionError(s.str()); + } + + return out; +} +} // namespace codegen +} // namespace hybridse diff --git a/hybridse/src/codegen/map_ir_builder.h b/hybridse/src/codegen/map_ir_builder.h new file mode 100644 index 00000000000..478c6cc975b --- /dev/null +++ b/hybridse/src/codegen/map_ir_builder.h @@ -0,0 +1,55 @@ +/* + * Copyright 2022 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ +#define HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ + +#include "codegen/struct_ir_builder.h" + +namespace hybridse { +namespace codegen { + +class MapIRBuilder final : public StructTypeIRBuilder { + public: + MapIRBuilder(::llvm::Module* m, ::llvm::Type* key_ty, ::llvm::Type* value_ty); + ~MapIRBuilder() override {} + + absl::StatusOr Construct(CodeGenContext* ctx, absl::Span args) const override; + + bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) override { return true; } + base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) override { + return {}; + } + + absl::StatusOr ExtractElement(CodeGenContext* ctx, const NativeValue&, + const NativeValue&) const override; + + absl::StatusOr MapKeys(CodeGenContext*, const NativeValue&) const; + + private: + void InitStructType() override; + + bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) override; + + private: + ::llvm::Type* key_type_ = nullptr; + ::llvm::Type* value_type_ = nullptr; +}; + +} // namespace codegen +} // namespace hybridse + +#endif // HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ diff --git a/hybridse/src/codegen/string_ir_builder.cc b/hybridse/src/codegen/string_ir_builder.cc index 8c41d326ee0..083c907fbe4 100644 --- a/hybridse/src/codegen/string_ir_builder.cc +++ b/hybridse/src/codegen/string_ir_builder.cc @@ -66,7 +66,7 @@ bool StringIRBuilder::CreateDefault(::llvm::BasicBlock* block, bool StringIRBuilder::NewString(::llvm::BasicBlock* block, ::llvm::Value** output) { - if (!Create(block, output)) { + if (!Allocate(block, output)) { LOG(WARNING) << "Fail to Create Default String"; return false; } @@ -86,7 +86,7 @@ bool StringIRBuilder::NewString(::llvm::BasicBlock* block, } bool StringIRBuilder::NewString(::llvm::BasicBlock* block, ::llvm::Value* size, ::llvm::Value* data, ::llvm::Value** output) { - if (!Create(block, output)) { + if (!Allocate(block, output)) { LOG(WARNING) << "Fail to Create Default String"; return false; } diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index 7adfb5d950f..4b0be401065 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -15,10 +15,15 @@ */ #include "codegen/struct_ir_builder.h" + +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "codegen/context.h" #include "codegen/date_ir_builder.h" #include "codegen/ir_base_builder.h" #include "codegen/string_ir_builder.h" #include "codegen/timestamp_ir_builder.h" + namespace hybridse { namespace codegen { StructTypeIRBuilder::StructTypeIRBuilder(::llvm::Module* m) @@ -54,6 +59,8 @@ StructTypeIRBuilder* StructTypeIRBuilder::CreateStructTypeIRBuilder(::llvm::Modu } absl::StatusOr StructTypeIRBuilder::CreateNull(::llvm::BasicBlock* block) { + EnsureOK(); + ::llvm::Value* value = nullptr; if (!CreateDefault(block, &value)) { return absl::InternalError(absl::StrCat("fail to construct ", GetLlvmObjectString(GetType()))); @@ -62,16 +69,17 @@ absl::StatusOr StructTypeIRBuilder::CreateNull(::llvm::BasicBlock* return NativeValue::CreateWithFlag(value, builder.getInt1(true)); } -::llvm::Type* StructTypeIRBuilder::GetType() { return struct_type_; } +::llvm::Type* StructTypeIRBuilder::GetType() const { return struct_type_; } -bool StructTypeIRBuilder::Create(::llvm::BasicBlock* block, +bool StructTypeIRBuilder::Allocate(::llvm::BasicBlock* block, ::llvm::Value** output) const { if (block == NULL || output == NULL) { LOG(WARNING) << "the output ptr or block is NULL "; return false; } ::llvm::IRBuilder<> builder(block); - ::llvm::Value* value = CreateAllocaAtHead(&builder, struct_type_, "struct_alloca"); + // value is a pointer to struct type + ::llvm::Value* value = CreateAllocaAtHead(&builder, struct_type_, GetLlvmObjectString(struct_type_)); *output = value; return true; } @@ -96,22 +104,10 @@ bool StructTypeIRBuilder::Set(::llvm::BasicBlock* block, ::llvm::Value* struct_v LOG(WARNING) << "Fail set Struct value: struct pointer is required"; return false; } - if (struct_value->getType()->getPointerElementType() != struct_type_) { - LOG(WARNING) << "Fail set Struct value: struct value type invalid " - << struct_value->getType() - ->getPointerElementType() - ->getStructName() - .str(); - return false; - } + ::llvm::IRBuilder<> builder(block); - builder.getInt64(1); - ::llvm::Value* value_ptr = - builder.CreateStructGEP(struct_type_, struct_value, idx); - if (nullptr == builder.CreateStore(value, value_ptr)) { - LOG(WARNING) << "Fail Set Struct Value idx = " << idx; - return false; - } + ::llvm::Value* value_ptr = builder.CreateStructGEP(struct_type_, struct_value, idx); + builder.CreateStore(value, value_ptr); return true; } @@ -137,5 +133,77 @@ bool StructTypeIRBuilder::Get(::llvm::BasicBlock* block, ::llvm::Value* struct_v *output = builder.CreateStructGEP(struct_type_, struct_value, idx); return true; } +absl::StatusOr StructTypeIRBuilder::Construct(CodeGenContext* ctx, + absl::Span args) const { + return absl::UnimplementedError(absl::StrCat("Construct for type ", GetLlvmObjectString(struct_type_))); +} + +absl::StatusOr<::llvm::Value*> StructTypeIRBuilder::ConstructFromRaw(CodeGenContext* ctx, + absl::Span<::llvm::Value* const> args) const { + EnsureOK(); + + llvm::Value* alloca = nullptr; + if (!Allocate(ctx->GetCurrentBlock(), &alloca)) { + return absl::FailedPreconditionError("failed to allocate array"); + } + + auto s = Set(ctx, alloca, args); + if (!s.ok()) { + return s; + } + + return alloca; +} + +absl::StatusOr StructTypeIRBuilder::ExtractElement(CodeGenContext* ctx, const NativeValue& arr, + const NativeValue& key) const { + return absl::UnimplementedError( + absl::StrCat("extract element unimplemented for ", GetLlvmObjectString(struct_type_))); +} + +void StructTypeIRBuilder::EnsureOK() const { + assert(struct_type_ != nullptr); + // it's a identified type + assert(!struct_type_->getName().empty()); +} +std::string StructTypeIRBuilder::GetTypeDebugString() const { return GetLlvmObjectString(struct_type_); } + +absl::Status StructTypeIRBuilder::Set(CodeGenContext* ctx, ::llvm::Value* struct_value, + absl::Span<::llvm::Value* const> members) const { + if (ctx == nullptr || struct_value == nullptr) { + return absl::InvalidArgumentError("ctx or struct pointer is null"); + } + + if (!IsStructPtr(struct_value->getType())) { + return absl::InvalidArgumentError( + absl::StrCat("value not a struct pointer: ", GetLlvmObjectString(struct_value->getType()))); + } + + if (struct_value->getType()->getPointerElementType() != struct_type_) { + return absl::InvalidArgumentError(absl::Substitute("input value has different type, expect $0 but got $1", + GetLlvmObjectString(struct_type_), + GetLlvmObjectString(struct_value->getType()))); + } + + if (members.size() != struct_type_->getNumElements()) { + return absl::InvalidArgumentError(absl::Substitute("struct $0 requires exact $1 member, but got $2", + GetLlvmObjectString(struct_type_), + struct_type_->getNumElements(), members.size())); + } + + for (unsigned idx = 0; idx < struct_type_->getNumElements(); ++idx) { + auto ele_type = struct_type_->getElementType(idx); + if (ele_type != members[idx]->getType()) { + return absl::InvalidArgumentError(absl::Substitute("$0th member: expect $1 but got $2", idx, + GetLlvmObjectString(ele_type), + GetLlvmObjectString(members[idx]->getType()))); + } + ::llvm::Value* value_ptr = ctx->GetBuilder()->CreateStructGEP(struct_type_, struct_value, idx); + ctx->GetBuilder()->CreateStore(members[idx], value_ptr); + } + + return absl::OkStatus(); +} + } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index e197665855b..f9b6ca30731 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -17,6 +17,8 @@ #ifndef HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_ #define HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_ +#include + #include "absl/status/statusor.h" #include "base/fe_status.h" #include "codegen/native_value.h" @@ -27,20 +29,46 @@ namespace codegen { class StructTypeIRBuilder : public TypeIRBuilder { public: + // TODO(ace): construct with CodeGenContext instead of llvm::Module explicit StructTypeIRBuilder(::llvm::Module*); ~StructTypeIRBuilder(); static StructTypeIRBuilder* CreateStructTypeIRBuilder(::llvm::Module*, ::llvm::Type*); static bool StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist); - virtual void InitStructType() = 0; virtual bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) = 0; virtual base::Status CastFrom(::llvm::BasicBlock* block, const NativeValue& src, NativeValue* output) = 0; - virtual bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) = 0; + // construct the default null safe struct absl::StatusOr CreateNull(::llvm::BasicBlock* block); - ::llvm::Type* GetType(); - bool Create(::llvm::BasicBlock* block, ::llvm::Value** output) const; + + virtual bool CreateDefault(::llvm::BasicBlock* block, ::llvm::Value** output) = 0; + + // Allocate and Initialize the struct value from args, each element in list represent exact argument in SQL literal. + // So for map data type, we create it in SQL with `map(key1, value1, ...)`, args is key or value for the result map + virtual absl::StatusOr Construct(CodeGenContext* ctx, absl::Span args) const; + + // construct struct value from llvm values, each element in list represent exact + // llvm struct field at that index + virtual absl::StatusOr<::llvm::Value*> ConstructFromRaw(CodeGenContext* ctx, + absl::Span<::llvm::Value* const> args) const; + + // Extract element value from composite data type + // 1. extract from array type by index + // 2. extract from struct type by field name + // 3. extract from map type by key + virtual absl::StatusOr ExtractElement(CodeGenContext* ctx, const NativeValue& arr, + const NativeValue& key) const; + + ::llvm::Type* GetType() const; + + std::string GetTypeDebugString() const; + + protected: + virtual void InitStructType() = 0; + + // allocate the given struct on current stack, no initialization + bool Allocate(::llvm::BasicBlock* block, ::llvm::Value** output) const; // Load the 'idx' th field into ''*output' // NOTE: not all types are loaded correctly, e.g for array type @@ -50,9 +78,13 @@ class StructTypeIRBuilder : public TypeIRBuilder { // Get the address of 'idx' th field bool Get(::llvm::BasicBlock* block, ::llvm::Value* struct_value, unsigned int idx, ::llvm::Value** output) const; + absl::Status Set(CodeGenContext* ctx, ::llvm::Value* struct_value, absl::Span<::llvm::Value* const> members) const; + + void EnsureOK() const; + protected: ::llvm::Module* m_; - ::llvm::Type* struct_type_; + ::llvm::StructType* struct_type_; }; } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/timestamp_ir_builder.cc b/hybridse/src/codegen/timestamp_ir_builder.cc index c3a8054e1cd..a07c29ee3de 100644 --- a/hybridse/src/codegen/timestamp_ir_builder.cc +++ b/hybridse/src/codegen/timestamp_ir_builder.cc @@ -267,7 +267,7 @@ bool TimestampIRBuilder::NewTimestamp(::llvm::BasicBlock* block, return false; } ::llvm::Value* timestamp; - if (!Create(block, ×tamp)) { + if (!Allocate(block, ×tamp)) { return false; } if (!SetTs(block, timestamp, @@ -286,7 +286,7 @@ bool TimestampIRBuilder::NewTimestamp(::llvm::BasicBlock* block, return false; } ::llvm::Value* timestamp; - if (!Create(block, ×tamp)) { + if (!Allocate(block, ×tamp)) { return false; } if (!SetTs(block, timestamp, ts)) { diff --git a/hybridse/src/codegen/type_ir_builder.cc b/hybridse/src/codegen/type_ir_builder.cc index 07adfb21855..0cba6015b9d 100644 --- a/hybridse/src/codegen/type_ir_builder.cc +++ b/hybridse/src/codegen/type_ir_builder.cc @@ -103,11 +103,7 @@ bool TypeIRBuilder::IsStringPtr(::llvm::Type* type) { } bool TypeIRBuilder::IsStructPtr(::llvm::Type* type) { - if (type->getTypeID() == ::llvm::Type::PointerTyID) { - type = reinterpret_cast<::llvm::PointerType*>(type)->getElementType(); - return type->isStructTy(); - } - return false; + return type->isPointerTy() && type->getPointerElementType()->isStructTy(); } base::Status TypeIRBuilder::UnaryOpTypeInfer( diff --git a/hybridse/src/codegen/udf_ir_builder.cc b/hybridse/src/codegen/udf_ir_builder.cc index 5030f3cd8ae..c9f613e5748 100644 --- a/hybridse/src/codegen/udf_ir_builder.cc +++ b/hybridse/src/codegen/udf_ir_builder.cc @@ -16,6 +16,8 @@ #include "codegen/udf_ir_builder.h" +#include +#include #include #include "codegen/context.h" @@ -172,7 +174,7 @@ Status UdfIRBuilder::BuildCodeGenUdfCall( } NativeValue gen_output; - CHECK_STATUS(gen_impl->gen(ctx_, args, &gen_output)); + CHECK_STATUS(gen_impl->gen(ctx_, args, {fn->GetReturnType(), fn->IsReturnNullable()}, &gen_output)); if (ret_null != nullptr) { if (gen_output.IsNullable()) { diff --git a/hybridse/src/node/expr_node.cc b/hybridse/src/node/expr_node.cc index 44acc336cef..8ad099a98b4 100644 --- a/hybridse/src/node/expr_node.cc +++ b/hybridse/src/node/expr_node.cc @@ -19,8 +19,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "codec/fe_row_codec.h" -#include "codegen/arithmetic_expr_ir_builder.h" -#include "codegen/type_ir_builder.h" #include "node/node_manager.h" #include "node/sql_node.h" #include "passes/expression/expr_pass.h" @@ -210,18 +208,26 @@ Status ExprNode::IsCastAccept(node::NodeManager* nm, const TypeNode* src, // this handles compatible type when both lhs and rhs are basic types // composited types like array, list, tuple are not handled correctly, so do not expect the function to handle those -// types absl::StatusOr ExprNode::CompatibleType(NodeManager* nm, const TypeNode* lhs, const TypeNode* rhs) { if (*lhs == *rhs) { // include Null = Null return rhs; } + + if (lhs->base() == kVoid && rhs->base() == kNull) { + return lhs; + } + + if (lhs->base() == kNull && rhs->base() == kVoid) { + return rhs; + } + if (lhs->IsNull()) { - // NULL + T -> T + // NULL/VOID + T -> T return rhs; } if (rhs->IsNull()) { - // T + NULL -> T + // T + NULL/VOID -> T return lhs; } @@ -845,21 +851,15 @@ Status ArrayExpr::InferAttr(ExprAnalysisContext* ctx) { return Status::OK(); } - // auto top_type = ctx->node_manager()->MakeTypeNode(kArray); TypeNode* top_type = nullptr; auto nm = ctx->node_manager(); - if (children_.empty()) { - FAIL_STATUS(kTypeError, "element type unknown for empty array expression"); - } else { - const TypeNode* ele_type = children_[0]->GetOutputType(); - for (size_t i = 1; i < children_.size() ; ++i) { - auto res = CompatibleType(ctx->node_manager(), ele_type, children_[i]->GetOutputType()); - CHECK_TRUE(res.ok(), kTypeError, res.status()); - ele_type = res.value(); - } - CHECK_TRUE(!ele_type->IsNull(), kTypeError, "unable to infer array type, all elements are null"); - top_type = nm->MakeArrayType(ele_type, children_.size()); + const TypeNode* ele_type = nm->MakeNode(); // void type + for (size_t i = 0; i < children_.size(); ++i) { + auto res = CompatibleType(ctx->node_manager(), ele_type, children_[i]->GetOutputType()); + CHECK_TRUE(res.ok(), kTypeError, res.status()); + ele_type = res.value(); } + top_type = nm->MakeArrayType(ele_type, children_.size()); SetOutputType(top_type); // array is nullable SetNullable(true); @@ -1142,5 +1142,50 @@ ExprNode* ExprNode::DeepCopy(NodeManager* nm) const { return root; } +ArrayElementExpr::ArrayElementExpr(ExprNode* array, ExprNode* pos) : ExprNode(kExprArrayElement) { + AddChild(array); + AddChild(pos); +} + +void ArrayElementExpr::Print(std::ostream& output, const std::string& org_tab) const { + // Print for ExprNode just talk too much, I don't intend impl that + // GetExprString is much simpler + output << org_tab << GetExprString(); +} + +const std::string ArrayElementExpr::GetExprString() const { + return absl::StrCat(array()->GetExprString(), "[", position()->GetExprString(), "]"); +} + +ArrayElementExpr* ArrayElementExpr::ShadowCopy(NodeManager* nm) const { + return nm->MakeNode(array(), position()); +} + +Status ArrayElementExpr::InferAttr(ExprAnalysisContext* ctx) { + auto* arr_type = array()->GetOutputType(); + auto* pos_type = position()->GetOutputType(); + + if (arr_type->IsMap()) { + auto map_type = arr_type->GetAsOrNull(); + CHECK_TRUE(node::ExprNode::IsSafeCast(pos_type, map_type->key_type()), common::kTypeError, + "incompatiable key type for ArrayElement, expect ", map_type->key_type()->DebugString(), ", got ", + pos_type->DebugString()); + + SetOutputType(map_type->value_type()); + SetNullable(map_type->value_nullable()); + } else if (arr_type->IsArray()) { + CHECK_TRUE(pos_type->IsInteger(), common::kTypeError, + "index type mismatch for ArrayElement, expect integer, got ", pos_type->DebugString()); + CHECK_TRUE(arr_type->GetGenericSize() == 1, common::kTypeError, "internal error: array of empty T"); + + SetOutputType(arr_type->GetGenericType(0)); + SetNullable(arr_type->IsGenericNullable(0)); + } else { + FAIL_STATUS(common::kTypeError, "can't get element from ", arr_type->DebugString(), ", expect map or array"); + } + return {}; +} +ExprNode *ArrayElementExpr::array() const { return GetChild(0); } +ExprNode *ArrayElementExpr::position() const { return GetChild(1); } } // namespace node } // namespace hybridse diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc index 86d51249e19..62aa8ede65f 100644 --- a/hybridse/src/node/node_manager.cc +++ b/hybridse/src/node/node_manager.cc @@ -484,12 +484,6 @@ SqlNode *NodeManager::MakeColumnIndexNode(SqlNodeList *keys, SqlNode *ts, SqlNod return RegisterNode(node_ptr); } -SqlNode *NodeManager::MakeColumnDescNode(const std::string &column_name, const DataType data_type, bool op_not_null, - ExprNode *default_value) { - SqlNode *node_ptr = new ColumnDefNode(column_name, data_type, op_not_null, default_value); - return RegisterNode(node_ptr); -} - SqlNodeList *NodeManager::MakeNodeList() { SqlNodeList *new_list_ptr = new SqlNodeList(); RegisterNode(new_list_ptr); diff --git a/hybridse/src/node/sql_node.cc b/hybridse/src/node/sql_node.cc index 9114bad2d53..f5543d6e8b8 100644 --- a/hybridse/src/node/sql_node.cc +++ b/hybridse/src/node/sql_node.cc @@ -17,7 +17,6 @@ #include "node/sql_node.h" #include -#include #include #include #include @@ -142,6 +141,7 @@ static absl::flat_hash_map CreateExprTypeNamesMap() {kExprOrderExpression, "order"}, {kExprEscaped, "escape"}, {kExprArray, "array"}, + {kExprArrayElement, "array element"}, }; for (auto kind = 0; kind < ExprType::kExprLast; ++kind) { DCHECK(map.find(static_cast(kind)) != map.end()); @@ -1185,6 +1185,7 @@ static absl::flat_hash_map CreateSqlNodeTypeToNa {kDynamicUdafFnDef, "kDynamicUdafFnDef"}, {kWithClauseEntry, "kWithClauseEntry"}, {kAlterTableStmt, "kAlterTableStmt"}, + {kColumnSchema, "kColumnSchema"}, }; for (auto kind = 0; kind < SqlNodeType::kSqlNodeTypeLast; ++kind) { DCHECK(map.find(static_cast(kind)) != map.end()) @@ -1454,19 +1455,35 @@ void CreateTableLikeClause::Print(std::ostream &output, const std::string &tab) output << "\n"; } +std::string ColumnSchemaNode::DebugString() const { + auto res = DataTypeName(type()); + if (!generics().empty()) { + absl::StrAppend(&res, "<", + absl::StrJoin(generics(), ", ", + [](std::string *out, const ColumnSchemaNode *in) { + absl::StrAppend(out, in->DebugString()); + }), + ">"); + } + + if (not_null()) { + absl::StrAppend(&res, " NOT NULL"); + } + + if (default_value()) { + absl::StrAppend(&res, " DEFAULT ", default_value()->GetExprString()); + } + + return res; +} + void ColumnDefNode::Print(std::ostream &output, const std::string &org_tab) const { SqlNode::Print(output, org_tab); const std::string tab = org_tab + INDENT + SPACE_ED; output << "\n"; - PrintValue(output, tab, column_name_, "column_name", false); - output << "\n"; - PrintValue(output, tab, DataTypeName(column_type_), "column_type", false); + PrintValue(output, tab, GetColumnName(), "column_name", false); output << "\n"; - PrintValue(output, tab, std::to_string(op_not_null_), "NOT NULL", !default_value_); - if (default_value_) { - output << "\n"; - PrintSqlNode(output, tab, default_value_, "default_value", true); - } + PrintValue(output, tab, schema_->DebugString(), "column_type", true); } void ColumnIndexNode::SetTTL(ExprListNode *ttl_node_list) { @@ -1995,25 +2012,6 @@ void StructExpr::Print(std::ostream &output, const std::string &org_tab) const { PrintSqlNode(output, tab, methods_, "methods", true); } -void TypeNode::Print(std::ostream &output, const std::string &org_tab) const { - SqlNode::Print(output, org_tab); - const std::string tab = org_tab + INDENT + SPACE_ED; - - output << "\n"; - PrintValue(output, tab, GetName(), "type", true); -} -bool TypeNode::Equals(const SqlNode *node) const { - if (!SqlNode::Equals(node)) { - return false; - } - - const TypeNode *that = dynamic_cast(node); - return this->base_ == that->base_ && - std::equal( - this->generics_.cbegin(), this->generics_.cend(), that->generics_.cbegin(), - [&](const hybridse::node::TypeNode *a, const hybridse::node::TypeNode *b) { return TypeEquals(a, b); }); -} - void JoinNode::Print(std::ostream &output, const std::string &org_tab) const { TableRefNode::Print(output, org_tab); diff --git a/hybridse/src/node/sql_node_test.cc b/hybridse/src/node/sql_node_test.cc index e2938656dcc..98e850806ea 100644 --- a/hybridse/src/node/sql_node_test.cc +++ b/hybridse/src/node/sql_node_test.cc @@ -209,11 +209,11 @@ TEST_F(SqlNodeTest, MakeWindowDefNodetTest) { ExprListNode *partitions = node_manager_->MakeExprList(); ExprNode *ptr1 = node_manager_->MakeColumnRefNode("keycol", ""); - partitions->PushBack(ptr1); + partitions->AddChild(ptr1); ExprNode *ptr2 = node_manager_->MakeOrderExpression(node_manager_->MakeColumnRefNode("col1", ""), true); ExprListNode *orders = node_manager_->MakeExprList(); - orders->PushBack(ptr2); + orders->AddChild(ptr2); int64_t maxsize = 0; SqlNode *frame = @@ -286,28 +286,28 @@ TEST_F(SqlNodeTest, NewFrameNodeTest) { TEST_F(SqlNodeTest, MakeInsertNodeTest) { ExprListNode *column_expr_list = node_manager_->MakeExprList(); ExprNode *ptr1 = node_manager_->MakeColumnRefNode("col1", ""); - column_expr_list->PushBack(ptr1); + column_expr_list->AddChild(ptr1); ExprNode *ptr2 = node_manager_->MakeColumnRefNode("col2", ""); - column_expr_list->PushBack(ptr2); + column_expr_list->AddChild(ptr2); ExprNode *ptr3 = node_manager_->MakeColumnRefNode("col3", ""); - column_expr_list->PushBack(ptr3); + column_expr_list->AddChild(ptr3); ExprNode *ptr4 = node_manager_->MakeColumnRefNode("col4", ""); - column_expr_list->PushBack(ptr4); + column_expr_list->AddChild(ptr4); ExprListNode *value_expr_list = node_manager_->MakeExprList(); ExprNode *value1 = node_manager_->MakeConstNode(1); ExprNode *value2 = node_manager_->MakeConstNode(2.3f); ExprNode *value3 = node_manager_->MakeConstNode(2.3); ExprNode *value4 = node_manager_->MakeParameterExpr(1); - value_expr_list->PushBack(value1); - value_expr_list->PushBack(value2); - value_expr_list->PushBack(value3); - value_expr_list->PushBack(value4); + value_expr_list->AddChild(value1); + value_expr_list->AddChild(value2); + value_expr_list->AddChild(value3); + value_expr_list->AddChild(value4); ExprListNode *insert_values = node_manager_->MakeExprList(); - insert_values->PushBack(value_expr_list); + insert_values->AddChild(value_expr_list); SqlNode *node_ptr = node_manager_->MakeInsertTableNode("", "t1", column_expr_list, insert_values); ASSERT_EQ(kInsertStmt, node_ptr->GetType()); @@ -670,11 +670,17 @@ TEST_F(SqlNodeTest, CreateIndexNodeTest) { ColumnIndexNode *index_node = dynamic_cast(node_manager_->MakeColumnIndexNode(index_items)); CreatePlanNode *node = node_manager_->MakeCreateTablePlanNode( "", "t1", - {node_manager_->MakeColumnDescNode("col1", node::kInt32, true), - node_manager_->MakeColumnDescNode("col2", node::kInt32, true), - node_manager_->MakeColumnDescNode("col3", node::kFloat, true), - node_manager_->MakeColumnDescNode("col4", node::kVarchar, true), - node_manager_->MakeColumnDescNode("col5", node::kTimestamp, true), index_node}, + {node_manager_->MakeNode( + "col1", node_manager_->MakeNode(node::kInt32, true, nullptr)), + node_manager_->MakeNode( + "col2", node_manager_->MakeNode(node::kInt32, true, nullptr)), + node_manager_->MakeNode( + "col3", node_manager_->MakeNode(node::kFloat, true, nullptr)), + node_manager_->MakeNode( + "col4", node_manager_->MakeNode(node::kVarchar, true, nullptr)), + node_manager_->MakeNode( + "col5", node_manager_->MakeNode(node::kTimestamp, true, nullptr)), + index_node}, {node_manager_->MakeReplicaNumNode(3), node_manager_->MakePartitionNumNode(8), node_manager_->MakeNode(kMemory)}, false); diff --git a/hybridse/src/node/type_node.cc b/hybridse/src/node/type_node.cc index e0052fca74c..c3c1015ce8f 100644 --- a/hybridse/src/node/type_node.cc +++ b/hybridse/src/node/type_node.cc @@ -20,7 +20,6 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_cat.h" #include "node/node_manager.h" -#include "vm/physical_op.h" namespace hybridse { namespace node { @@ -52,7 +51,11 @@ bool TypeNode::IsTimestamp() const { return base_ == node::kTimestamp; } bool TypeNode::IsString() const { return base_ == node::kVarchar; } bool TypeNode::IsArithmetic() const { return IsInteger() || IsFloating(); } bool TypeNode::IsNumber() const { return IsInteger() || IsFloating(); } -bool TypeNode::IsNull() const { return base_ == node::kNull; } + +// Better function name ? Note the difference of VOID and NULL, VOID is a data type +// while NULL is a placeholder for missing or unknown information, not a real data type. +bool TypeNode::IsNull() const { return base_ == node::kNull || base_ == node::kVoid; } + bool TypeNode::IsBool() const { return base_ == node::kBool; } bool TypeNode::IsIntegral() const { @@ -137,5 +140,89 @@ FixedArrayType *FixedArrayType::ShadowCopy(NodeManager *nm) const { return nm->MakeArrayType(element_type(), num_elements_); } +void TypeNode::AddGeneric(const node::TypeNode *dtype, bool nullable) { + generics_.push_back(dtype); + generics_nullable_.push_back(nullable); +} +const hybridse::node::TypeNode *TypeNode::GetGenericType(size_t idx) const { return generics_[idx]; } +const std::string TypeNode::GetName() const { + std::string type_name = DataTypeName(base_); + if (!generics_.empty()) { + for (auto type : generics_) { + type_name.append("_"); + type_name.append(type->GetName()); + } + } + return type_name; +} + +void TypeNode::Print(std::ostream &output, const std::string &org_tab) const { + SqlNode::Print(output, org_tab); + const std::string tab = org_tab + INDENT + SPACE_ED; + + output << "\n"; + PrintValue(output, tab, GetName(), "type", true); +} +bool TypeNode::Equals(const SqlNode *node) const { + if (!SqlNode::Equals(node)) { + return false; + } + + const TypeNode *that = dynamic_cast(node); + return this->base_ == that->base_ && + std::equal( + this->generics_.cbegin(), this->generics_.cend(), that->generics_.cbegin(), + [&](const hybridse::node::TypeNode *a, const hybridse::node::TypeNode *b) { return TypeEquals(a, b); }); +} + +const std::string OpaqueTypeNode::GetName() const { return "opaque<" + std::to_string(bytes_) + ">"; } + +MapType::MapType(const TypeNode *key_ty, const TypeNode *value_ty, bool value_not_null) : TypeNode(node::kMap) { + // map key does not accept null, value is nullable unless extra attributes specified + AddGeneric(key_ty, false); + AddGeneric(value_ty, !value_not_null); +} +MapType::~MapType() {} +const TypeNode *MapType::key_type() const { return GetGenericType(0); } +const TypeNode *MapType::value_type() const { return GetGenericType(1); } +bool MapType::value_nullable() const { return IsGenericNullable(1); } + +// MAP +// 1. ALL KEYs or VALUEs must share a least common type. +// 2. KEY is simple type only: void/bool/numeric/data/timestamp/string +// 3. Resolve to MAP if arguments is empty +absl::StatusOr MapType::InferMapType(NodeManager* nm, absl::Span types) { + if (types.size() % 2 != 0) { + return absl::InvalidArgumentError("map expects a positive even number of arguments"); + } + + const node::TypeNode* key = nm->MakeNode(); // void type + const node::TypeNode* value = nm->MakeNode(); // void type + for (size_t i = 0; i < types.size(); i += 2) { + if (!types[i].type()->IsBaseOrNullType()) { + return absl::FailedPreconditionError( + absl::StrCat("key type for map should be void/bool/numeric/data/timestamp/string only, got ", + types[i].type()->DebugString())); + } + auto key_res = node::ExprNode::CompatibleType(nm, key, types[i].type()); + if (!key_res.ok()) { + return key_res.status(); + } + key = key_res.value(); + auto value_res = node::ExprNode::CompatibleType(nm, value, types[i + 1].type()); + if (!value_res.ok()) { + return value_res.status(); + } + value = value_res.value(); + } + + if (!types.empty() && (key->base() == kVoid || value->base() == kVoid)) { + // only empty map resolved to MAP + return absl::FailedPreconditionError("KEY/VALUE type of non-empty map can't be VOID"); + } + + return nm->MakeNode(key, value); +} + } // namespace node } // namespace hybridse diff --git a/hybridse/src/passes/lambdafy_projects.h b/hybridse/src/passes/lambdafy_projects.h index 3371cd12902..6afed956ee3 100644 --- a/hybridse/src/passes/lambdafy_projects.h +++ b/hybridse/src/passes/lambdafy_projects.h @@ -17,16 +17,12 @@ #ifndef HYBRIDSE_SRC_PASSES_LAMBDAFY_PROJECTS_H_ #define HYBRIDSE_SRC_PASSES_LAMBDAFY_PROJECTS_H_ -#include #include #include #include #include "node/expr_node.h" -#include "node/plan_node.h" #include "node/sql_node.h" -#include "udf/udf_library.h" -#include "vm/schemas_context.h" namespace hybridse { namespace passes { diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index 164dba11f2b..e05e639efb1 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -345,7 +345,8 @@ base::Status Planner::CreateSelectQueryPlan(const node::SelectQueryNode *root, n return base::Status::OK(); } -base::Status Planner::CreateSetOperationPlan(const node::SetOperationNode *root, node::SetOperationPlanNode **plan_tree) { +base::Status Planner::CreateSetOperationPlan(const node::SetOperationNode *root, + node::SetOperationPlanNode **plan_tree) { CHECK_TRUE(nullptr != root, common::kPlanError, "can not create query plan node with null query node") auto list = node_manager_->MakeList(); diff --git a/hybridse/src/plan/planner.h b/hybridse/src/plan/planner.h index 731663ab246..6da3068fdd8 100644 --- a/hybridse/src/plan/planner.h +++ b/hybridse/src/plan/planner.h @@ -49,6 +49,7 @@ class Planner { virtual ~Planner() {} virtual base::Status CreatePlanTree(const NodePointVector &parser_trees, PlanNodeList &plan_trees) = 0; // NOLINT (runtime/references) + static base::Status TransformTableDef(const std::string &table_name, const NodePointVector &column_desc_list, type::TableDef *table); bool MergeWindows(const std::map &map, @@ -132,11 +133,11 @@ class SimplePlanner : public Planner { bool enable_batch_window_parallelization = true, const std::unordered_map* extra_options = nullptr) : Planner(manager, is_batch_mode, is_cluster_optimized, enable_batch_window_parallelization, extra_options) {} - ~SimplePlanner() {} + ~SimplePlanner() override {} protected: base::Status CreatePlanTree(const NodePointVector &parser_trees, - PlanNodeList &plan_trees); // NOLINT + PlanNodeList &plan_trees) override; // NOLINT }; } // namespace plan diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 5d9eb939113..163de8e53be 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -25,8 +25,10 @@ #include "absl/strings/match.h" #include "absl/types/span.h" #include "base/fe_status.h" +#include "node/sql_node.h" #include "udf/udf.h" #include "zetasql/parser/ast_node_kind.h" +#include "zetasql/parser/parse_tree_manual.h" namespace hybridse { namespace plan { @@ -57,6 +59,10 @@ static base::Status ConvertAlterTableStmt(const zetasql::ASTAlterTableStatement* node::SqlNode** out); static base::Status ConvertSetOperation(const zetasql::ASTSetOperation* stmt, node::NodeManager* nm, node::SetOperationNode** out); +static base::Status ConvertSchemaNode(const zetasql::ASTColumnSchema* stmt, node::NodeManager* nm, + node::ColumnSchemaNode** out); +static base::Status ConvertArrayElement(const zetasql::ASTArrayElement* expr, node::NodeManager* nm, + node::ArrayElementExpr** out); /// Used to convert zetasql ASTExpression Node into our ExprNode base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node::NodeManager* node_manager, @@ -107,6 +113,13 @@ base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node: } return base::Status::OK(); } + case zetasql::AST_ARRAY_ELEMENT: { + node::ArrayElementExpr* expr = nullptr; + CHECK_STATUS( + ConvertGuard(ast_expression, node_manager, &expr, ConvertArrayElement)); + *output = expr; + return base::Status::OK(); + } case zetasql::AST_CASE_VALUE_EXPRESSION: { auto* case_expression = ast_expression->GetAsOrDie(); auto& arguments = case_expression->arguments(); @@ -123,7 +136,7 @@ base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node: node::ExprNode* then_expr = nullptr; CHECK_STATUS(ConvertExprNode(arguments[i], node_manager, &when_expr)) CHECK_STATUS(ConvertExprNode(arguments[i + 1], node_manager, &then_expr)) - when_list_expr->PushBack(node_manager->MakeWhenNode(when_expr, then_expr)); + when_list_expr->AddChild(node_manager->MakeWhenNode(when_expr, then_expr)); i += 2; } else { CHECK_STATUS(ConvertExprNode(arguments[i], node_manager, &else_expr)) @@ -147,7 +160,7 @@ base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node: node::ExprNode* then_expr = nullptr; CHECK_STATUS(ConvertExprNode(arguments[i], node_manager, &when_expr)) CHECK_STATUS(ConvertExprNode(arguments[i + 1], node_manager, &then_expr)) - when_list_expr->PushBack(node_manager->MakeWhenNode(when_expr, then_expr)); + when_list_expr->AddChild(node_manager->MakeWhenNode(when_expr, then_expr)); i += 2; } else { CHECK_STATUS(ConvertExprNode(arguments[i], node_manager, &else_expr)) @@ -1475,9 +1488,7 @@ base::Status ConvertCreateProcedureNode(const zetasql::ASTCreateProcedureStateme } // case element -// ASTColumnDefinition -> case element.schema -// ASSTSimpleColumnSchema -> ColumnDeefNode -// otherwise -> not implemented +// ASTColumnDefinition -> ColumnDefNode // ASTIndexDefinition -> ColumnIndexNode // otherwise -> not implemented base::Status ConvertTableElement(const zetasql::ASTTableElement* element, node::NodeManager* node_manager, @@ -1489,38 +1500,10 @@ base::Status ConvertTableElement(const zetasql::ASTTableElement* element, node:: auto column_def = element->GetAsOrNull(); CHECK_TRUE(column_def != nullptr, common::kSqlAstError, "not an ASTColumnDefinition"); - auto not_null_columns = column_def->schema()->FindAttributes( - zetasql::AST_NOT_NULL_COLUMN_ATTRIBUTE); - bool not_null = !not_null_columns.empty(); - const std::string name = column_def->name()->GetAsString(); - - auto kind = column_def->schema()->node_kind(); - switch (kind) { - case zetasql::AST_SIMPLE_COLUMN_SCHEMA: { - // only simple column schema is supported - auto simple_column_schema = column_def->schema()->GetAsOrNull(); - CHECK_TRUE(simple_column_schema != nullptr, common::kSqlAstError, "not and ASTSimpleColumnSchema"); - - std::string type_name = ""; - CHECK_STATUS(AstPathExpressionToString(simple_column_schema->type_name(), &type_name)) - node::DataType type; - CHECK_STATUS(node::StringToDataType(type_name, &type)); - - node::ExprNode* default_value = nullptr; - if (simple_column_schema->default_expression()) { - CHECK_STATUS( - ConvertExprNode(simple_column_schema->default_expression(), node_manager, &default_value)); - } - - *node = node_manager->MakeColumnDescNode(name, type, not_null, default_value); - return base::Status::OK(); - } - default: { - return base::Status(common::kSqlAstError, absl::StrCat("unsupported column schema type: ", - zetasql::ASTNode::NodeKindToString(kind))); - } - } + node::ColumnSchemaNode* schema = nullptr; + CHECK_STATUS(ConvertSchemaNode(column_def->schema(), node_manager, &schema)); + *node = node_manager->MakeNode(name, schema); break; } case zetasql::AST_INDEX_DEFINITION: { @@ -1528,13 +1511,14 @@ base::Status ConvertTableElement(const zetasql::ASTTableElement* element, node:: node::ColumnIndexNode* index_node = nullptr; CHECK_STATUS(ConvertColumnIndexNode(ast_index_node, node_manager, &index_node)); *node = index_node; - return base::Status::OK(); + break; } default: { return base::Status(common::kSqlAstError, absl::StrCat("unsupported table column elemnt: ", element->GetNodeKindString())); } } + return base::Status::OK(); } // ASTIndexDefinition node @@ -1628,14 +1612,14 @@ base::Status ConvertIndexOption(const zetasql::ASTOptionsEntry* entry, node::Nod node::DataType unit; CHECK_STATUS(ASTIntervalLIteralToNum(entry->value(), &value, &unit)); auto node = node_manager->MakeConstNode(value, unit); - ttl_list->PushBack(node); + ttl_list->AddChild(node); break; } case zetasql::AST_INT_LITERAL: { int64_t value; CHECK_STATUS(ASTIntLiteralToNum(entry->value(), &value)); auto node = node_manager->MakeConstNode(value, node::kLatest); - ttl_list->PushBack(node); + ttl_list->AddChild(node); break; } case zetasql::AST_STRUCT_CONSTRUCTOR_WITH_PARENS: { @@ -1649,11 +1633,11 @@ base::Status ConvertIndexOption(const zetasql::ASTOptionsEntry* entry, node::Nod CHECK_STATUS(ASTIntervalLIteralToNum(struct_parens->field_expression(0), &value, &unit)); auto node = node_manager->MakeConstNode(value, unit); - ttl_list->PushBack(node); + ttl_list->AddChild(node); value = 0; CHECK_STATUS(ASTIntLiteralToNum(struct_parens->field_expression(1), &value)); - ttl_list->PushBack(node_manager->MakeConstNode(value, node::kLatest)); + ttl_list->AddChild(node_manager->MakeConstNode(value, node::kLatest)); break; } default: { @@ -1972,7 +1956,7 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod node::ExprListNode* column_list = node_manager->MakeExprList(); if (nullptr != root->column_list()) { for (auto column : root->column_list()->identifiers()) { - column_list->PushBack(node_manager->MakeColumnRefNode(column->GetAsString(), "")); + column_list->AddChild(node_manager->MakeColumnRefNode(column->GetAsString(), "")); } } @@ -2307,6 +2291,19 @@ base::Status ConvertASTType(const zetasql::ASTType* ast_type, node::NodeManager* }))); break; } + case zetasql::AST_MAP_TYPE: { + CHECK_STATUS((ConvertGuard( + ast_type, nm, output, + [](const zetasql::ASTMapType* map_tp, node::NodeManager* nm, node::TypeNode** out) -> base::Status { + node::TypeNode* key = nullptr; + node::TypeNode* value = nullptr; + CHECK_STATUS(ConvertASTType(map_tp->key_type(), nm, &key)); + CHECK_STATUS(ConvertASTType(map_tp->value_type(), nm, &value)); + *out = nm->MakeNode(key, value); + return base::Status::OK(); + }))); + break; + } default: { return base::Status(common::kSqlAstError, "Un-support type: " + ast_type->GetNodeKindString()); } @@ -2406,5 +2403,82 @@ base::Status ConvertSetOperation(const zetasql::ASTSetOperation* set_op, node::N } } +base::Status ConvertSchemaNode(const zetasql::ASTColumnSchema* stmt, node::NodeManager* nm, + node::ColumnSchemaNode** out) { + auto not_null_columns = + stmt->FindAttributes(zetasql::AST_NOT_NULL_COLUMN_ATTRIBUTE); + bool not_null = !not_null_columns.empty(); + + node::ExprNode* default_value = nullptr; + if (stmt->default_expression()) { + CHECK_STATUS(ConvertExprNode(stmt->default_expression(), nm, &default_value)); + } + + switch (stmt->node_kind()) { + case zetasql::AST_SIMPLE_COLUMN_SCHEMA: { + auto simple_column_schema = stmt->GetAsOrNull(); + CHECK_TRUE(simple_column_schema != nullptr, common::kSqlAstError, "not and ASTSimpleColumnSchema"); + + std::string type_name = ""; + CHECK_STATUS(AstPathExpressionToString(simple_column_schema->type_name(), &type_name)) + node::DataType type; + CHECK_STATUS(node::StringToDataType(type_name, &type)); + + *out = nm->MakeNode(type, not_null, default_value); + break; + } + case zetasql::AST_ARRAY_COLUMN_SCHEMA: { + CHECK_STATUS((ConvertGuard( + stmt, nm, out, + [not_null, default_value](const zetasql::ASTArrayColumnSchema* array_type, node::NodeManager* nm, + node::ColumnSchemaNode** out) -> base::Status { + node::ColumnSchemaNode* element_ty = nullptr; + CHECK_STATUS(ConvertSchemaNode(array_type->element_schema(), nm, &element_ty)); + + *out = nm->MakeNode( + node::DataType::kArray, std::initializer_list{element_ty}, + not_null, default_value); + return base::Status::OK(); + }))); + break; + } + case zetasql::AST_MAP_COLUMN_SCHEMA: { + CHECK_STATUS((ConvertGuard( + stmt, nm, out, + [not_null, default_value](const zetasql::ASTMapColumnSchema* map_type, node::NodeManager* nm, + node::ColumnSchemaNode** out) -> base::Status { + node::ColumnSchemaNode* key = nullptr; + CHECK_STATUS(ConvertSchemaNode(map_type->key_schema(), nm, &key)); + node::ColumnSchemaNode* value = nullptr; + CHECK_STATUS(ConvertSchemaNode(map_type->value_schema(), nm, &value)); + + *out = nm->MakeNode( + node::DataType::kMap, std::initializer_list{key, value}, + not_null, default_value); + return base::Status::OK(); + }))); + break; + } + default: { + return base::Status(common::kSqlAstError, + absl::StrCat("unsupported column schema type: ", stmt->GetNodeKindString())); + } + } + + return base::Status::OK(); +} + +base::Status ConvertArrayElement(const zetasql::ASTArrayElement* expr, node::NodeManager* nm, + node::ArrayElementExpr** out) { + node::ExprNode* array = nullptr; + node::ExprNode* pos = nullptr; + + CHECK_STATUS(ConvertExprNode(expr->array(), nm, &array)); + CHECK_STATUS(ConvertExprNode(expr->position(), nm, &pos)); + + *out = nm->MakeNode(array, pos); + return {}; +} + } // namespace plan } // namespace hybridse diff --git a/hybridse/src/planv2/ast_node_converter_test.cc b/hybridse/src/planv2/ast_node_converter_test.cc index 51447011f78..c89907817e3 100644 --- a/hybridse/src/planv2/ast_node_converter_test.cc +++ b/hybridse/src/planv2/ast_node_converter_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "absl/strings/match.h" #include "case/sql_case.h" @@ -945,20 +946,6 @@ TEST_F(ASTNodeConverterTest, ConvertCreateTableNodeErrorTest) { auto status = ConvertCreateTableNode(create_stmt, &node_manager, &output); EXPECT_EQ(common::kTypeError, status.code); } - { - // not supported schema - const std::string sql = "create table t (a Array) "; - - std::unique_ptr parser_output; - ZETASQL_ASSERT_OK(zetasql::ParseStatement(sql, zetasql::ParserOptions(), &parser_output)); - const auto* statement = parser_output->statement(); - ASSERT_TRUE(statement->Is()); - - const auto create_stmt = statement->GetAsOrDie(); - node::CreateStmt* output = nullptr; - auto status = ConvertCreateTableNode(create_stmt, &node_manager, &output); - EXPECT_EQ(common::kSqlAstError, status.code); - } { // not supported table element const std::string sql = "create table t (a int64, primary key (a)) "; diff --git a/hybridse/src/planv2/plan_api.cc b/hybridse/src/planv2/plan_api.cc index affe2ca80f0..d3f8f7644bf 100644 --- a/hybridse/src/planv2/plan_api.cc +++ b/hybridse/src/planv2/plan_api.cc @@ -16,13 +16,36 @@ #include "plan/plan_api.h" #include "planv2/planner_v2.h" +#include "zetasql/parser/parser.h" #include "zetasql/public/error_helpers.h" #include "zetasql/public/error_location.pb.h" namespace hybridse { namespace plan { -using hybridse::plan::SimplePlannerV2; +base::Status PlanAPI::CreatePlanTreeFromScript(vm::SqlContext *ctx) { + zetasql::ParserOptions parser_opts; + zetasql::LanguageOptions language_opts; + language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE); + parser_opts.set_language_options(&language_opts); + // save parse result into SqlContext so SQL engine can reference fields inside ASTNode during whole compile stage + auto zetasql_status = + zetasql::ParseScript(ctx->sql, parser_opts, zetasql::ERROR_MESSAGE_MULTI_LINE_WITH_CARET, &ctx->ast_node); + zetasql::ErrorLocation location; + if (!zetasql_status.ok()) { + zetasql::ErrorLocation location; + GetErrorLocation(zetasql_status, &location); + return {common::kSyntaxError, zetasql::FormatError(zetasql_status)}; + } + + DLOG(INFO) << "AST Node:\n" << ctx->ast_node->script()->DebugString(); + + const zetasql::ASTScript *script = ctx->ast_node->script(); + auto planner_ptr = + std::make_unique(&ctx->nm, ctx->engine_mode == vm::kBatchMode, ctx->is_cluster_optimized, + ctx->enable_batch_window_parallelization, ctx->options.get()); + return planner_ptr->CreateASTScriptPlan(script, ctx->logical_plan); +} bool PlanAPI::CreatePlanTreeFromScript(const std::string &sql, PlanNodeList &plan_trees, NodeManager *node_manager, Status &status, bool is_batch_mode, bool is_cluster, diff --git a/hybridse/src/planv2/planner_v2.h b/hybridse/src/planv2/planner_v2.h index 46627f10a90..2555ffd66e2 100644 --- a/hybridse/src/planv2/planner_v2.h +++ b/hybridse/src/planv2/planner_v2.h @@ -35,12 +35,12 @@ using node::PlanNodeList; class SimplePlannerV2 : public SimplePlanner { public: - explicit SimplePlannerV2(node::NodeManager *manager) : SimplePlanner(manager, true, false, false) {} SimplePlannerV2(node::NodeManager *manager, bool is_batch_mode, bool is_cluster_optimized = false, bool enable_batch_window_parallelization = false, const std::unordered_map *extra_options = nullptr) : SimplePlanner(manager, is_batch_mode, is_cluster_optimized, enable_batch_window_parallelization, extra_options) {} + base::Status CreateASTScriptPlan(const zetasql::ASTScript *script, PlanNodeList &plan_trees); // NOLINT (runtime/references) }; diff --git a/hybridse/src/sdk/hybridse_interface_core.i b/hybridse/src/sdk/hybridse_interface_core.i index 660f9bac7a1..9c053b69b71 100644 --- a/hybridse/src/sdk/hybridse_interface_core.i +++ b/hybridse/src/sdk/hybridse_interface_core.i @@ -118,6 +118,7 @@ SWIG_JAVABODY_PROXY(public, public, SWIGTYPE) #include "base/iterator.h" #include "vm/catalog.h" #include "vm/engine.h" +#include "vm/sql_ctx.h" #include "vm/engine_context.h" #include "vm/sql_compiler.h" #include "vm/jit_wrapper.h" @@ -140,6 +141,7 @@ using hybridse::vm::WindowOp; using hybridse::vm::EngineMode; using hybridse::vm::EngineOptions; using hybridse::vm::IndexHintHandler; +using hybridse::vm::SqlContext; using hybridse::base::Iterator; using hybridse::base::ConstIterator; using hybridse::base::Trace; diff --git a/hybridse/src/testing/engine_test_base.cc b/hybridse/src/testing/engine_test_base.cc index 7d02528b5ce..3aebea8f2de 100644 --- a/hybridse/src/testing/engine_test_base.cc +++ b/hybridse/src/testing/engine_test_base.cc @@ -409,7 +409,7 @@ Status EngineTestRunner::Compile() { DLOG(INFO) << "Physical plan:\n" << oss.str(); std::ostringstream runner_oss; - std::dynamic_pointer_cast(session_->GetCompileInfo())->GetClusterJob().Print(runner_oss, ""); + std::dynamic_pointer_cast(session_->GetCompileInfo())->GetClusterJob()->Print(runner_oss, ""); DLOG(INFO) << "Runner plan:\n" << runner_oss.str(); } return status; diff --git a/hybridse/src/udf/default_defs/map_defs.cc b/hybridse/src/udf/default_defs/map_defs.cc new file mode 100644 index 00000000000..c1cae3e554c --- /dev/null +++ b/hybridse/src/udf/default_defs/map_defs.cc @@ -0,0 +1,123 @@ +/** + * Copyright (c) 2023 4Paradigm Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/map_ir_builder.h" +#include "codegen/ir_base_builder.h" +#include "node/expr_node.h" +#include "node/type_node.h" +#include "udf/default_udf_library.h" +#include "udf/udf_registry.h" + +namespace hybridse { +namespace udf { + +void DefaultUdfLibrary::InitMapUdfs() { + RegisterCodeGenUdf("map") + .variadic_args<>( + // infer + [](UdfResolveContext* ctx, const std::vector& arg_attrs, + ExprAttrNode* out) -> base::Status { + auto ret = node::MapType::InferMapType(ctx->node_manager(), arg_attrs); + CHECK_TRUE(ret.ok(), common::kTypeError, ret.status().ToString()); + out->SetType(ret.value()); + out->SetNullable(true); + return {}; + }, + // gen + [](codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* out) -> base::Status { + CHECK_TRUE(return_info.type()->IsMap(), common::kTypeError, "not a map type output"); + auto* map_type = return_info.type()->GetAsOrNull(); + CHECK_TRUE(map_type != nullptr, common::kTypeError, "can not cast to MapType"); + + ::llvm::Type* key_type = nullptr; + ::llvm::Type* value_type = nullptr; + CHECK_TRUE(codegen::GetLlvmType(ctx->GetModule(), map_type->key_type(), &key_type), + common::kCodegenError); + CHECK_TRUE(codegen::GetLlvmType(ctx->GetModule(), map_type->value_type(), &value_type), + common::kCodegenError); + codegen::MapIRBuilder builder(ctx->GetModule(), key_type, value_type); + auto res = builder.Construct(ctx, args); + if (res.ok()) { + *out = res.value(); + return {}; + } + return {common::kCodegenError, res.status().ToString()}; + }) + .doc(R"( + @brief map(key1, value1, key2, value2, ...) - Creates a map with the given key/value pairs. + + Example: + + @code{.sql} + select map(1, '1', 2, '2'); + -- {1: "1", 2: "2"} + @endcode + + @since 0.9.0 + )"); + + RegisterCodeGenUdf("map_keys") + .args( + [](UdfResolveContext* ctx, const ExprAttrNode& in, ExprAttrNode* out) -> base::Status { + CHECK_TRUE(in.type()->IsMap(), common::kTypeError, "map_keys requires a map data type, got ", + in.type()->DebugString()); + + auto map_type = in.type()->GetAsOrNull(); + CHECK_TRUE(map_type != nullptr, common::kTypeError); + + out->SetType(ctx->node_manager()->MakeNode(node::kArray, map_type->key_type())); + out->SetNullable(true); + return {}; + }, + [](codegen::CodeGenContext* ctx, codegen::NativeValue in, const node::ExprAttrNode& return_info, + codegen::NativeValue* out) -> base::Status { + const node::TypeNode* type = nullptr; + CHECK_TRUE(codegen::GetFullType(ctx->node_manager(), in.GetType(), &type), common::kTypeError); + auto map_type = type->GetAsOrNull(); + CHECK_TRUE(map_type != nullptr, common::kTypeError); + + ::llvm::Type* key_type = nullptr; + ::llvm::Type* value_type = nullptr; + CHECK_TRUE(codegen::GetLlvmType(ctx->GetModule(), map_type->key_type(), &key_type), + common::kCodegenError); + CHECK_TRUE(codegen::GetLlvmType(ctx->GetModule(), map_type->value_type(), &value_type), + common::kCodegenError); + codegen::MapIRBuilder builder(ctx->GetModule(), key_type, value_type); + + auto res = builder.MapKeys(ctx, in); + if (res.ok()) { + *out = res.value(); + return {}; + } + return {common::kCodegenError, res.status().ToString()}; + }) + .doc(R"( + @brief map_keys(map) - Returns an unordered array containing the keys of the map. + + Example: + + @code{.sql} + select map_keys(map(1, '2', 3, '4')); + -- [1, 3] + @endcode + + @since 0.9.0 + )"); +} + +} // namespace udf +} // namespace hybridse diff --git a/hybridse/src/udf/default_udf_library.cc b/hybridse/src/udf/default_udf_library.cc index e6a546095ec..265a1e09250 100644 --- a/hybridse/src/udf/default_udf_library.cc +++ b/hybridse/src/udf/default_udf_library.cc @@ -665,6 +665,7 @@ void DefaultUdfLibrary::Init() { InitFeatureZero(); InitArrayUdfs(); + InitMapUdfs(); InitEarthDistanceUdf(); InitJsonUdfs(); @@ -794,7 +795,7 @@ void DefaultUdfLibrary::InitStringUdf() { RegisterCodeGenUdf("concat").variadic_args<>( /* infer */ [](UdfResolveContext* ctx, - const std::vector& arg_attrs, + const std::vector& arg_attrs, ExprAttrNode* out) { out->SetType(ctx->node_manager()->MakeTypeNode(node::kVarchar)); out->SetNullable(false); @@ -802,7 +803,7 @@ void DefaultUdfLibrary::InitStringUdf() { }, /* gen */ [](CodeGenContext* ctx, const std::vector& args, - NativeValue* out) { + const ExprAttrNode& return_info, NativeValue* out) { codegen::StringIRBuilder string_ir_builder(ctx->GetModule()); return string_ir_builder.Concat(ctx->GetCurrentBlock(), args, out); }) @@ -821,16 +822,16 @@ void DefaultUdfLibrary::InitStringUdf() { RegisterCodeGenUdf("concat_ws") .variadic_args( /* infer */ - [](UdfResolveContext* ctx, const ExprAttrNode* arg, - const std::vector& arg_types, + [](UdfResolveContext* ctx, const ExprAttrNode& arg, + const std::vector& arg_types, ExprAttrNode* out) { out->SetType(ctx->node_manager()->MakeTypeNode(node::kVarchar)); out->SetNullable(false); return Status::OK(); }, /* gen */ - [](CodeGenContext* ctx, NativeValue arg, - const std::vector& args, NativeValue* out) { + [](CodeGenContext* ctx, NativeValue arg, const std::vector& args, + const ExprAttrNode& return_info, NativeValue* out) { codegen::StringIRBuilder string_ir_builder(ctx->GetModule()); return string_ir_builder.ConcatWS(ctx->GetCurrentBlock(), arg, @@ -1651,7 +1652,7 @@ void DefaultUdfLibrary::InitMathUdf() { RegisterExprUdf("round") .variadic_args( - [](UdfResolveContext* ctx, ExprNode* x, const std::vector& other) -> ExprNode* { + [](UdfResolveContext* ctx, ExprNode* x, absl::Span other) -> ExprNode* { if (!x->GetOutputType()->IsArithmetic() || x->GetOutputType()->IsBool()) { ctx->SetError("round do not support first parameter of type " + x->GetOutputType()->GetName()); return nullptr; @@ -2233,18 +2234,15 @@ void DefaultUdfLibrary::InitTimeAndDateUdf() { )"); RegisterCodeGenUdf("year") - .args( - [](CodeGenContext* ctx, NativeValue date, NativeValue* out) { - codegen::DateIRBuilder date_ir_builder(ctx->GetModule()); - ::llvm::Value* ret = nullptr; - Status status; - CHECK_TRUE(date_ir_builder.Year(ctx->GetCurrentBlock(), - date.GetRaw(), &ret, status), - kCodegenError, - "Fail to build udf year(date): ", status.str()); - *out = NativeValue::Create(ret); - return status; - }) + .args([](CodeGenContext* ctx, NativeValue date, const node::ExprAttrNode& return_info, NativeValue* out) { + codegen::DateIRBuilder date_ir_builder(ctx->GetModule()); + ::llvm::Value* ret = nullptr; + Status status; + CHECK_TRUE(date_ir_builder.Year(ctx->GetCurrentBlock(), date.GetRaw(), &ret, status), kCodegenError, + "Fail to build udf year(date): ", status.str()); + *out = NativeValue::Create(ret); + return status; + }) .returns(); RegisterExternal("month") @@ -2264,7 +2262,7 @@ void DefaultUdfLibrary::InitTimeAndDateUdf() { RegisterCodeGenUdf("month") .args( - [](CodeGenContext* ctx, NativeValue date, NativeValue* out) { + [](CodeGenContext* ctx, NativeValue date, const node::ExprAttrNode& ri, NativeValue* out) { codegen::DateIRBuilder date_ir_builder(ctx->GetModule()); ::llvm::Value* ret = nullptr; Status status; @@ -2298,7 +2296,7 @@ void DefaultUdfLibrary::InitTimeAndDateUdf() { )"); RegisterCodeGenUdf("dayofmonth").args( - [](CodeGenContext* ctx, NativeValue date, NativeValue* out) { + [](CodeGenContext* ctx, NativeValue date, const node::ExprAttrNode& ri, NativeValue* out) { codegen::DateIRBuilder date_ir_builder(ctx->GetModule()); ::llvm::Value* ret = nullptr; Status status; @@ -2554,13 +2552,13 @@ void DefaultUdfLibrary::InitTimeAndDateUdf() { .variadic_args<>( /* infer */ [](UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode* out) { auto nm = ctx->node_manager(); auto tuple_type = nm->MakeTypeNode(node::kTuple); for (auto attr : args) { - tuple_type->generics_.push_back(attr->type()); - tuple_type->generics_nullable_.push_back(attr->nullable()); + tuple_type->generics_.push_back(attr.type()); + tuple_type->generics_nullable_.push_back(attr.nullable()); } out->SetType(tuple_type); out->SetNullable(false); @@ -2568,7 +2566,7 @@ void DefaultUdfLibrary::InitTimeAndDateUdf() { }, /* gen */ [](CodeGenContext* ctx, const std::vector& args, - NativeValue* out) { + const ExprAttrNode& return_info, NativeValue* out) { *out = NativeValue::CreateTuple(args); return Status::OK(); }); diff --git a/hybridse/src/udf/default_udf_library.h b/hybridse/src/udf/default_udf_library.h index be5ed6c2414..92152649fa0 100644 --- a/hybridse/src/udf/default_udf_library.h +++ b/hybridse/src/udf/default_udf_library.h @@ -52,6 +52,9 @@ class DefaultUdfLibrary : public UdfLibrary { // Array Udf defines, udfs either accept array as parameter or returns array void InitArrayUdfs(); + // Map functions + void InitMapUdfs(); + // aggregate functions for statistic void InitStatisticsUdafs(); diff --git a/hybridse/src/udf/dynamic_lib_manager.cc b/hybridse/src/udf/dynamic_lib_manager.cc index c6a034247cd..b3b281a0346 100644 --- a/hybridse/src/udf/dynamic_lib_manager.cc +++ b/hybridse/src/udf/dynamic_lib_manager.cc @@ -19,6 +19,8 @@ #include #include +#include "glog/logging.h" + namespace hybridse { namespace udf { diff --git a/hybridse/src/udf/literal_traits.h b/hybridse/src/udf/literal_traits.h index 13c876951e8..2c79c8a365d 100644 --- a/hybridse/src/udf/literal_traits.h +++ b/hybridse/src/udf/literal_traits.h @@ -18,15 +18,12 @@ #define HYBRIDSE_SRC_UDF_LITERAL_TRAITS_H_ #include -#include #include #include #include -#include #include #include -#include "base/fe_status.h" #include "base/string_ref.h" #include "base/type.h" #include "codec/fe_row_codec.h" @@ -139,8 +136,10 @@ static bool operator==(const Nullable& x, const Nullable& y) { // ===================================== // // ArrayRef // ===================================== // -template ::CCallArgType> +template struct ArrayRef { + using CType = typename DataTypeTrait::CCallArgType; + CType* raw; bool* nullables; uint64_t size; diff --git a/hybridse/src/udf/udf_registry.cc b/hybridse/src/udf/udf_registry.cc index 932174d8145..60e93460c24 100644 --- a/hybridse/src/udf/udf_registry.cc +++ b/hybridse/src/udf/udf_registry.cc @@ -206,20 +206,17 @@ Status ExprUdfRegistry::ResolveFunction(UdfResolveContext* ctx, Status LlvmUdfRegistry::ResolveFunction(UdfResolveContext* ctx, node::FnDefNode** result) { std::vector arg_types; - std::vector arg_attrs; + std::vector arg_attrs; for (size_t i = 0; i < ctx->arg_size(); ++i) { auto arg_type = ctx->arg_type(i); bool nullable = ctx->arg_nullable(i); CHECK_TRUE(arg_type != nullptr, kCodegenError, i, "th argument node type is unknown: ", name()); arg_types.push_back(arg_type); - arg_attrs.push_back(new ExprAttrNode(arg_type, nullable)); + arg_attrs.emplace_back(arg_type, nullable); } ExprAttrNode out_attr(nullptr, true); auto status = gen_impl_func_->infer(ctx, arg_attrs, &out_attr); - for (auto ptr : arg_attrs) { - delete const_cast(ptr); - } CHECK_STATUS(status, "Infer llvm output attr failed: ", status.str()); auto return_type = out_attr.type(); diff --git a/hybridse/src/udf/udf_registry.h b/hybridse/src/udf/udf_registry.h index 3ea96d25c13..d9512e581f0 100644 --- a/hybridse/src/udf/udf_registry.h +++ b/hybridse/src/udf/udf_registry.h @@ -28,13 +28,11 @@ #include #include "base/fe_status.h" -#include "codec/list_iterator_codec.h" #include "codegen/context.h" #include "node/node_manager.h" #include "node/sql_node.h" #include "udf/literal_traits.h" #include "udf/udf_library.h" -#include "vm/schemas_context.h" namespace hybridse { namespace udf { @@ -394,10 +392,11 @@ class LlvmUdfGenBase { public: virtual Status gen(codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* res) = 0; virtual Status infer(UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode*) = 0; node::TypeNode* fixed_ret_type() const { return fixed_ret_type_; } @@ -417,33 +416,36 @@ struct LlvmUdfGen : public LlvmUdfGenBase { using FType = std::function::second_type..., + const ExprAttrNode& return_info, codegen::NativeValue*)>; using InferFType = std::function::second_type..., + typename std::pair::second_type..., ExprAttrNode*)>; Status gen(codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* result) override { CHECK_TRUE(args.size() == sizeof...(Args), common::kCodegenError, "Fail to invoke LlvmUefGen::gen, args size do not " "match with template args)"); - return gen_internal(ctx, args, result, + return gen_internal(ctx, args, return_info, result, std::index_sequence_for()); } template Status gen_internal(codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* result, const std::index_sequence&) { - return gen_func(ctx, args[I]..., result); + return gen_func(ctx, args[I]..., return_info, result); } Status infer(UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode* out) override { return infer_internal(ctx, args, out, std::index_sequence_for()); @@ -451,7 +453,7 @@ struct LlvmUdfGen : public LlvmUdfGenBase { template Status infer_internal(UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode* out, const std::index_sequence&) { if (this->infer_func) { return infer_func(ctx, args[I]..., out); @@ -475,39 +477,39 @@ struct LlvmUdfGen : public LlvmUdfGenBase { template struct VariadicLLVMUdfGen : public LlvmUdfGenBase { using FType = std::function::second_type..., - const std::vector&, codegen::NativeValue*)>; + codegen::CodeGenContext*, typename std::pair::second_type..., + const std::vector&, const ExprAttrNode& return_info, codegen::NativeValue*)>; using InferFType = std::function::second_type..., - const std::vector&, ExprAttrNode*)>; + typename std::pair::second_type..., + const std::vector&, ExprAttrNode*)>; Status gen(codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* result) override { CHECK_TRUE(args.size() >= sizeof...(Args), common::kCodegenError, "Fail to invoke VariadicLLVMUdfGen::gen, " "args size do not match with template args)"); - return gen_internal(ctx, args, result, - std::index_sequence_for()); + return gen_internal(ctx, args, return_info, result, std::index_sequence_for()); }; template Status gen_internal(codegen::CodeGenContext* ctx, const std::vector& args, + const ExprAttrNode& return_info, codegen::NativeValue* result, const std::index_sequence&) { std::vector variadic_args; for (size_t i = sizeof...(I); i < args.size(); ++i) { variadic_args.emplace_back(args[i]); } - return this->gen_func(ctx, args[I]..., variadic_args, result); + return this->gen_func(ctx, args[I]..., variadic_args, return_info, result); } Status infer(UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode* out) override { return infer_internal(ctx, args, out, std::index_sequence_for()); @@ -515,9 +517,9 @@ struct VariadicLLVMUdfGen : public LlvmUdfGenBase { template Status infer_internal(UdfResolveContext* ctx, - const std::vector& args, + const std::vector& args, ExprAttrNode* out, const std::index_sequence&) { - std::vector variadic_args; + std::vector variadic_args; for (size_t i = sizeof...(I); i < args.size(); ++i) { variadic_args.emplace_back(args[i]); } @@ -723,9 +725,8 @@ class CodeGenUdfTemplateRegistryHelper { LlvmUdfRegistryHelper& helper) { // NOLINT helper.args( [](codegen::CodeGenContext* ctx, - typename std::pair< - Args, codegen::NativeValue>::second_type... args, - codegen::NativeValue* result) { + typename std::pair::second_type... args, + const ExprAttrNode& return_info, codegen::NativeValue* result) { return FTemplate()(ctx, args..., result); }); return helper.cur_def(); diff --git a/hybridse/src/udf/udf_registry_test.cc b/hybridse/src/udf/udf_registry_test.cc index 962b367819b..aac28fc8f17 100644 --- a/hybridse/src/udf/udf_registry_test.cc +++ b/hybridse/src/udf/udf_registry_test.cc @@ -384,14 +384,14 @@ TEST_F(UdfRegistryTest, test_codegen_udf_register) { library.RegisterCodeGenUdf("add").args( /* infer */ - [](UdfResolveContext* ctx, const ExprAttrNode* x, const ExprAttrNode* y, + [](UdfResolveContext* ctx, const ExprAttrNode& x, const ExprAttrNode& y, ExprAttrNode* out) { - out->SetType(x->type()); + out->SetType(x.type()); return Status::OK(); }, /* gen */ [](CodeGenContext* ctx, NativeValue x, NativeValue y, - NativeValue* out) { + const ExprAttrNode& ri, NativeValue* out) { *out = x; return Status::OK(); }); @@ -409,14 +409,14 @@ TEST_F(UdfRegistryTest, test_variadic_codegen_udf_register) { library.RegisterCodeGenUdf("concat").variadic_args<>( /* infer */ [](UdfResolveContext* ctx, - const std::vector& arg_attrs, + const std::vector& arg_attrs, ExprAttrNode* out) { - out->SetType(arg_attrs[0]->type()); + out->SetType(arg_attrs[0].type()); return Status::OK(); }, /* gen */ [](CodeGenContext* ctx, const std::vector& args, - NativeValue* out) { + const ExprAttrNode& return_info, NativeValue* out) { *out = args[0]; return Status::OK(); }); diff --git a/hybridse/src/vm/engine.cc b/hybridse/src/vm/engine.cc index c0d9be8c333..0865655f3c1 100644 --- a/hybridse/src/vm/engine.cc +++ b/hybridse/src/vm/engine.cc @@ -160,7 +160,7 @@ bool Engine::Get(const std::string& sql, const std::string& db, RunSession& sess sql_context.enable_expr_optimize = options_.IsEnableExprOptimize(); sql_context.jit_options = options_.jit_options(); sql_context.options = session.GetOptions(); - sql_context.index_hints_ = session.index_hints_; + sql_context.index_hints = session.index_hints_; if (session.engine_mode() == kBatchMode) { sql_context.parameter_types = dynamic_cast(&session)->GetParameterSchema(); } else if (session.engine_mode() == kBatchRequestMode) { @@ -191,7 +191,7 @@ bool Engine::Get(const std::string& sql, const std::string& db, RunSession& sess LOG(INFO) << "physical plan:\n" << plan_oss.str() << std::endl; } std::ostringstream runner_oss; - sql_context.cluster_job.Print(runner_oss, ""); + sql_context.cluster_job->Print(runner_oss, ""); LOG(INFO) << "cluster job:\n" << runner_oss.str() << std::endl; } return true; @@ -377,20 +377,20 @@ bool RunSession::SetCompileInfo(const std::shared_ptr& compile_info int32_t RequestRunSession::Run(const Row& in_row, Row* out_row) { DLOG(INFO) << "Request Row Run with main task"; - return Run(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job.main_task_id(), + return Run(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job->main_task_id(), in_row, out_row); } int32_t RequestRunSession::Run(const uint32_t task_id, const Row& in_row, Row* out_row) { auto task = std::dynamic_pointer_cast(compile_info_) ->get_sql_context() - .cluster_job.GetTask(task_id) + .cluster_job->GetTask(task_id) .GetRoot(); if (nullptr == task) { LOG(WARNING) << "fail to run request plan: taskid" << task_id << " not exist!"; return -2; } DLOG(INFO) << "Request Row Run with task_id " << task_id; - RunnerContext ctx(&std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, in_row, + RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, in_row, sp_name_, is_debug_); auto output = task->RunWithCache(ctx); if (!output) { @@ -405,15 +405,15 @@ int32_t RequestRunSession::Run(const uint32_t task_id, const Row& in_row, Row* o } int32_t BatchRequestRunSession::Run(const std::vector& request_batch, std::vector& output) { - return Run(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job.main_task_id(), + return Run(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job->main_task_id(), request_batch, output); } int32_t BatchRequestRunSession::Run(const uint32_t id, const std::vector& request_batch, std::vector& output) { - RunnerContext ctx(&std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, + RunnerContext ctx(std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job, request_batch, sp_name_, is_debug_); auto task = - std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job.GetTask(id).GetRoot(); + std::dynamic_pointer_cast(compile_info_)->get_sql_context().cluster_job->GetTask(id).GetRoot(); if (nullptr == task) { LOG(WARNING) << "Fail to run request plan: taskid" << id << " not exist!"; return -2; @@ -435,8 +435,8 @@ int32_t BatchRunSession::Run(std::vector& rows, uint64_t limit) { } int32_t BatchRunSession::Run(const Row& parameter_row, std::vector& rows, uint64_t limit) { auto& sql_ctx = std::dynamic_pointer_cast(compile_info_)->get_sql_context(); - RunnerContext ctx(&sql_ctx.cluster_job, parameter_row, is_debug_); - auto output = sql_ctx.cluster_job.GetTask(0).GetRoot()->RunWithCache(ctx); + RunnerContext ctx(sql_ctx.cluster_job, parameter_row, is_debug_); + auto output = sql_ctx.cluster_job->GetTask(0).GetRoot()->RunWithCache(ctx); if (!output) { DLOG(INFO) << "Run batch plan output is empty"; return 0; diff --git a/hybridse/src/vm/runner_ctx.h b/hybridse/src/vm/runner_ctx.h index 0924015450a..350d2372a09 100644 --- a/hybridse/src/vm/runner_ctx.h +++ b/hybridse/src/vm/runner_ctx.h @@ -29,8 +29,7 @@ namespace vm { class RunnerContext { public: - explicit RunnerContext(hybridse::vm::ClusterJob* cluster_job, - const hybridse::codec::Row& parameter, + explicit RunnerContext(std::shared_ptr cluster_job, const hybridse::codec::Row& parameter, const bool is_debug = false) : cluster_job_(cluster_job), sp_name_(""), @@ -39,7 +38,7 @@ class RunnerContext { parameter_(parameter), is_debug_(is_debug), batch_cache_() {} - explicit RunnerContext(hybridse::vm::ClusterJob* cluster_job, + explicit RunnerContext(std::shared_ptr cluster_job, const hybridse::codec::Row& request, const std::string& sp_name = "", const bool is_debug = false) @@ -50,7 +49,7 @@ class RunnerContext { parameter_(), is_debug_(is_debug), batch_cache_() {} - explicit RunnerContext(hybridse::vm::ClusterJob* cluster_job, + explicit RunnerContext(std::shared_ptr cluster_job, const std::vector& request_batch, const std::string& sp_name = "", const bool is_debug = false) @@ -68,7 +67,7 @@ class RunnerContext { return requests_[idx]; } const hybridse::codec::Row& GetParameterRow() const { return parameter_; } - hybridse::vm::ClusterJob* cluster_job() { return cluster_job_; } + std::shared_ptr cluster_job() { return cluster_job_; } void SetRequest(const hybridse::codec::Row& request); void SetRequests(const std::vector& requests); bool is_debug() const { return is_debug_; } @@ -81,7 +80,7 @@ class RunnerContext { void SetBatchCache(int64_t id, std::shared_ptr data); private: - hybridse::vm::ClusterJob* cluster_job_; + std::shared_ptr cluster_job_; const std::string sp_name_; hybridse::codec::Row request_; std::vector requests_; diff --git a/hybridse/src/vm/runner_test.cc b/hybridse/src/vm/runner_test.cc index ea8d9c9643e..bce8c8712d3 100644 --- a/hybridse/src/vm/runner_test.cc +++ b/hybridse/src/vm/runner_test.cc @@ -75,13 +75,13 @@ void RunnerCheck(std::shared_ptr catalog, const std::string sql, ASSERT_TRUE(ok) << compile_status; ASSERT_TRUE(sql_compiler.BuildClusterJob(sql_context, compile_status)); ASSERT_TRUE(nullptr != sql_context.physical_plan); - ASSERT_TRUE(sql_context.cluster_job.IsValid()); + ASSERT_TRUE(sql_context.cluster_job->IsValid()); std::ostringstream oss; sql_context.physical_plan->Print(oss, ""); std::cout << "physical plan:\n" << sql << "\n" << oss.str() << std::endl; std::ostringstream runner_oss; - sql_context.cluster_job.Print(runner_oss, ""); + sql_context.cluster_job->Print(runner_oss, ""); std::cout << "runner: \n" << runner_oss.str() << std::endl; std::ostringstream oss_schema; @@ -349,7 +349,7 @@ TEST_F(RunnerTest, KeyGeneratorTest) { ASSERT_TRUE(sql_context.physical_plan != nullptr); auto root = GetFirstRunnerOfType( - sql_context.cluster_job.GetTask(0).GetRoot(), kRunnerGroup); + sql_context.cluster_job->GetTask(0).GetRoot(), kRunnerGroup); auto group_runner = dynamic_cast(root); std::vector rows; hybridse::type::TableDef temp_table; diff --git a/hybridse/src/vm/sql_compiler.cc b/hybridse/src/vm/sql_compiler.cc index c686e1401b4..ea5626545ee 100644 --- a/hybridse/src/vm/sql_compiler.cc +++ b/hybridse/src/vm/sql_compiler.cc @@ -159,7 +159,7 @@ Status SqlCompiler::BuildBatchModePhysicalPlan(SqlContext* ctx, const ::hybridse vm::BatchModeTransformer transformer(&ctx->nm, ctx->db, cl_, &ctx->parameter_types, llvm_module, library, ctx->is_cluster_optimized, ctx->enable_expr_optimize, ctx->enable_batch_window_parallelization, ctx->enable_window_column_pruning, - ctx->options.get(), ctx->index_hints_); + ctx->options.get(), ctx->index_hints); transformer.AddDefaultPasses(); CHECK_STATUS(transformer.TransformPhysicalPlan(plan_list, output), "Fail to generate physical plan batch mode"); ctx->schema = *(*output)->GetOutputSchema(); @@ -172,7 +172,7 @@ Status SqlCompiler::BuildRequestModePhysicalPlan(SqlContext* ctx, const ::hybrid PhysicalOpNode** output) { vm::RequestModeTransformer transformer(&ctx->nm, ctx->db, cl_, &ctx->parameter_types, llvm_module, library, {}, ctx->is_cluster_optimized, false, ctx->enable_expr_optimize, - enable_request_performance_sensitive, ctx->options.get(), ctx->index_hints_); + enable_request_performance_sensitive, ctx->options.get(), ctx->index_hints); if (ctx->options && ctx->options->count(LONG_WINDOWS)) { transformer.AddPass(passes::kPassSplitAggregationOptimized); transformer.AddPass(passes::kPassLongWindowOptimized); @@ -196,7 +196,7 @@ Status SqlCompiler::BuildBatchRequestModePhysicalPlan(SqlContext* ctx, const ::h vm::RequestModeTransformer transformer(&ctx->nm, ctx->db, cl_, &ctx->parameter_types, llvm_module, library, ctx->batch_request_info.common_column_indices, ctx->is_cluster_optimized, ctx->is_batch_request_optimized, ctx->enable_expr_optimize, true, - ctx->options.get(), ctx->index_hints_); + ctx->options.get(), ctx->index_hints); if (ctx->options && ctx->options->count(LONG_WINDOWS)) { transformer.AddPass(passes::kPassSplitAggregationOptimized); transformer.AddPass(passes::kPassLongWindowOptimized); @@ -297,7 +297,10 @@ bool SqlCompiler::BuildClusterJob(SqlContext& ctx, Status& status) { // NOLINT ctx.is_cluster_optimized && is_request_mode, ctx.batch_request_info.common_column_indices, ctx.batch_request_info.common_node_set); - ctx.cluster_job = runner_builder.BuildClusterJob(ctx.physical_plan, status); + if (ctx.cluster_job == nullptr) { + ctx.cluster_job = std::make_shared(); + } + *ctx.cluster_job = runner_builder.BuildClusterJob(ctx.physical_plan, status); return status.isOK(); } @@ -310,11 +313,8 @@ bool SqlCompiler::BuildClusterJob(SqlContext& ctx, Status& status) { // NOLINT */ bool SqlCompiler::Parse(SqlContext& ctx, ::hybridse::base::Status& status) { // NOLINT - bool is_batch_mode = ctx.engine_mode == kBatchMode; - if (!::hybridse::plan::PlanAPI::CreatePlanTreeFromScript(ctx.sql, ctx.logical_plan, &ctx.nm, status, is_batch_mode, - ctx.is_cluster_optimized, - ctx.enable_batch_window_parallelization, - ctx.options.get())) { + status = hybridse::plan::PlanAPI::CreatePlanTreeFromScript(&ctx); + if (!status.isOK()) { LOG(WARNING) << "Fail create sql plan: " << status; return false; } diff --git a/hybridse/src/vm/sql_compiler.h b/hybridse/src/vm/sql_compiler.h index a70f5275276..a874be405fa 100644 --- a/hybridse/src/vm/sql_compiler.h +++ b/hybridse/src/vm/sql_compiler.h @@ -19,7 +19,7 @@ #include #include -#include + #include "base/fe_status.h" #include "llvm/IR/Module.h" #include "udf/udf_library.h" @@ -30,60 +30,13 @@ #include "vm/physical_op.h" #include "vm/physical_plan_context.h" #include "vm/runner.h" +#include "vm/sql_ctx.h" namespace hybridse { namespace vm { using hybridse::base::Status; -struct SqlContext { - // mode: batch|request|batch request - ::hybridse::vm::EngineMode engine_mode; - bool is_cluster_optimized = false; - bool is_batch_request_optimized = false; - bool enable_expr_optimize = false; - bool enable_batch_window_parallelization = true; - bool enable_window_column_pruning = false; - - // the sql content - std::string sql; - // the database - std::string db; - // the logical plan - ::hybridse::node::PlanNodeList logical_plan; - ::hybridse::vm::PhysicalOpNode* physical_plan = nullptr; - hybridse::vm::ClusterJob cluster_job; - // TODO(wangtaize) add a light jit engine - // eg using bthead to compile ir - hybridse::vm::JitOptions jit_options; - std::shared_ptr jit = nullptr; - Schema schema; - Schema request_schema; - std::string request_db_name; - std::string request_name; - Schema parameter_types; - uint32_t row_size; - uint32_t limit_cnt = 0; - std::string ir; - std::string logical_plan_str; - std::string physical_plan_str; - std::string encoded_schema; - std::string encoded_request_schema; - ::hybridse::node::NodeManager nm; - ::hybridse::udf::UdfLibrary* udf_library = nullptr; - - ::hybridse::vm::BatchRequestInfo batch_request_info; - - std::shared_ptr> options; - - // [ALPHA] SQL diagnostic infos - // not standardized, only index hints, no error, no warning, no other hint/info - std::shared_ptr index_hints_; - - SqlContext() {} - ~SqlContext() {} -}; - class SqlCompileInfo : public CompileInfo { public: SqlCompileInfo() : sql_ctx() {} @@ -111,13 +64,13 @@ class SqlCompileInfo : public CompileInfo { const std::string& GetRequestDbName() const override { return sql_ctx.request_db_name; } const hybridse::vm::BatchRequestInfo& GetBatchRequestInfo() const override { return sql_ctx.batch_request_info; } const hybridse::vm::PhysicalOpNode* GetPhysicalPlan() const override { return sql_ctx.physical_plan; } - hybridse::vm::Runner* GetMainTask() { return sql_ctx.cluster_job.GetMainTask().GetRoot(); } - hybridse::vm::ClusterJob& GetClusterJob() { return sql_ctx.cluster_job; } + hybridse::vm::Runner* GetMainTask() { return sql_ctx.cluster_job->GetMainTask().GetRoot(); } + std::shared_ptr GetClusterJob() { return sql_ctx.cluster_job; } void DumpPhysicalPlan(std::ostream& output, const std::string& tab) override { sql_ctx.physical_plan->Print(output, tab); } void DumpClusterJob(std::ostream& output, const std::string& tab) override { - sql_ctx.cluster_job.Print(output, tab); + sql_ctx.cluster_job->Print(output, tab); } static SqlCompileInfo* CastFrom(CompileInfo* node) { return dynamic_cast(node); } diff --git a/hybridse/src/vm/sql_ctx.cc b/hybridse/src/vm/sql_ctx.cc new file mode 100644 index 00000000000..b328801978c --- /dev/null +++ b/hybridse/src/vm/sql_ctx.cc @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2023 OpenMLDB Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "vm/sql_ctx.h" + +// DONT DELETE: unique_ptr requires full specification for underlying type +#include "zetasql/parser/parser.h" // IWYU pragma: keep + +namespace hybridse { +namespace vm { +SqlContext::SqlContext() {} + +SqlContext::~SqlContext() {} + +} // namespace vm +} // namespace hybridse diff --git a/hybridse/src/vm/transform.cc b/hybridse/src/vm/transform.cc index 82e96b3c094..49a76d95273 100644 --- a/hybridse/src/vm/transform.cc +++ b/hybridse/src/vm/transform.cc @@ -26,6 +26,7 @@ #include "codegen/context.h" #include "codegen/fn_ir_builder.h" #include "codegen/fn_let_ir_builder.h" +#include "codegen/ir_base_builder.h" #include "passes/physical/batch_request_optimize.h" #include "passes/physical/cluster_optimized.h" #include "passes/physical/condition_optimized.h" @@ -39,9 +40,9 @@ #include "passes/physical/window_column_pruning.h" #include "plan/planner.h" #include "proto/fe_common.pb.h" +#include "vm/internal/node_helper.h" #include "vm/physical_op.h" #include "vm/schemas_context.h" -#include "vm/internal/node_helper.h" namespace hybridse { namespace vm { diff --git a/src/sdk/node_adapter.cc b/src/sdk/node_adapter.cc index ef9de07a774..2a7960741a8 100644 --- a/src/sdk/node_adapter.cc +++ b/src/sdk/node_adapter.cc @@ -330,7 +330,7 @@ bool NodeAdapter::TransformToTableDef(::hybridse::node::CreatePlanNode* create_n status->code = hybridse::common::kTypeError; return false; } - auto val = TransformDataType(*dynamic_cast(default_val), + auto val = TransformDataType(*dynamic_cast(default_val), add_column_desc->data_type()); if (!val) { status->msg = "default value type mismatch"; diff --git a/third-party/cmake/FetchZetasql.cmake b/third-party/cmake/FetchZetasql.cmake index b2b1d580593..bfe1e4c94a0 100644 --- a/third-party/cmake/FetchZetasql.cmake +++ b/third-party/cmake/FetchZetasql.cmake @@ -13,10 +13,10 @@ # limitations under the License. set(ZETASQL_HOME https://github.com/4paradigm/zetasql) -set(ZETASQL_VERSION 0.3.1) -set(ZETASQL_HASH_DARWIN 48bfdfe5fa91d414b0bf8383f116bc2a1f558c12fa286e49ea5ceede366dfbcf) -set(ZETASQL_HASH_LINUX_UBUNTU 3847ed7a60aeda1192adf7d702076d2db2bd49258992e2af67515a57b8f6f6a6) -set(ZETASQL_HASH_LINUX_CENTOS e73e6259ab2df3ae7289a9ae78600b69a8fbb6e4890d07a1031ccb1e37fa4281) +set(ZETASQL_VERSION 0.3.3) +set(ZETASQL_HASH_DARWIN f1c6a4f61b4a3f278dd46ace86f8b5e30780e596ef4af22f22cc12a4a7f83664) +set(ZETASQL_HASH_LINUX_UBUNTU bfe6ef8fd8221e5619dbb66b298ad767a4e1a1326b0c4ccfb75aa9ab872d1ce2) +set(ZETASQL_HASH_LINUX_CENTOS 8b63a149abf9d14fed9e63f465e74c2300d6de7404b859c48a94d4b579d080c2) set(ZETASQL_TAG v${ZETASQL_VERSION}) function(init_zetasql_urls)