diff --git a/hybridse/include/base/fe_slice.h b/hybridse/include/base/fe_slice.h index 39bb6fa5b0c..2ae3a61bb5c 100644 --- a/hybridse/include/base/fe_slice.h +++ b/hybridse/include/base/fe_slice.h @@ -16,14 +16,15 @@ #ifndef HYBRIDSE_INCLUDE_BASE_FE_SLICE_H_ #define HYBRIDSE_INCLUDE_BASE_FE_SLICE_H_ + #include #include #include #include -#include + #include + #include "base/raw_buffer.h" -#include "boost/smart_ptr/local_shared_ptr.hpp" namespace hybridse { namespace base { diff --git a/hybridse/include/base/raw_buffer.h b/hybridse/include/base/raw_buffer.h index 294378560a9..66d36b16b48 100644 --- a/hybridse/include/base/raw_buffer.h +++ b/hybridse/include/base/raw_buffer.h @@ -16,10 +16,9 @@ #ifndef HYBRIDSE_INCLUDE_BASE_RAW_BUFFER_H_ #define HYBRIDSE_INCLUDE_BASE_RAW_BUFFER_H_ -#include + #include #include -#include #include "glog/logging.h" diff --git a/hybridse/include/codec/fe_row_codec.h b/hybridse/include/codec/fe_row_codec.h index 64a9a3be3fc..93ea3da2c78 100644 --- a/hybridse/include/codec/fe_row_codec.h +++ b/hybridse/include/codec/fe_row_codec.h @@ -18,15 +18,30 @@ #define HYBRIDSE_INCLUDE_CODEC_FE_ROW_CODEC_H_ #include +#include #include #include #include #include "absl/status/statusor.h" +#include "base/fe_status.h" +#include "base/fe_slice.h" #include "base/raw_buffer.h" +#include "codec/row.h" #include "proto/fe_type.pb.h" namespace hybridse { + +namespace node { +class ExprNode; +} +namespace codegen { +class InsertRowBuilder; +} +namespace vm { +class HybridSeJitWrapper; +} + namespace codec { const uint32_t BitMapSize(uint32_t size); @@ -74,6 +89,54 @@ inline uint32_t GetStartOffset(int32_t column_count) { void FillNullStringOffset(int8_t* buf, uint32_t start, uint32_t addr_length, uint32_t str_idx, uint32_t str_offset); + +// single slice builder from pure codegen +class SliceBuilder { + public: + SliceBuilder(vm::HybridSeJitWrapper*, const hybridse::codec::Schema* schema); + virtual ~SliceBuilder() {} + + base::Status Build(const std::vector&, base::RefCountedSlice*); + + base::Status Build(absl::Span, base::RefCountedSlice*); + + private: + void EnsureInitialized() { + assert(row_builder_ != nullptr && "must initialize the row builder before encoding"); + } + + const Schema* schema_; + std::shared_ptr row_builder_ = nullptr; +}; + +// new row builder from pure codegen +class RowBuilder2 { + public: + RowBuilder2(vm::HybridSeJitWrapper*, int sliceSize); + RowBuilder2(vm::HybridSeJitWrapper*, const std::vector& schemas); + RowBuilder2(vm::HybridSeJitWrapper*, const std::vector>& schemas); + ~RowBuilder2() {} + + base::Status Init(); + + base::Status InitSchema(int idx, const codec::Schema& sc); + + base::Status Build(const std::vector&, codec::Row*); + + private: + void EnsureInitialized() { + assert(initialized_ && "RowBuild not initialized"); + } + + vm::HybridSeJitWrapper* jit_ = nullptr; + std::vector schemas_; + std::vector> builders_; + + bool initialized_ = false; +}; + +// Old row builder in C +// limited data type support, no map, no array. U should upgrade to RowBuilder2 class RowBuilder { public: explicit RowBuilder(const hybridse::codec::Schema& schema); diff --git a/hybridse/include/codec/row.h b/hybridse/include/codec/row.h index 09ba4741090..ee421dbe2e8 100644 --- a/hybridse/include/codec/row.h +++ b/hybridse/include/codec/row.h @@ -18,14 +18,9 @@ #define HYBRIDSE_INCLUDE_CODEC_ROW_H_ #include -#include #include -#include -#include #include #include "base/fe_slice.h" -#include "base/raw_buffer.h" -#include "proto/fe_type.pb.h" namespace hybridse { namespace codec { diff --git a/hybridse/src/codec/fe_row_codec.cc b/hybridse/src/codec/fe_row_codec.cc index e67456e2cdb..e92c66bbd83 100644 --- a/hybridse/src/codec/fe_row_codec.cc +++ b/hybridse/src/codec/fe_row_codec.cc @@ -21,7 +21,10 @@ #include "codec/type_codec.h" #include "gflags/gflags.h" +#include "codegen/insert_row_builder.h" #include "glog/logging.h" +#include "proto/fe_common.pb.h" +#include "vm/engine.h" DECLARE_bool(enable_spark_unsaferow_format); @@ -1033,5 +1036,100 @@ absl::StatusOr SliceFormat::GetStringColumnInfo(size_t idx) const str_field_start_offset_); } +SliceBuilder::SliceBuilder(vm::HybridSeJitWrapper* jit, const hybridse::codec::Schema* schema) + : schema_(schema) { + row_builder_ = std::make_shared(jit, schema_); +} + +base::Status SliceBuilder::Build(const std::vector& values, base::RefCountedSlice* slice) { + return Build(absl::MakeSpan(values), slice); +} + +base::Status SliceBuilder::Build(absl::Span values, base::RefCountedSlice* slice) { + EnsureInitialized(); + + auto rs = row_builder_->ComputeRowUnsafe(values); + if (!rs.ok()) { + return {common::kCodegenEncodeError, rs.status().ToString()}; + } + + auto buf = rs.value(); + if (buf == nullptr) { + return {common::kCodegenEncodeError, "internal error: encoded buf is null"}; + } + + *slice = base::RefCountedSlice::CreateManaged(buf, RowView::GetSize(buf)); + + return {}; +} + +RowBuilder2::RowBuilder2(vm::HybridSeJitWrapper* jit, int sliceSize) : jit_(jit) { + schemas_.resize(sliceSize); + builders_.resize(sliceSize); +} +RowBuilder2::RowBuilder2(vm::HybridSeJitWrapper* jit, const std::vector& schemas) + : jit_(jit), schemas_(schemas) { + builders_.resize(schemas_.size()); +} +RowBuilder2::RowBuilder2(vm::HybridSeJitWrapper* jit, + const std::vector>& schemas) + : jit_(jit) { + for (auto& sc : schemas) { + schemas_.push_back(Schema()); + auto& ref = schemas_.back(); + for (auto& col : sc) { + ref.Add()->CopyFrom(col); + } + } + builders_.resize(schemas_.size()); +} + +base::Status RowBuilder2::Init() { + CHECK_TRUE(jit_ != nullptr, common::kCodegenEncodeError, "jit is null"); + for (size_t i = 0; i < schemas_.size(); ++i) { + CHECK_TRUE(!schemas_[i].empty(), common::kCodegenEncodeError, absl::StrCat(i, "th schema un-initialized")); + if (builders_[i] == nullptr) { + builders_[i] = std::make_shared(jit_, &schemas_[i]); + } + } + + initialized_ = true; + return {}; +} + +base::Status RowBuilder2::Build(const std::vector& values, codec::Row* out) { + EnsureInitialized(); + + auto expect_cols = + std::accumulate(schemas_.begin(), schemas_.end(), 0, [](int val, const auto& e) { return val + e.size(); }); + CHECK_TRUE(values.size() == expect_cols, common::kCodegenEncodeError, "pass in expr number do not match, expect ", + expect_cols, " but got ", values.size()); + + int col_idx = 0; + Row row; + auto values_ref = absl::MakeSpan(values); + for (size_t i = 0; i < schemas_.size(); ++i) { + RefCountedSlice slice; + CHECK_STATUS(builders_[i]->Build(values_ref.subspan(col_idx, schemas_[i].size()), &slice)); + if (i == 0) { + row.Reset(slice); + } else { + row.Append(slice); + } + + col_idx += schemas_[i].size(); + } + + *out = row; + + return {}; +} +base::Status RowBuilder2::InitSchema(int idx, const codec::Schema& sc) { + if (idx >= schemas_.size()) { + return {common::kCodegenEncodeError, "idx out of bound"}; + } + schemas_[idx] = sc; + return {}; +} } // namespace codec } // namespace hybridse diff --git a/hybridse/src/codegen/aggregate_ir_builder.cc b/hybridse/src/codegen/aggregate_ir_builder.cc index 156d1cf91eb..b927f373b2d 100644 --- a/hybridse/src/codegen/aggregate_ir_builder.cc +++ b/hybridse/src/codegen/aggregate_ir_builder.cc @@ -736,8 +736,7 @@ base::Status AggregateIRBuilder::BuildMulti(const std::string& base_funcname, std::vector> outputs; agg_generator.GenOutputs(&builder, &outputs); for (auto pair : outputs) { - output_encoder.BuildEncodePrimaryField(output_arg, pair.first, - pair.second); + CHECK_STATUS(output_encoder.BuildEncodePrimaryField(output_arg, pair.first, pair.second)); } } builder.CreateRetVoid(); diff --git a/hybridse/src/codegen/buf_ir_builder.h b/hybridse/src/codegen/buf_ir_builder.h index 0ec9e664baf..fbe91074c8b 100644 --- a/hybridse/src/codegen/buf_ir_builder.h +++ b/hybridse/src/codegen/buf_ir_builder.h @@ -36,11 +36,14 @@ class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { ~BufNativeEncoderIRBuilder() override; + ABSL_MUST_USE_RESULT base::Status Init() noexcept; // the output_ptr like int8_t** + ABSL_MUST_USE_RESULT base::Status BuildEncode(::llvm::Value* output_ptr) override; + ABSL_MUST_USE_RESULT base::Status BuildEncodePrimaryField(::llvm::Value* buf, size_t idx, const NativeValue& val); private: diff --git a/hybridse/src/codegen/buf_ir_builder_test.cc b/hybridse/src/codegen/buf_ir_builder_test.cc index 8f6c1149081..1dd6990076c 100644 --- a/hybridse/src/codegen/buf_ir_builder_test.cc +++ b/hybridse/src/codegen/buf_ir_builder_test.cc @@ -187,7 +187,6 @@ void RunEncode(::hybridse::type::TableDef& table, // NOLINT auto jit = std::unique_ptr( vm::HybridSeJitWrapper::Create()); jit->Init(); - vm::HybridSeJitWrapper::InitJitSymbols(jit.get()); ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx))); auto load_fn_jit = jit->FindFunction("fn"); void (*decode)(int8_t**) = @@ -307,7 +306,6 @@ void LoadValue(T* result, bool* is_null, auto jit = std::unique_ptr( vm::HybridSeJitWrapper::Create()); jit->Init(); - vm::HybridSeJitWrapper::InitJitSymbols(jit.get()); ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx))); auto load_fn_jit = jit->FindFunction("fn"); @@ -438,7 +436,6 @@ void RunColCase(T expected, type::TableDef& table, // NOLINT auto jit = std::unique_ptr( vm::HybridSeJitWrapper::Create()); jit->Init(); - vm::HybridSeJitWrapper::InitJitSymbols(jit.get()); ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx))); jit->AddExternalFunction("print_list_i16", reinterpret_cast(&PrintListInt16)); diff --git a/hybridse/src/codegen/fn_ir_builder_test.cc b/hybridse/src/codegen/fn_ir_builder_test.cc index 9a3da3bee97..ee04bd5ab93 100644 --- a/hybridse/src/codegen/fn_ir_builder_test.cc +++ b/hybridse/src/codegen/fn_ir_builder_test.cc @@ -82,7 +82,6 @@ void CheckResult(node::FnNodeFnDef *fn_def, R exp, V1 a, V2 b) { m->print(::llvm::errs(), NULL, true, true); auto jit = std::unique_ptr(vm::HybridSeJitWrapper::Create()); jit->Init(); - vm::HybridSeJitWrapper::InitJitSymbols(jit.get()); ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx))); auto test_fn = (R(*)(V1, V2))jit->FindFunction(fn_def->header_->GeIRFunctionName()); R result = test_fn(a, b); diff --git a/hybridse/src/codegen/fn_let_ir_builder_test.h b/hybridse/src/codegen/fn_let_ir_builder_test.h index c4a534fbe72..b98ce27a3dd 100644 --- a/hybridse/src/codegen/fn_let_ir_builder_test.h +++ b/hybridse/src/codegen/fn_let_ir_builder_test.h @@ -134,7 +134,6 @@ void CheckFnLetBuilderWithParameterRow(::hybridse::node::NodeManager* manager, v auto jit = std::unique_ptr( vm::HybridSeJitWrapper::Create()); jit->Init(); - vm::HybridSeJitWrapper::InitJitSymbols(jit.get()); ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx))); auto address = jit->FindFunction("test_at_fn"); diff --git a/hybridse/src/codegen/insert_row_builder.cc b/hybridse/src/codegen/insert_row_builder.cc index c52eec6a1d8..bea754485c6 100644 --- a/hybridse/src/codegen/insert_row_builder.cc +++ b/hybridse/src/codegen/insert_row_builder.cc @@ -21,7 +21,9 @@ #include #include +#include "absl/cleanup/cleanup.h" #include "absl/status/status.h" +#include "absl/strings/str_join.h" #include "base/fe_status.h" #include "codegen/buf_ir_builder.h" #include "codegen/context.h" @@ -29,39 +31,36 @@ #include "node/node_manager.h" #include "passes/resolve_fn_and_attrs.h" #include "udf/default_udf_library.h" -#include "vm/engine.h" #include "vm/jit_wrapper.h" namespace hybridse { namespace codegen { -InsertRowBuilder::InsertRowBuilder(const codec::Schema* schema) : schema_(schema) {} - -absl::Status InsertRowBuilder::Init() { - ::hybridse::vm::Engine::InitializeGlobalLLVM(); - - jit_ = std::unique_ptr(vm::HybridSeJitWrapper::Create()); - if (!jit_->Init()) { - jit_ = nullptr; - return absl::InternalError("fail to init jit"); - } - if (!vm::HybridSeJitWrapper::InitJitSymbols(jit_.get())) { - jit_ = nullptr; - return absl::InternalError("fail to init jit symbols"); - } - return absl::OkStatus(); -} +InsertRowBuilder::InsertRowBuilder(vm::HybridSeJitWrapper* jit, const codec::Schema* schema) + : schema_(schema), jit_(jit) {} absl::StatusOr> InsertRowBuilder::ComputeRow(const node::ExprListNode* values) { - EnsureInitialized(); return ComputeRow(values->children_); } absl::StatusOr> InsertRowBuilder::ComputeRow(absl::Span values) { - EnsureInitialized(); + auto rs = ComputeRowUnsafe(values); + CHECK_ABSL_STATUSOR(rs); + + auto managed_row = std::shared_ptr(rs.value(), std::free); + return managed_row; +} + +absl::StatusOr InsertRowBuilder::ComputeRowUnsafe(absl::Span values) { + absl::Cleanup clean = [&]() { fn_counter_++; }; + + DLOG(INFO) << absl::StrJoin(values, ", ", [](std::string* str, const node::ExprNode* const expr) { + absl::StrAppend(str, expr->GetExprString()); + }); std::unique_ptr llvm_ctx = llvm::make_unique(); - std::unique_ptr llvm_module = llvm::make_unique("insert_row_builder", *llvm_ctx); + std::unique_ptr llvm_module = + llvm::make_unique(absl::StrCat("insert_row_builder_", fn_counter_.load()), *llvm_ctx); vm::SchemasContext empty_sc; node::NodeManager nm; codec::Schema empty_param_types; @@ -78,7 +77,7 @@ absl::StatusOr> InsertRowBuilder::ComputeRow(absl::Span< transformed.push_back(out); } - std::string fn_name = absl::StrCat("gen_insert_row_", fn_counter_++); + std::string fn_name = absl::StrCat("gen_insert_row_", fn_counter_.load()); auto fs = BuildFn(&dump_ctx, fn_name, transformed); CHECK_ABSL_STATUSOR(fs); @@ -98,9 +97,7 @@ absl::StatusOr> InsertRowBuilder::ComputeRow(absl::Span< int8_t* insert_row = nullptr; encode(&insert_row); - auto managed_row = std::shared_ptr(insert_row, std::free); - - return managed_row; + return insert_row; } absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, @@ -135,7 +132,7 @@ absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, l BufNativeEncoderIRBuilder encode_builder(ctx, &columns, schema_); CHECK_STATUS_TO_ABSL(encode_builder.Init()); - encode_builder.BuildEncode(row_ptr_ptr); + CHECK_STATUS_TO_ABSL(encode_builder.BuildEncode(row_ptr_ptr)); builder->CreateRetVoid(); } @@ -143,7 +140,5 @@ absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, l return fn; } -// build the function that transform a single insert row values into encoded row -absl::StatusOr InsertRowBuilder::BuildEncodeFn() { return absl::OkStatus(); } } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/insert_row_builder.h b/hybridse/src/codegen/insert_row_builder.h index 83e8c1c2126..3391bbe9e57 100644 --- a/hybridse/src/codegen/insert_row_builder.h +++ b/hybridse/src/codegen/insert_row_builder.h @@ -29,11 +29,15 @@ namespace hybridse { namespace codegen { +// Row builder that can output encode row from ExprNodes: +// schema + [ExprNode] -> row buf +// +// Common usage: +// - by insert statement that encode insert values into row bufs +// - by any row builder who want to build the row buf manually class InsertRowBuilder { public: - explicit InsertRowBuilder(const codec::Schema* schema); - - absl::Status Init(); + InsertRowBuilder(vm::HybridSeJitWrapper*, const codec::Schema*) ABSL_ATTRIBUTE_NONNULL(); // compute the encoded row result for insert statement's single values expression list // @@ -43,9 +47,12 @@ class InsertRowBuilder { absl::StatusOr> ComputeRow(const node::ExprListNode* values); - private: - void EnsureInitialized() { assert(jit_ && "InsertRowBuilder not initialized"); } + // compute the encoded row result for insert statement's single values expression list + // + // returns a pointer to encoded buf, you must manually delete the pointer after use + absl::StatusOr ComputeRowUnsafe(absl::Span values); + private: // build the function the will output the row from single insert values // // the function is just equivalent to C: `void fn(int8_t**)`. @@ -53,14 +60,10 @@ class InsertRowBuilder { absl::StatusOr BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, absl::Span); - // build the function that transform a single insert row values into encoded row - absl::StatusOr BuildEncodeFn(); - // CodeGenContextBase* ctx_; const codec::Schema* schema_; + vm::HybridSeJitWrapper* jit_; std::atomic fn_counter_ = 0; - - std::unique_ptr jit_; }; } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/insert_row_builder_test.cc b/hybridse/src/codegen/insert_row_builder_test.cc index 4924c175957..900fc9b2e04 100644 --- a/hybridse/src/codegen/insert_row_builder_test.cc +++ b/hybridse/src/codegen/insert_row_builder_test.cc @@ -22,6 +22,7 @@ #include "node/sql_node.h" #include "plan/plan_api.h" #include "vm/sql_ctx.h" +#include "vm/engine.h" namespace hybridse { namespace codegen { @@ -51,11 +52,10 @@ TEST_F(InsertRowBuilderTest, encode) { map_ty->mutable_value_type()->set_base_type(type::kVarchar); } - InsertRowBuilder builder(&sc); - { - auto s = builder.Init(); - ASSERT_TRUE(s.ok()) << s; - } + auto jit = std::shared_ptr(vm::HybridSeJitWrapper::Create()); + ASSERT_TRUE(jit->Init()); + + InsertRowBuilder builder(jit.get(), &sc); auto as = builder.ComputeRow(dynamic_cast(exprlist)); ASSERT_TRUE(as.ok()) << as.status(); @@ -67,5 +67,6 @@ TEST_F(InsertRowBuilderTest, encode) { // int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); + ::hybridse::vm::Engine::InitializeGlobalLLVM(); return RUN_ALL_TESTS(); } diff --git a/hybridse/src/codegen/ir_base_builder_test.h b/hybridse/src/codegen/ir_base_builder_test.h index 494cfdb0818..5f4ed8d9d2a 100644 --- a/hybridse/src/codegen/ir_base_builder_test.h +++ b/hybridse/src/codegen/ir_base_builder_test.h @@ -197,14 +197,12 @@ class ModuleTestFunction { std::unique_ptr<::llvm::LLVMContext> llvm_ctx) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); - jit = std::unique_ptr( - vm::HybridSeJitWrapper::Create()); - jit->Init(); - InitBuiltinJitSymbols(jit.get()); - if (library != nullptr) { - library->InitJITSymbols(jit.get()); - } else { - udf::DefaultUdfLibrary::get()->InitJITSymbols(jit.get()); + base::Status s; + jit = + std::unique_ptr(vm::HybridSeJitWrapper::CreateWithDefaultSymbols(library, &s, {})); + if (jit == nullptr || !s.isOK()) { + LOG(WARNING) << "create jit failed" << s; + return; } llvm::errs() << *(module.get()) << "\n"; diff --git a/hybridse/src/codegen/map_ir_builder.cc b/hybridse/src/codegen/map_ir_builder.cc index 27e6944c102..d617b76e419 100644 --- a/hybridse/src/codegen/map_ir_builder.cc +++ b/hybridse/src/codegen/map_ir_builder.cc @@ -81,6 +81,13 @@ absl::StatusOr MapIRBuilder::Construct(CodeGenContextBase* ctx, abs } auto builder = ctx->GetBuilder(); + if (args.empty()) { + if (!Set(ctx->GetCurrentBlock(), map_alloca, SZ_IDX, builder->getInt32(0))) { + return absl::InternalError("setting map size=0 for map failed"); + } + return NativeValue::Create(map_alloca); + } + auto* original_size = builder->getInt32(args.size() / 2); auto* key_vec = builder->CreateAlloca(key_type_, original_size, "key_vec"); auto* value_vec = builder->CreateAlloca(value_type_, original_size, "value_vec"); @@ -210,16 +217,19 @@ absl::StatusOr MapIRBuilder::ExtractElement(CodeGenContextBase* ctx auto builder = ctx->GetBuilder(); - auto s = ctx->CreateBranch( + builder->CreateStore(builder->getInt1(true), out_null_alloca_param); + + auto s = ctx->CreateBranchNot( builder->CreateOr(arr_is_null_param, key_is_null_param), - [&]() -> base::Status { - builder->CreateStore(builder->getInt1(true), out_null_alloca_param); - return {}; - }, [&]() -> base::Status { ::llvm::Value* sz = nullptr; CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr_param, SZ_IDX, &sz), common::kCodegenError); + CHECK_STATUS(ctx->CreateBranch(builder->CreateICmpSLE(sz, builder->getInt32(0)), [&]() -> base::Status { + builder->CreateRetVoid(); + return {}; + })); + ::llvm::Value* keys = nullptr; CHECK_TRUE(Load(ctx->GetCurrentBlock(), map_ptr_param, KEY_VEC_IDX, &keys), common::kCodegenError); @@ -394,6 +404,12 @@ absl::StatusOr MapIRBuilder::BuildEncodeByteSizeFn(CodeGenConte auto& value_vec = elements_vec[2]; auto& value_null_vec = elements_vec[3]; + CHECK_STATUS_TO_ABSL( + ctx->CreateBranch(builder->CreateICmpSLE(map_size, builder->getInt32(0)), [&]() -> base::Status { + builder->CreateRet(CodecSizeForPrimitive(builder, builder->getInt32Ty())); + return {}; + })); + auto keys_sz = CalEncodeSizeForArray(ctx, key_vec, map_size); CHECK_ABSL_STATUSOR(keys_sz); auto values_sz = CalEncodeSizeForArray(ctx, value_vec, map_size); @@ -712,61 +728,71 @@ absl::StatusOr MapIRBuilder::EncodeBaseValue(CodeGenContextBase* c absl::StatusOr MapIRBuilder::Decode(CodeGenContextBase* ctx, llvm::Value* row_ptr) const { auto* builder = ctx->GetBuilder(); - auto* map_sz = builder->CreateLoad(builder->getInt32Ty(), row_ptr); + auto* map_sz = builder->CreateLoad(builder->getInt32Ty(), + builder->CreatePointerCast(row_ptr, builder->getInt32Ty()->getPointerTo())); llvm::Value* map_alloca = nullptr; if (!Allocate(ctx->GetCurrentBlock(), &map_alloca)) { return absl::InternalError("fail to allocate map"); } + if (!Set(ctx->GetCurrentBlock(), map_alloca, SZ_IDX, builder->getInt32(0))) { + return absl::InternalError("fail to set default size for map"); + } - // only work if allocation happens in the top function, otherwise vectors will be cleared - llvm::Value* key_vec = builder->CreateAlloca(key_type_, map_sz, "map_key_vec"); - llvm::Value* value_vec = builder->CreateAlloca(value_type_, map_sz, "map_value_vec"); - llvm::Value* value_nulls_vec = builder->CreateAlloca(builder->getInt1Ty(), map_sz, "map_nulls_vec"); - - llvm::Value* idx0_alloca = builder->CreateAlloca(builder->getInt32Ty()); - builder->CreateStore(builder->getInt32(0), idx0_alloca); - // also allocate space for pointer type that points to - CHECK_STATUS_TO_ABSL(ctx->CreateWhile( - [&](llvm::Value** cond) -> base::Status { - *cond = ctx->GetBuilder()->CreateICmpSLT(builder->CreateLoad(idx0_alloca), map_sz); - return {}; - }, - [&]() -> base::Status { - llvm::Value* idx = builder->CreateLoad(idx0_alloca); - if (key_type_->isPointerTy()) { - auto* ele_val = builder->CreateAlloca(key_type_->getPointerElementType()); - builder->CreateStore(ele_val, builder->CreateGEP(key_type_, key_vec, idx)); - } - if (value_type_->isPointerTy()) { - auto* ele_val = builder->CreateAlloca(value_type_->getPointerElementType()); - builder->CreateStore(ele_val, builder->CreateGEP(value_type_, value_vec, idx)); - } - - builder->CreateStore(builder->CreateAdd(builder->getInt32(1), idx), idx0_alloca); - return {}; - })); - - auto s0 = BuildGetPtrOffset(builder, row_ptr, builder->getInt32(4), builder->getInt8Ty()->getPointerTo()); - CHECK_ABSL_STATUSOR(s0); - auto key_vec_res = DecodeArrayValue(ctx, s0.value(), map_sz, key_vec, key_type_); - CHECK_ABSL_STATUSOR(key_vec_res); + CHECK_STATUS_TO_ABSL(ctx->CreateBranch(builder->CreateICmpSGT(map_sz, builder->getInt32(0)), [&]() -> base::Status { + // only work if allocation happens in the top function, otherwise vectors will be cleared + llvm::Value* key_vec = builder->CreateAlloca(key_type_, map_sz, "map_key_vec"); + llvm::Value* value_vec = builder->CreateAlloca(value_type_, map_sz, "map_value_vec"); + llvm::Value* value_nulls_vec = builder->CreateAlloca(builder->getInt1Ty(), map_sz, "map_nulls_vec"); - auto s1 = BuildGetPtrOffset(builder, row_ptr, builder->CreateAdd(builder->getInt32(4), key_vec_res.value()), - builder->getInt8Ty()->getPointerTo()); - CHECK_ABSL_STATUSOR(s1); - auto value_vec_res = DecodeArrayValue(ctx, s1.value(), map_sz, value_vec, value_type_); - CHECK_ABSL_STATUSOR(value_vec_res); + llvm::Value* idx0_alloca = builder->CreateAlloca(builder->getInt32Ty()); + builder->CreateStore(builder->getInt32(0), idx0_alloca); + // also allocate space for pointer type that points to + CHECK_STATUS(ctx->CreateWhile( + [&](llvm::Value** cond) -> base::Status { + *cond = ctx->GetBuilder()->CreateICmpSLT(builder->CreateLoad(idx0_alloca), map_sz); + return {}; + }, + [&]() -> base::Status { + llvm::Value* idx = builder->CreateLoad(idx0_alloca); + if (key_type_->isPointerTy()) { + auto* ele_val = builder->CreateAlloca(key_type_->getPointerElementType()); + builder->CreateStore(ele_val, builder->CreateGEP(key_type_, key_vec, idx)); + } + if (value_type_->isPointerTy()) { + auto* ele_val = builder->CreateAlloca(value_type_->getPointerElementType()); + builder->CreateStore(ele_val, builder->CreateGEP(value_type_, value_vec, idx)); + } - auto s2 = BuildGetPtrOffset( - builder, row_ptr, - builder->CreateAdd(builder->getInt32(4), builder->CreateAdd(key_vec_res.value(), value_vec_res.value())), - builder->getInt8Ty()->getPointerTo()); - CHECK_ABSL_STATUSOR(s2); - auto value_null_vec_res = DecodeArrayValue(ctx, s2.value(), map_sz, value_nulls_vec, builder->getInt1Ty()); - CHECK_ABSL_STATUSOR(value_null_vec_res); + builder->CreateStore(builder->CreateAdd(builder->getInt32(1), idx), idx0_alloca); + return {}; + })); + + auto s0 = BuildGetPtrOffset(builder, row_ptr, builder->getInt32(4), builder->getInt8Ty()->getPointerTo()); + CHECK_TRUE(s0.ok(), common::kCodegenError, s0.status()); + auto key_vec_res = DecodeArrayValue(ctx, s0.value(), map_sz, key_vec, key_type_); + CHECK_TRUE(key_vec_res.ok(), common::kCodegenError, key_vec_res.status()); + + auto s1 = BuildGetPtrOffset(builder, row_ptr, builder->CreateAdd(builder->getInt32(4), key_vec_res.value()), + builder->getInt8Ty()->getPointerTo()); + CHECK_TRUE(s1.ok(), common::kCodegenError, s1.status()); + auto value_vec_res = DecodeArrayValue(ctx, s1.value(), map_sz, value_vec, value_type_); + CHECK_TRUE(value_vec_res.ok(), common::kCodegenError, value_vec_res.status()); + + auto s2 = BuildGetPtrOffset( + builder, row_ptr, + builder->CreateAdd(builder->getInt32(4), builder->CreateAdd(key_vec_res.value(), value_vec_res.value())), + builder->getInt8Ty()->getPointerTo()); + CHECK_TRUE(s2.ok(), common::kCodegenError, s2.status()); + auto value_null_vec_res = DecodeArrayValue(ctx, s2.value(), map_sz, value_nulls_vec, builder->getInt1Ty()); + CHECK_TRUE(value_null_vec_res.ok(), common::kCodegenError, value_null_vec_res.status()); - CHECK_ABSL_STATUS(Set(ctx, map_alloca, {map_sz, key_vec, value_vec, value_nulls_vec})); + { + auto s = Set(ctx, map_alloca, {map_sz, key_vec, value_vec, value_nulls_vec}); + CHECK_TRUE(s.ok(), common::kCodegenError, s); + } + return {}; + })); return map_alloca; } @@ -837,7 +863,8 @@ absl::StatusOr MapIRBuilder::DecodeBaseValue(CodeGenContextBase* c llvm::Value* base_ptr, llvm::Type* type) const { auto builder = ctx->GetBuilder(); if (type->isIntegerTy() || type->isFloatTy() || type->isDoubleTy()) { - builder->CreateStore(builder->CreateLoad(type, ptr), base_ptr); + builder->CreateStore(builder->CreateLoad(type, builder->CreatePointerCast(ptr, type->getPointerTo())), + base_ptr); return CodecSizeForPrimitive(builder, type); } // struct pointer diff --git a/hybridse/src/sdk/hybridse_interface_core.i b/hybridse/src/sdk/hybridse_interface_core.i index 9c053b69b71..71cd2aab39a 100644 --- a/hybridse/src/sdk/hybridse_interface_core.i +++ b/hybridse/src/sdk/hybridse_interface_core.i @@ -196,6 +196,7 @@ using hybridse::node::DataType; %include "node/node_enum.h" %include "node/plan_node.h" %include "node/sql_node.h" +%include "node/node_manager.h" %include "vm/catalog.h" %include "vm/simple_catalog.h" %include "vm/schemas_context.h" @@ -208,3 +209,4 @@ using hybridse::node::DataType; %include "vm/mem_catalog.h" %template(VectorDataType) std::vector; +%template(ExprNodeVector) std::vector; diff --git a/hybridse/src/udf/default_udf_library.cc b/hybridse/src/udf/default_udf_library.cc index 265a1e09250..ea568f3f9f3 100644 --- a/hybridse/src/udf/default_udf_library.cc +++ b/hybridse/src/udf/default_udf_library.cc @@ -17,23 +17,24 @@ #include "udf/default_udf_library.h" #include +#include #include +#include #include #include #include #include #include -#include -#include +#include "absl/cleanup/cleanup.h" #include "codegen/date_ir_builder.h" #include "codegen/string_ir_builder.h" #include "codegen/timestamp_ir_builder.h" #include "udf/containers.h" +#include "udf/default_defs/date_and_time_def.h" +#include "udf/default_defs/expr_def.h" #include "udf/udf.h" #include "udf/udf_registry.h" -#include "udf/default_defs/expr_def.h" -#include "udf/default_defs/date_and_time_def.h" using openmldb::base::Date; using openmldb::base::StringRef; @@ -46,7 +47,8 @@ namespace hybridse { namespace udf { DefaultUdfLibrary* DefaultUdfLibrary::MakeDefaultUdf() { - LOG(INFO) << "Creating DefaultUdfLibrary"; + absl::Time begin = absl::Now(); + absl::Cleanup clean = [&]() { LOG(INFO) << "Created DefaultUdfLibrary in " << absl::Now() - begin; }; return new DefaultUdfLibrary(); } diff --git a/hybridse/src/vm/jit.cc b/hybridse/src/vm/jit.cc index f714c66db8a..9053a76431f 100644 --- a/hybridse/src/vm/jit.cc +++ b/hybridse/src/vm/jit.cc @@ -15,23 +15,23 @@ */ #include "vm/jit.h" + #include #include extern "C" { #include #include } + +#include "absl/cleanup/cleanup.h" +#include "absl/time/clock.h" #include "glog/logging.h" -#include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" @@ -149,7 +149,13 @@ bool HybridSeJit::AddSymbol(::llvm::orc::JITDylib& jd, } bool HybridSeLlvmJitWrapper::Init() { - DLOG(INFO) << "Start to initialize hybridse jit"; + absl::Time begin = absl::Now(); + absl::Cleanup clean = [&]() { DLOG(INFO) << "LLVM JIT initialize takes " << absl::Now() - begin; }; + + if (initialized_) { + return true; + } + auto jit = ::llvm::Expected>( HybridSeJitBuilder().create()); { @@ -166,16 +172,27 @@ bool HybridSeLlvmJitWrapper::Init() { this->mi_ = std::unique_ptr<::llvm::orc::MangleAndInterner>( new ::llvm::orc::MangleAndInterner(jit_->getExecutionSession(), jit_->getDataLayout())); + + auto s = InitJitSymbols(); + if (!s.isOK()) { + LOG(WARNING) << s; + return false; + } + + initialized_ = true; return true; } bool HybridSeLlvmJitWrapper::OptModule(::llvm::Module* module) { + EnsureInitialized(); return jit_->OptModule(module); } bool HybridSeLlvmJitWrapper::AddModule( std::unique_ptr module, std::unique_ptr llvm_ctx) { + EnsureInitialized(); + ::llvm::Error e = jit_->addIRModule( ::llvm::orc::ThreadSafeModule(std::move(module), std::move(llvm_ctx))); if (e) { @@ -206,9 +223,18 @@ bool HybridSeLlvmJitWrapper::AddExternalFunction(const std::string& name, } #ifdef LLVM_EXT_ENABLE -bool HybridSeMcJitWrapper::Init() { return true; } +bool HybridSeMcJitWrapper::Init() { + auto s = InitJitSymbols(); + if (!s.isOK()) { + LOG(WARNING) << s; + return false; + } + return true; +} bool HybridSeMcJitWrapper::OptModule(::llvm::Module* module) { + EnsureInitialized(); + DLOG(INFO) << "Module before opt:\n" << LlvmToString(*module); RunDefaultOptPasses(module); DLOG(INFO) << "Module after opt:\n" << LlvmToString(*module); @@ -218,6 +244,8 @@ bool HybridSeMcJitWrapper::OptModule(::llvm::Module* module) { bool HybridSeMcJitWrapper::AddModule( std::unique_ptr module, std::unique_ptr llvm_ctx) { + EnsureInitialized(); + if (llvm::verifyModule(*module, &llvm::errs(), nullptr)) { // note: destruct module before ctx module = nullptr; diff --git a/hybridse/src/vm/jit.h b/hybridse/src/vm/jit.h index 7af5f17ac0d..3ae60fccf35 100644 --- a/hybridse/src/vm/jit.h +++ b/hybridse/src/vm/jit.h @@ -17,10 +17,8 @@ #ifndef HYBRIDSE_SRC_VM_JIT_H_ #define HYBRIDSE_SRC_VM_JIT_H_ -#include #include #include -#include "llvm/ExecutionEngine/GenericValue.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" #include "vm/jit_wrapper.h" @@ -100,10 +98,6 @@ class HybridSeLlvmJitWrapper : public HybridSeJitWrapper { hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) override; - // llvm::Module* GetModule() { - // } - // llvm::LLVMContext* GetLlvmContext(); - private: std::unique_ptr jit_; std::unique_ptr<::llvm::orc::MangleAndInterner> mi_; diff --git a/hybridse/src/vm/jit_wrapper.cc b/hybridse/src/vm/jit_wrapper.cc index 2f2e75d1502..876df037ef0 100644 --- a/hybridse/src/vm/jit_wrapper.cc +++ b/hybridse/src/vm/jit_wrapper.cc @@ -43,6 +43,8 @@ namespace hybridse { namespace vm { +static void InitBuiltinJitSymbols(HybridSeJitWrapper* jit_ptr); + bool HybridSeJitWrapper::AddModuleFromBuffer(const base::RawBuffer& buf) { std::string buf_str(buf.addr, buf.size); ::llvm::SMDiagnostic diagnostic; @@ -50,17 +52,34 @@ bool HybridSeJitWrapper::AddModuleFromBuffer(const base::RawBuffer& buf) { auto mem_buf = ::llvm::MemoryBuffer::getMemBuffer(buf_str); auto llvm_module = parseIR(*mem_buf, diagnostic, *llvm_ctx); if (llvm_module == nullptr) { - LOG(WARNING) << "Parse module failed: module string is\n" << buf_str; + LOG(WARNING) << "Parse module failed: module string is:\n" << buf_str; std::string err_msg; llvm::raw_string_ostream err_msg_stream(err_msg); - diagnostic.print("", err_msg_stream); + diagnostic.print("llvm_module_from_buf", err_msg_stream); + err_msg_stream.flush(); + LOG(WARNING) << err_msg; return false; } return this->AddModule(std::move(llvm_module), std::move(llvm_ctx)); } +HybridSeJitWrapper::HybridSeJitWrapper() {} + +base::Status HybridSeJitWrapper::InitJitSymbols() { + // default builitin + InitBuiltinJitSymbols(this); + // udf external functions + if (lib_ == nullptr) { + lib_ = udf::DefaultUdfLibrary::get(); + } + lib_->InitJITSymbols(this); + return {}; +} + bool HybridSeJitWrapper::InitJitSymbols(HybridSeJitWrapper* jit) { + // default builitin InitBuiltinJitSymbols(jit); + // udf external functions udf::DefaultUdfLibrary::get()->InitJITSymbols(jit); return true; } @@ -69,6 +88,18 @@ HybridSeJitWrapper* HybridSeJitWrapper::Create() { return Create(JitOptions()); } +HybridSeJitWrapper* HybridSeJitWrapper::CreateWithDefaultSymbols(udf::UdfLibrary* lib, base::Status* status, + const JitOptions& options) { + auto jit = vm::HybridSeJitWrapper::Create(options); + jit->SetLib(lib); + if (!jit->Init()) { + LOG(WARNING) << "fail to init jit"; + *status = {common::kCodegenError, "fail to init jit"}; + return nullptr; + } + return jit; +} + HybridSeJitWrapper* HybridSeJitWrapper::Create(const JitOptions& jit_options) { if (jit_options.IsEnableMcjit()) { #ifdef LLVM_EXT_ENABLE diff --git a/hybridse/src/vm/jit_wrapper.h b/hybridse/src/vm/jit_wrapper.h index b0bbb70c6ec..f7a578a4306 100644 --- a/hybridse/src/vm/jit_wrapper.h +++ b/hybridse/src/vm/jit_wrapper.h @@ -25,16 +25,21 @@ #include "vm/engine_context.h" namespace hybridse { + +namespace udf { +class UdfLibrary; +} namespace vm { class JitOptions; class HybridSeJitWrapper { public: - HybridSeJitWrapper() {} - virtual ~HybridSeJitWrapper() {} + HybridSeJitWrapper(); HybridSeJitWrapper(const HybridSeJitWrapper&) = delete; + virtual ~HybridSeJitWrapper() {} + virtual bool Init() = 0; virtual bool OptModule(::llvm::Module* module) = 0; @@ -47,14 +52,32 @@ class HybridSeJitWrapper { virtual hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) = 0; + // create the JIT wrapper with default builtin symbols imported already + static HybridSeJitWrapper* CreateWithDefaultSymbols(udf::UdfLibrary*, base::Status*, + const JitOptions& jit_options = {}); + static HybridSeJitWrapper* Create(const JitOptions& jit_options); static HybridSeJitWrapper* Create(); - static void DeleteJit(HybridSeJitWrapper* jit); + // TODO(someone): remove it, java wrapper should ensure deletion + // deprecated + static void DeleteJit(HybridSeJitWrapper* jit); + // deprecated, use InitJitSymbols() static bool InitJitSymbols(HybridSeJitWrapper* jit); -}; -void InitBuiltinJitSymbols(HybridSeJitWrapper* jit_ptr); + void SetLib(udf::UdfLibrary* lib) { lib_ = lib; } + + protected: + void EnsureInitialized() { assert(initialized_ && "JitWrapper must initialize explicitly"); } + + base::Status InitJitSymbols(); + + // lib_ is determined during Init, or you should explicitly + // set lib via SetLib before Init + udf::UdfLibrary* lib_ = nullptr; + + bool initialized_ = false; +}; } // namespace vm } // namespace hybridse diff --git a/hybridse/src/vm/jit_wrapper_test.cc b/hybridse/src/vm/jit_wrapper_test.cc index f783496e0ff..89d521c0dde 100644 --- a/hybridse/src/vm/jit_wrapper_test.cc +++ b/hybridse/src/vm/jit_wrapper_test.cc @@ -83,7 +83,6 @@ void simple_test(const EngineOptions &options) { ASSERT_FALSE(ir_str.empty()); HybridSeJitWrapper *jit = HybridSeJitWrapper::Create(); ASSERT_TRUE(jit->Init()); - HybridSeJitWrapper::InitJitSymbols(jit); base::RawBuffer ir_buf(const_cast(ir_str.data()), ir_str.size()); ASSERT_TRUE(jit->AddModuleFromBuffer(ir_buf)); @@ -151,7 +150,6 @@ TEST_F(JitWrapperTest, test_window) { ASSERT_FALSE(ir_str.empty()); HybridSeJitWrapper *jit = HybridSeJitWrapper::Create(); ASSERT_TRUE(jit->Init()); - HybridSeJitWrapper::InitJitSymbols(jit); base::RawBuffer ir_buf(const_cast(ir_str.data()), ir_str.size()); ASSERT_TRUE(jit->AddModuleFromBuffer(ir_buf)); diff --git a/hybridse/src/vm/sql_compiler.cc b/hybridse/src/vm/sql_compiler.cc index ea5626545ee..7b53c50143c 100644 --- a/hybridse/src/vm/sql_compiler.cc +++ b/hybridse/src/vm/sql_compiler.cc @@ -108,17 +108,12 @@ bool SqlCompiler::Compile(SqlContext& ctx, Status& status) { // NOLINT m->print(::llvm::errs(), NULL, true, true); return false; } - // ::llvm::errs() << *(m.get()); auto jit = std::shared_ptr( - HybridSeJitWrapper::Create(ctx.jit_options)); - if (jit == nullptr || !jit->Init()) { - status.msg = "fail to init jit let"; - status.code = common::kJitError; + HybridSeJitWrapper::CreateWithDefaultSymbols(ctx.udf_library, &status, ctx.jit_options)); + if (!status.isOK()) { LOG(WARNING) << status; return false; } - InitBuiltinJitSymbols(jit.get()); - ctx.udf_library->InitJITSymbols(jit.get()); if (!jit->OptModule(m.get())) { LOG(WARNING) << "fail to opt ir module for sql " << ctx.sql; return false; diff --git a/java/hybridse-sdk/src/main/java/com/_4paradigm/hybridse/sdk/JitManager.java b/java/hybridse-sdk/src/main/java/com/_4paradigm/hybridse/sdk/JitManager.java index 4dc90a5870d..a98ddf3d6fb 100644 --- a/java/hybridse-sdk/src/main/java/com/_4paradigm/hybridse/sdk/JitManager.java +++ b/java/hybridse-sdk/src/main/java/com/_4paradigm/hybridse/sdk/JitManager.java @@ -54,7 +54,6 @@ public static synchronized HybridSeJitWrapper getJit(String tag) { if(!jit.Init()){ throw new RuntimeException("Fail to init jit"); } - HybridSeJitWrapper.InitJitSymbols(jit); jits.put(tag, jit); } return jits.get(tag); @@ -135,6 +134,8 @@ public static synchronized void initJitModule(String tag, ByteBuffer moduleBuffe * @param tag module tag */ public static synchronized void removeModule(String tag) { + // HybridSeJitWrapper is a proxy class to C pointer, Java do not automatic + // lifetime of C pointer, so it must made explicitly initializedModuleTags.remove(tag); HybridSeJitWrapper jit = jits.remove(tag); if (jit != null) { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala index 0de5bcab239..f89a97d2d4d 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala @@ -18,17 +18,24 @@ package com._4paradigm.openmldb.batch import java.sql.{Date, Timestamp} -import com._4paradigm.hybridse.codec.{RowBuilder, RowView, Row => NativeRow} +import com._4paradigm.hybridse.codec.{RowBuilder, RowView, Row => NativeRow, RowBuilder2} +import com._4paradigm.std.{ExprNodeVector} import com._4paradigm.hybridse.sdk.HybridSeException +import com._4paradigm.hybridse.sdk.JitManager import com._4paradigm.hybridse.vm.CoreAPI +import com._4paradigm.hybridse.node.{NodeManager, ExprNode} import com._4paradigm.openmldb.batch.utils.HybridseUtil import org.apache.spark.sql.Row import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, - StringType, StructType, TimestampType} + StringType, StructType, TimestampType, MapType, ArrayType, DataType} import org.slf4j.LoggerFactory import java.util.Calendar +import java.util.concurrent.atomic.AtomicInteger +import java.text.SimpleDateFormat +import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.mutable +import com._4paradigm.hybridse.`type`.TypeOuterClass.ColumnDef class SparkRowCodec(sliceSchemas: Array[StructType]) { @@ -37,6 +44,9 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { private val sliceNum = sliceSchemas.length private val columnDefSegmentList = sliceSchemas.map(HybridseUtil.getHybridseSchema) + private val newEncoder = + columnDefSegmentList.foldLeft(false)((res, sc) => res || requireNewEncoder(sc.asScala.toArray)) + // for encode private var rowBuilders = columnDefSegmentList.map(cols => new RowBuilder(cols)) @@ -48,8 +58,43 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { private val sliceFieldOffsets = sliceSchemas.scanLeft(0)( (cur, schema) => cur + schema.size) + private val jitCounter = new AtomicInteger(0) def encode(row: Row): NativeRow = { + if (newEncoder) { + val cnt = jitCounter.getAndIncrement(); + // FIXME(#3748): native codegen row builder have issue adding multiple modules to + // the same jit instance: duplicated symbol error. + // this is the work-around that ensure every row encoded with distinct jit instance + val tag = "rowbuilder2_" + Thread.currentThread().getId + "_" + cnt; + try { + val jit = JitManager.getJit(tag) + var newRowBuilder = new RowBuilder2(jit, sliceSchemas.size) + for (i <- 0 until columnDefSegmentList.size) { + val s = newRowBuilder.InitSchema(i, columnDefSegmentList(i)) + if (!s.isOK()) { + throw new HybridSeException(s.str()) + } + } + var s = newRowBuilder.Init() + if (!s.isOK()) { + throw new HybridSeException(s.str()) + } + + var nm = new NodeManager(); + var vec = sparkRowToNativeExprVec(row, nm) + + var nativeRow = new NativeRow() + s = newRowBuilder.Build(vec, nativeRow) + if (!s.isOK()) { + throw new HybridSeException(s.str()) + } + return nativeRow + } finally { + JitManager.removeModule(tag) + } + } + var result: NativeRow = null // collect slice size and string raw bytes val sliceSizes = Array.fill(sliceNum)(0) @@ -153,7 +198,7 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { } } - def encodeSingle(row: Row, outBuf: Long, outSize: Int, + private def encodeSingle(row: Row, outBuf: Long, outSize: Int, sliceStrings: Seq[Array[Byte]], sliceIndex: Int): Unit = { val rowBuilder = rowBuilders(sliceIndex) val schema = sliceSchemas(sliceIndex) @@ -216,7 +261,7 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { } - def decodeSingle(nativeRow: NativeRow, output: Array[Any], sliceIndex: Int): Unit = { + private def decodeSingle(nativeRow: NativeRow, output: Array[Any], sliceIndex: Int): Unit = { val rowView = rowViews(sliceIndex) val schema = sliceSchemas(sliceIndex) @@ -276,6 +321,111 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { }) } + private def requireNewEncoder(schema: Array[ColumnDef]): Boolean = { + schema.foldLeft(false)((res, col) => { + res || col.getSchema().hasArrayType() || col.getSchema().hasMapType() + }) + } + + private def sparkRowToNativeExprVec(row: Row, nm: NodeManager): ExprNodeVector = { + var vec = new ExprNodeVector() + + var idx = 0 + sliceSchemas.foreach(st => { + st.foreach(col => { + var expr = sparkColToNativeExpr(row, idx, col.dataType, nm) + vec.add(expr) + idx += 1 + }) + }) + + vec + } + + private def sparkColToNativeExpr(row: Row, idx: Int, dataType: DataType, nm: NodeManager): ExprNode = { + if (row.isNullAt(idx)) { + nm.MakeConstNode() + } else { + dataType match { + case ShortType => + nm.MakeConstNode(row.getShort(idx)) + case IntegerType => + nm.MakeConstNode(row.getInt(idx)) + case LongType => + nm.MakeConstNode(row.getLong(idx)) + case FloatType => + nm.MakeConstNode(row.getFloat(idx)) + case DoubleType => + nm.MakeConstNode(row.getDouble(idx)) + case BooleanType => + nm.MakeConstNode(row.getBoolean(idx)) + case StringType => { + // generally safe, native ConstNode copyed the string + val str = row.getString(idx) + nm.MakeConstNode(str) + } + case TimestampType => { + var args = nm.MakeExprList() + args.AddChild(nm.MakeConstNode(row.getTimestamp(idx).getTime)) + nm.MakeFuncNode("timestamp", args, null) + } + case DateType => { + val date = row.getDate(idx) + val fmt = new SimpleDateFormat("yyyy-MM-dd") + var str = fmt.format(date) + var args = nm.MakeExprList() + args.AddChild(nm.MakeConstNode(str)) + nm.MakeFuncNode("date", args, null) + } + case MapType(keyType, valType, valContainsNull) => { + var mapVal = row.getMap[Any, Any](idx) + // logger.warn(s"map val ${mapVal}") + var args = nm.MakeExprList() + mapVal.foreach(kv => { + args.AddChild(valToNativeExpr(kv._1, keyType, nm)) + args.AddChild(valToNativeExpr(kv._2, valType, nm)) + }) + nm.MakeFuncNode("map", args, null) + } + case _ => throw new IllegalArgumentException( + s"Spark type ${dataType} not supported") + } + } + } + + private def valToNativeExpr(v: Any, dataType: DataType, nm: NodeManager): ExprNode = { + dataType match { + case ShortType => nm.MakeConstNode(v.asInstanceOf[Short]) + case IntegerType => nm.MakeConstNode(v.asInstanceOf[Int]) + case LongType => nm.MakeConstNode(v.asInstanceOf[Long]) + case FloatType => nm.MakeConstNode(v.asInstanceOf[Float]) + case DoubleType => nm.MakeConstNode(v.asInstanceOf[Double]) + case BooleanType => nm.MakeConstNode(v.asInstanceOf[Boolean]) + case StringType => nm.MakeConstNode(v.asInstanceOf[String]) + case TimestampType => + var args = nm.MakeExprList() + args.AddChild(nm.MakeConstNode(v.asInstanceOf[Long])) + nm.MakeFuncNode("timestamp", args, null) + case DateType => + // date from string literal + var args = nm.MakeExprList() + args.AddChild(nm.MakeConstNode(v.asInstanceOf[String])) + nm.MakeFuncNode("date", args, null) + case MapType(keyType, valType, _) => { + var mapVal = v.asInstanceOf[Map[Any, Any]] + var args = nm.MakeExprList() + mapVal.foreach(kv => { + args.AddChild(valToNativeExpr(kv._1, keyType, nm)) + args.AddChild(valToNativeExpr(kv._2, valType, nm)) + }) + // TODO(someone): support empty map, since 'map()' inferred as map + // we need construst a extra cast operation to hint the true type from schema + nm.MakeFuncNode("map", args, null) + } + case _ => throw new IllegalArgumentException( + s"Spark type ${dataType} not supported") + } + } def delete(): Unit = { if (rowViews != null) { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala index 1f9d34dcb8b..2e1d1c429c4 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala @@ -96,7 +96,11 @@ object SimpleProjectPlan { val sparkCol = SparkColumnUtil.getColumnFromIndex(inputDf, colIndex) val sparkType = inputDf.schema(colIndex).dataType val schemaType = DataTypeUtil.sparkTypeToHybridseProtoType(sparkType) - val innerType = DataTypeUtil.hybridseProtoTypeToOpenmldbType(schemaType) + if (!schemaType.hasBaseType()) { + throw new UnsupportedHybridSeException( + s"expression output type does not expect to be ${schemaType} for simple project") + } + val innerType = DataTypeUtil.hybridseProtoTypeToOpenmldbType(schemaType.getBaseType()) sparkCol -> innerType case ExprType.kExprPrimary => diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/DataTypeUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/DataTypeUtil.scala index d1199fc27ea..5bdc16ef6f6 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/DataTypeUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/DataTypeUtil.scala @@ -16,27 +16,50 @@ package com._4paradigm.openmldb.batch.utils -import com._4paradigm.hybridse.`type`.TypeOuterClass.{Type => HybridseProtoType} +import com._4paradigm.hybridse.`type`.TypeOuterClass.{Type => HybridseProtoType, ColumnSchema} import org.apache.spark.sql.types.{BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, - ShortType, StringType, TimestampType} + ShortType, StringType, TimestampType, MapType, ArrayType, StructField} import com._4paradigm.hybridse.node.{DataType => OpenmldbDataType} +import com._4paradigm.hybridse.`type`.TypeOuterClass.ColumnSchema.TypeCase.BASE_TYPE +import com._4paradigm.hybridse.`type`.TypeOuterClass.ColumnSchema.TypeCase.ARRAY_TYPE +import com._4paradigm.hybridse.`type`.TypeOuterClass.ColumnSchema.TypeCase.MAP_TYPE +import com._4paradigm.hybridse.`type`.TypeOuterClass.{ArrayType => HSArrayType, MapType => HSMapType} object DataTypeUtil { - def sparkTypeToHybridseProtoType(dtype: DataType): HybridseProtoType = { + def sparkTypeToHybridseProtoType(dtype: DataType): ColumnSchema = { + var builder = ColumnSchema.newBuilder() dtype match { - case ShortType => HybridseProtoType.kInt16 - case IntegerType => HybridseProtoType.kInt32 - case LongType => HybridseProtoType.kInt64 - case FloatType => HybridseProtoType.kFloat - case DoubleType => HybridseProtoType.kDouble - case BooleanType => HybridseProtoType.kBool - case StringType => HybridseProtoType.kVarchar - case DateType => HybridseProtoType.kDate - case TimestampType => HybridseProtoType.kTimestamp + case ShortType => builder.setBaseType(HybridseProtoType.kInt16) + case IntegerType => builder.setBaseType(HybridseProtoType.kInt32) + case LongType => builder.setBaseType(HybridseProtoType.kInt64) + case FloatType => builder.setBaseType(HybridseProtoType.kFloat) + case DoubleType => builder.setBaseType(HybridseProtoType.kDouble) + case BooleanType => builder.setBaseType(HybridseProtoType.kBool) + case StringType => builder.setBaseType(HybridseProtoType.kVarchar) + case DateType => builder.setBaseType(HybridseProtoType.kDate) + case TimestampType => builder.setBaseType(HybridseProtoType.kTimestamp) + case ArrayType(eleType, containsNull) => { + var hsArrType = HSArrayType.newBuilder() + .setEleType(sparkTypeToHybridseProtoType(eleType) + .toBuilder().setIsNotNull(!containsNull)) + .build() + builder.setArrayType(hsArrType) + } + case MapType(keyType, valueType, valContainsNull) => { + var hsKeyType = sparkTypeToHybridseProtoType(keyType) + var hsValType = sparkTypeToHybridseProtoType(valueType) + .toBuilder().setIsNotNull(!valContainsNull) + builder.setMapType(HSMapType.newBuilder() + .setKeyType(hsKeyType) + .setValueType(hsValType) + .build()) + } case _ => throw new IllegalArgumentException( s"Spark type $dtype not supported") } + + builder.build() } def openmldbTypeToProtoType(dtype: OpenmldbDataType): HybridseProtoType = { @@ -55,17 +78,33 @@ object DataTypeUtil { } } - def hybridseProtoTypeToSparkType(dtype: HybridseProtoType): DataType = { - dtype match { - case HybridseProtoType.kInt16 => ShortType - case HybridseProtoType.kInt32 => IntegerType - case HybridseProtoType.kInt64 => LongType - case HybridseProtoType.kFloat => FloatType - case HybridseProtoType.kDouble => DoubleType - case HybridseProtoType.kBool => BooleanType - case HybridseProtoType.kVarchar => StringType - case HybridseProtoType.kDate => DateType - case HybridseProtoType.kTimestamp => TimestampType + def hybridseProtoTypeToSparkType(dtype: ColumnSchema): DataType = { + dtype.getTypeCase() match { + case BASE_TYPE => { + dtype.getBaseType() match { + case HybridseProtoType.kInt16 => ShortType + case HybridseProtoType.kInt32 => IntegerType + case HybridseProtoType.kInt64 => LongType + case HybridseProtoType.kFloat => FloatType + case HybridseProtoType.kDouble => DoubleType + case HybridseProtoType.kBool => BooleanType + case HybridseProtoType.kVarchar => StringType + case HybridseProtoType.kDate => DateType + case HybridseProtoType.kTimestamp => TimestampType + case _ => throw new IllegalArgumentException( + s"HybridSE type $dtype not supported") + } + } + case ARRAY_TYPE => { + var ele = hybridseProtoTypeToSparkType(dtype.getArrayType().getEleType()) + ArrayType(ele, !dtype.getIsNotNull()) + } + case MAP_TYPE => { + var keyType = hybridseProtoTypeToSparkType(dtype.getMapType().getKeyType()) + var valueType = hybridseProtoTypeToSparkType(dtype.getMapType().getValueType()) + var valNull = !dtype.getMapType().getValueType().getIsNotNull() + MapType(keyType, valueType, valNull) + } case _ => throw new IllegalArgumentException( s"HybridSE type $dtype not supported") } @@ -148,6 +187,10 @@ object DataTypeUtil { case StringType => "string" case DateType => "date" case TimestampType => "timestamp" + case MapType(keyType, valueType, valueContainsNull) => { + var valueAttr = if (valueContainsNull) { "" } else { "NOT NULL" } + s"MAP<${sparkTypeToString(keyType)}, ${sparkTypeToString(valueType)} $valueAttr>" + } case _ => throw new IllegalArgumentException( s"Spark type $dataType not supported") } diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala index ee588ab677d..291b3c83c09 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala @@ -17,7 +17,7 @@ package com._4paradigm.openmldb.batch.utils import java.util -import com._4paradigm.hybridse.`type`.TypeOuterClass.{ColumnDef, Database, TableDef} +import com._4paradigm.hybridse.`type`.TypeOuterClass.{ColumnDef, Database, TableDef, Type => HybridseProtoType} import com._4paradigm.hybridse.node.ConstNode import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException import com._4paradigm.hybridse.vm.{PhysicalLoadDataNode, PhysicalOpNode, PhysicalSelectIntoNode} @@ -70,10 +70,13 @@ object HybridseUtil { def getTableDef(tableName: String, dataFrame: DataFrame): TableDef = { val tblBulder = TableDef.newBuilder() dataFrame.schema.foreach(field => { - tblBulder.addColumns(ColumnDef.newBuilder() + var sc = DataTypeUtil.sparkTypeToHybridseProtoType(field.dataType) + tblBulder.addColumns( + ColumnDef.newBuilder() .setName(field.name) .setIsNotNull(!field.nullable) - .setType(DataTypeUtil.sparkTypeToHybridseProtoType(field.dataType)) + .setSchema(sc) + .setType(if (sc.hasBaseType()) {sc.getBaseType()} else {HybridseProtoType.kNull}) .build() ) }) @@ -84,17 +87,20 @@ object HybridseUtil { def getHybridseSchema(structType: StructType): java.util.List[ColumnDef] = { val list = new util.ArrayList[ColumnDef]() structType.foreach(field => { + var sc = DataTypeUtil.sparkTypeToHybridseProtoType(field.dataType) list.add(ColumnDef.newBuilder() .setName(field.name) .setIsNotNull(!field.nullable) - .setType(DataTypeUtil.sparkTypeToHybridseProtoType(field.dataType)).build()) + .setSchema(sc) + .setType(if (sc.hasBaseType()) {sc.getBaseType()} else {HybridseProtoType.kNull}) + .build()) }) list } def getSparkSchema(columns: java.util.List[ColumnDef]): StructType = { StructType(columns.asScala.map(col => { - StructField(col.getName, DataTypeUtil.hybridseProtoTypeToSparkType(col.getType), !col.getIsNotNull) + StructField(col.getName, DataTypeUtil.hybridseProtoTypeToSparkType(col.getSchema), !col.getIsNotNull) })) } diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala index 0cd852f3268..09990296c9b 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala @@ -20,7 +20,7 @@ import com._4paradigm.openmldb.batch.SparkTestSuite import com._4paradigm.openmldb.batch.api.OpenmldbSession import com._4paradigm.openmldb.batch.utils.SparkUtil import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, MapType} class TestProject extends SparkTestSuite { @@ -43,4 +43,27 @@ class TestProject extends SparkTestSuite { } + test("Test end2end row project with map values") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val data = Seq( + Row(1, Map.apply(1 -> "11", 12 -> "99")), + Row(2, Map.apply(13 -> "99"))) + // Row(2, Map.empty[Int, String])) + val schema = StructType(List( + StructField("id", IntegerType), + StructField("val", MapType(IntegerType, StringType)))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(data), schema) + + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText = "select id, val[12] as ele from t1" + val outputDf = sess.sql(sqlText) + outputDf.show() + + val sparksqlOutputDf = sess.sparksql(sqlText) + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + } } diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestUnsafeRowProject.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestUnsafeRowProject.scala index d34af44b521..7068081a9f8 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestUnsafeRowProject.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestUnsafeRowProject.scala @@ -58,5 +58,4 @@ class TestUnsafeRowProject extends SparkTestSuite { val sparksqlOutputDf = sess.sparksql(sqlText) assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) } - } diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 4a4f815ab0a..ac5b8ec5153 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -475,14 +475,12 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s // TODO(someone): // 1. default value from table definition // 2. parameters - ::hybridse::codegen::InsertRowBuilder insert_builder(&sc); - { - auto s = insert_builder.Init(); - if (!s.ok()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, s.ToString()); - return false; - } + ::hybridse::vm::Engine::InitializeGlobalLLVM(); + auto jit = std::shared_ptr(hybridse::vm::HybridSeJitWrapper::Create()); + if (!jit->Init()) { + return false; } + ::hybridse::codegen::InsertRowBuilder insert_builder(jit.get(), &sc); size_t total_rows_size = insert_stmt->values_.size(); for (size_t i = 0; i < total_rows_size; i++) {