Skip to content

Commit

Permalink
feat(map): cast null or map() to explicit map type
Browse files Browse the repository at this point in the history
Auto casting to target type when necessary, which usually happens with
insert statement while a default schema defined by target table. This
supports constructing null or empty values from offline engine.
  • Loading branch information
aceforeverd committed Apr 3, 2024
1 parent a348f68 commit 2991987
Show file tree
Hide file tree
Showing 26 changed files with 196 additions and 132 deletions.
15 changes: 15 additions & 0 deletions cases/query/udf_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ cases:
# ================================================================
# Map data type
# FIXME: request mode tests disabled, because TestRequestEngineForLastRow cause SEG FAULT
# ================================================================
- id: 13
mode: request-unsupport
Expand Down Expand Up @@ -637,3 +638,17 @@ cases:
1, abc
2, null
- id: 19
mode: request-unsupport
desc: empty or null map
sql: |
select cast (null as map<int, string>)[0] as o1,
cast (null as map<string, int>) ["12"] as o2,
cast (map() as map<string, int64>) ["12"] as o3,
cast (map() as map<int, timestamp>) [7] as o4,
cast (map(7, "9") as map<int, string>) [7] as o5,
cast (map() as map<date, timestamp>) [date("2012-12-12")] as o6,
expect:
columns: ["o1 string", "o2 int", "o3 int64", "o4 timestamp", "o5 string", "o6 timestamp"]
data: |
NULL, NULL, NULL, NULL, 9, NULL
22 changes: 8 additions & 14 deletions hybridse/include/node/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,26 +410,20 @@ class NodeManager {
return node_ptr;
}

private:
void SetNodeUniqueId(ExprNode *node);
void SetNodeUniqueId(TypeNode *node);
void SetNodeUniqueId(PlanNode *node);
void SetNodeUniqueId(vm::PhysicalOpNode *node);
void SetIdCounter(size_t i) {
assert(i > id_counter_);
id_counter_ = i;
}

private:
template <typename T>
void SetNodeUniqueId(T *node) {
node->SetNodeId(other_node_idx_counter_++);
node->SetNodeId(id_counter_++);
}

std::list<base::FeBaseObject *> node_list_;

// unique id counter for various types of node
size_t expr_idx_counter_ = 1;
size_t type_idx_counter_ = 1;
size_t plan_idx_counter_ = 1;
size_t physical_plan_idx_counter_ = 1;
size_t other_node_idx_counter_ = 1;
size_t exprid_idx_counter_ = 0;
size_t id_counter_ = 0;
size_t expr_id_counter_ = 0;
};

} // namespace node
Expand Down
18 changes: 12 additions & 6 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -1520,22 +1520,28 @@ class FnDefNode : public SqlNode {

class CastExprNode : public ExprNode {
public:
explicit CastExprNode(const node::DataType cast_type, node::ExprNode *expr)
explicit CastExprNode(const node::TypeNode *cast_type, node::ExprNode *expr)
: ExprNode(kExprCast), cast_type_(cast_type) {
this->AddChild(expr);
}

~CastExprNode() {}
void Print(std::ostream &output, const std::string &org_tab) const;
const std::string GetExprString() const;
virtual bool Equals(const ExprNode *that) const;
void Print(std::ostream &output, const std::string &org_tab) const override;
const std::string GetExprString() const override;
bool Equals(const ExprNode *that) const override;
CastExprNode *ShadowCopy(NodeManager *) const override;
static CastExprNode *CastFrom(ExprNode *node);

ExprNode *expr() const { return GetChild(0); }
const DataType cast_type_;
const TypeNode *cast_type() const { return cast_type_; }

// legacy interface, required by offline batch
// pls use cast_type() as much as possible
node::DataType base_cast_type() const;

Status InferAttr(ExprAnalysisContext *ctx) override;

private:
const TypeNode *cast_type_;
};

class WhenExprNode : public ExprNode {
Expand Down
27 changes: 25 additions & 2 deletions hybridse/src/codegen/cast_expr_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ bool CastExprIRBuilder::IsSafeCast(::llvm::Type* lhs, ::llvm::Type* rhs) {
}
Status CastExprIRBuilder::Cast(const NativeValue& value,
::llvm::Type* cast_type, NativeValue* output) {
CHECK_STATUS(TypeIRBuilder::BinaryOpTypeInfer(node::ExprNode::IsCastAccept,
value.GetType(), cast_type));
if (value.GetType() == cast_type) {
*output = value;
return {};
}
if (IsSafeCast(value.GetType(), cast_type)) {
CHECK_STATUS(SafeCast(value, cast_type, output));
} else {
Expand All @@ -81,6 +83,7 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t
CHECK_TRUE(IsSafeCast(value.GetType(), dst_type), kCodegenError, "Safe cast fail: unsafe cast");
Status status;
if (value.IsConstNull()) {
// VOID type
auto res = CreateSafeNull(block_, dst_type);
CHECK_TRUE(res.ok(), kCodegenError, res.status().ToString());
*output = res.value();
Expand Down Expand Up @@ -114,6 +117,12 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t

Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output) {
::llvm::IRBuilder<> builder(block_);
node::NodeManager nm;
const node::TypeNode* src_node = nullptr;
const node::TypeNode* dst_node = nullptr;
CHECK_TRUE(GetFullType(&nm, value.GetType(), &src_node), kCodegenError);
CHECK_TRUE(GetFullType(&nm, dst_type, &dst_node), kCodegenError);

if (value.IsConstNull() || (TypeIRBuilder::IsNumber(dst_type) && TypeIRBuilder::IsDatePtr(value.GetType()))) {
// input is const null or (cast date to number)
auto res = CreateSafeNull(block_, dst_type);
Expand All @@ -135,6 +144,20 @@ Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst
StringIRBuilder string_ir_builder(block_->getModule());
CHECK_STATUS(string_ir_builder.CastToNumber(block_, value, dst_type, output));
return Status::OK();
} else if (src_node->IsMap() && dst_node->IsMap()) {
auto src_map_node = src_node->GetAsOrNull<node::MapType>();
assert(src_map_node != nullptr && "logic error: map type empty");
if (src_map_node->GetGenericType(0)->IsNull() && src_map_node->GetGenericType(1)->IsNull()) {
auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block_->getModule(), dst_type);
CHECK_TRUE(s.ok(), kCodegenError, s.status().ToString());
llvm::Value* val = nullptr;
CHECK_TRUE(s.value()->CreateDefault(block_, &val), kCodegenError);
*output = NativeValue::Create(val);
return Status::OK();
} else {
CHECK_TRUE(false, kCodegenError, "unimplemented: casting ", src_node->DebugString(), " to ",
dst_node->DebugString());
}
} else {
Status status;
::llvm::Value* output_value = nullptr;
Expand Down
12 changes: 4 additions & 8 deletions hybridse/src/codegen/expr_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Status ExprIRBuilder::BuildConstExpr(
::llvm::IRBuilder<> builder(ctx_->GetCurrentBlock());
switch (const_node->GetDataType()) {
case ::hybridse::node::kNull: {
*output = NativeValue(nullptr, nullptr, llvm::Type::getTokenTy(builder.getContext()));
*output = NativeValue(nullptr, nullptr, llvm::Type::getVoidTy(builder.getContext()));
break;
}
case ::hybridse::node::kBool: {
Expand Down Expand Up @@ -649,14 +649,10 @@ Status ExprIRBuilder::BuildCastExpr(const ::hybridse::node::CastExprNode* node,

CastExprIRBuilder cast_builder(ctx_->GetCurrentBlock());
::llvm::Type* cast_type = NULL;
CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type_, &cast_type),
kCodegenError, "Fail to cast expr: dist type invalid");
CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type(), &cast_type), kCodegenError,
"Fail to cast expr: dist type invalid");

if (cast_builder.IsSafeCast(left.GetType(), cast_type)) {
return cast_builder.SafeCast(left, cast_type, output);
} else {
return cast_builder.UnSafeCast(left, cast_type, output);
}
return cast_builder.Cast(left, cast_type, output);
}

Status ExprIRBuilder::BuildBinaryExpr(const ::hybridse::node::BinaryExpr* node,
Expand Down
27 changes: 26 additions & 1 deletion hybridse/src/codegen/insert_row_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "codegen/insert_row_builder.h"

#include <algorithm>
#include <map>
#include <string>
#include <utility>
Expand All @@ -28,14 +29,18 @@
#include "codegen/buf_ir_builder.h"
#include "codegen/context.h"
#include "codegen/expr_ir_builder.h"
#include "codegen/ir_base_builder.h"
#include "node/node_manager.h"
#include "node/sql_node.h"
#include "passes/resolve_fn_and_attrs.h"
#include "udf/default_udf_library.h"
#include "vm/jit_wrapper.h"

namespace hybridse {
namespace codegen {

static size_t MaxExprId(absl::Span<node::ExprNode* const>);

InsertRowBuilder::InsertRowBuilder(vm::HybridSeJitWrapper* jit, const codec::Schema* schema)
: schema_(schema), jit_(jit) {}

Expand Down Expand Up @@ -63,6 +68,9 @@ absl::StatusOr<int8_t*> InsertRowBuilder::ComputeRowUnsafe(absl::Span<node::Expr
llvm::make_unique<llvm::Module>(absl::StrCat("insert_row_builder_", fn_counter_.load()), *llvm_ctx);
vm::SchemasContext empty_sc;
node::NodeManager nm;
// WORKAROUND. Set the id counter to the max of all input expr nodes,
// so there will no node id conflicts during codegen
nm.SetIdCounter(MaxExprId(values) + 1);
codec::Schema empty_param_types;
CodeGenContext dump_ctx(llvm_module.get(), &empty_sc, &empty_param_types, &nm);

Expand All @@ -71,9 +79,16 @@ absl::StatusOr<int8_t*> InsertRowBuilder::ComputeRowUnsafe(absl::Span<node::Expr
passes::ResolveFnAndAttrs resolver(&expr_ctx);

std::vector<node::ExprNode*> transformed;
for (auto& expr : values) {
for (size_t i = 0; i < values.size(); i++) {
auto expr = values[i];
node::ExprNode* out = nullptr;
CHECK_STATUS_TO_ABSL(resolver.VisitExpr(expr, &out));
auto tgt_type = ColumnSchema2Type(schema_->Get(i).schema(), &nm);
CHECK_ABSL_STATUSOR(tgt_type);
if (!tgt_type.value()->Equals(out->GetOutputType())) {
auto cast = nm.MakeNode<node::CastExprNode>(tgt_type.value(), out);
CHECK_STATUS_TO_ABSL(resolver.VisitExpr(cast, &out));
}
transformed.push_back(out);
}

Expand Down Expand Up @@ -140,5 +155,15 @@ absl::StatusOr<llvm::Function*> InsertRowBuilder::BuildFn(CodeGenContext* ctx, l
return fn;
}

size_t MaxExprId(absl::Span<node::ExprNode* const> exprs) {
size_t ret = 0;

for (auto& expr : exprs) {
ret = std::max(std::max(ret, expr->node_id()), MaxExprId(expr->children_));
}

return ret;
}

} // namespace codegen
} // namespace hybridse
1 change: 0 additions & 1 deletion hybridse/src/codegen/insert_row_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class InsertRowBuilder {
absl::StatusOr<llvm::Function*> BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name,
absl::Span<node::ExprNode* const>);

// CodeGenContextBase* ctx_;
const codec::Schema* schema_;
vm::HybridSeJitWrapper* jit_;
std::atomic<uint32_t> fn_counter_ = 0;
Expand Down
14 changes: 13 additions & 1 deletion hybridse/src/codegen/insert_row_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace codegen {
class InsertRowBuilderTest : public ::testing::Test {};

TEST_F(InsertRowBuilderTest, encode) {
std::string sql = "insert into t1 values (1, map (1, '12'))";
std::string sql = "insert into t1 values (1, map (1, '12'), null, map())";
vm::SqlContext ctx;
ctx.sql = sql;
auto s = plan::PlanAPI::CreatePlanTreeFromScript(&ctx);
Expand All @@ -51,6 +51,18 @@ TEST_F(InsertRowBuilderTest, encode) {
map_ty->mutable_key_type()->set_base_type(type::kInt32);
map_ty->mutable_value_type()->set_base_type(type::kVarchar);
}
{
auto col = sc.Add();
auto map_ty = col->mutable_schema()->mutable_map_type();
map_ty->mutable_key_type()->set_base_type(type::kFloat);
map_ty->mutable_value_type()->set_base_type(type::kTimestamp);
}
{
auto col = sc.Add();
auto map_ty = col->mutable_schema()->mutable_map_type();
map_ty->mutable_key_type()->set_base_type(type::kDate);
map_ty->mutable_value_type()->set_base_type(type::kVarchar);
}

auto jit = std::shared_ptr<vm::HybridSeJitWrapper>(vm::HybridSeJitWrapper::Create());
ASSERT_TRUE(jit->Init());
Expand Down
4 changes: 2 additions & 2 deletions hybridse/src/codegen/ir_base_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,8 @@ bool GetBaseType(::llvm::Type* type, ::hybridse::node::DataType* output) {
return false;
}
switch (type->getTypeID()) {
case ::llvm::Type::TokenTyID: {
*output = ::hybridse::node::kNull;
case ::llvm::Type::VoidTyID: {
*output = ::hybridse::node::kVoid;
return true;
}
case ::llvm::Type::FloatTyID: {
Expand Down
6 changes: 2 additions & 4 deletions hybridse/src/codegen/native_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ bool NativeValue::IsNullable() const { return IsConstNull() || HasFlag(); }

// NativeValue is null if:
// - raw_ is null
// - type_ is of token type.
// Currently there is no elsewhere using token type, so assert token type should be safe.
// token type represents SQL NULL may not appropriate, more work refer #926
bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isTokenTy()); }
// - type_ is of void type.
bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isVoidTy()); }

void NativeValue::SetName(const std::string& name) {
if (raw_ == nullptr) {
Expand Down
38 changes: 33 additions & 5 deletions hybridse/src/codegen/struct_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
#include "codegen/context.h"
#include "codegen/date_ir_builder.h"
#include "codegen/ir_base_builder.h"
#include "codegen/map_ir_builder.h"
#include "codegen/string_ir_builder.h"
#include "codegen/timestamp_ir_builder.h"
#include "node/node_manager.h"

namespace hybridse {
namespace codegen {
Expand All @@ -40,19 +42,34 @@ bool StructTypeIRBuilder::StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Valu

absl::StatusOr<std::unique_ptr<StructTypeIRBuilder>> StructTypeIRBuilder::CreateStructTypeIRBuilder(
::llvm::Module* m, ::llvm::Type* type) {
node::DataType base_type;
if (!GetBaseType(type, &base_type)) {
return absl::UnimplementedError(
absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type)));
node::NodeManager nm;
const node::TypeNode* ctype = nullptr;
if (!GetFullType(&nm, type, &ctype)) {
return absl::InvalidArgumentError(absl::StrCat("can't get full type for: ", GetLlvmObjectString(type)));
}

switch (base_type) {
switch (ctype->base()) {
case node::kTimestamp:
return std::make_unique<TimestampIRBuilder>(m);
case node::kDate:
return std::make_unique<DateIRBuilder>(m);
case node::kVarchar:
return std::make_unique<StringIRBuilder>(m);
case node::DataType::kMap: {
assert(ctype->IsMap() && "logic error: not a map type");
auto map_type = ctype->GetAsOrNull<node::MapType>();
assert(map_type != nullptr && "logic error: map type empty");
::llvm::Type* key_type = nullptr;
::llvm::Type* value_type = nullptr;
if (codegen::GetLlvmType(m, map_type->key_type(), &key_type) &&
codegen::GetLlvmType(m, map_type->value_type(), &value_type)) {
return std::make_unique<MapIRBuilder>(m, key_type, value_type);
} else {
return absl::InvalidArgumentError(
absl::Substitute("not able to casting map type: $0", GetLlvmObjectString(type)));
}
break;
}
default: {
break;
}
Expand Down Expand Up @@ -224,5 +241,16 @@ absl::StatusOr<std::vector<llvm::Value*>> StructTypeIRBuilder::Load(CodeGenConte

return res;
}

absl::StatusOr<NativeValue> CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type) {
if (TypeIRBuilder::IsStructPtr(type)) {
auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), type);
CHECK_ABSL_STATUSOR(s);
return s.value()->CreateNull(block);
}

return NativeValue(nullptr, nullptr, type);
}

} // namespace codegen
} // namespace hybridse
7 changes: 7 additions & 0 deletions hybridse/src/codegen/struct_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class StructTypeIRBuilder : public TypeIRBuilder {
explicit StructTypeIRBuilder(::llvm::Module*);
~StructTypeIRBuilder();

// construct corresponding struct ir builder if exists for input type,
// otherwise, error status returned
static absl::StatusOr<std::unique_ptr<StructTypeIRBuilder>> CreateStructTypeIRBuilder(::llvm::Module*,
::llvm::Type*);
static bool StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist);
Expand Down Expand Up @@ -93,6 +95,11 @@ class StructTypeIRBuilder : public TypeIRBuilder {
::llvm::Module* m_;
::llvm::StructType* struct_type_;
};

// construct a safe null value for type
// returns NativeValue{raw, is_null=true} on success, raw is ensured to be not nullptr
absl::StatusOr<NativeValue> CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type);

} // namespace codegen
} // namespace hybridse
#endif // HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_
Loading

0 comments on commit 2991987

Please sign in to comment.