Skip to content

Commit

Permalink
feat(jit_wrapper): simplify jit initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
aceforeverd committed Feb 7, 2024
1 parent 69a4bb4 commit 6b6de97
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 44 deletions.
3 changes: 0 additions & 3 deletions hybridse/src/codegen/buf_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ void RunEncode(::hybridse::type::TableDef& table, // NOLINT
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(
vm::HybridSeJitWrapper::Create());
jit->Init();
vm::HybridSeJitWrapper::InitJitSymbols(jit.get());
ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx)));
auto load_fn_jit = jit->FindFunction("fn");
void (*decode)(int8_t**) =
Expand Down Expand Up @@ -306,7 +305,6 @@ void LoadValue(T* result, bool* is_null,
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(
vm::HybridSeJitWrapper::Create());
jit->Init();
vm::HybridSeJitWrapper::InitJitSymbols(jit.get());
ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx)));
auto load_fn_jit = jit->FindFunction("fn");

Expand Down Expand Up @@ -437,7 +435,6 @@ void RunColCase(T expected, type::TableDef& table, // NOLINT
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(
vm::HybridSeJitWrapper::Create());
jit->Init();
vm::HybridSeJitWrapper::InitJitSymbols(jit.get());
ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx)));
jit->AddExternalFunction("print_list_i16",
reinterpret_cast<void*>(&PrintListInt16));
Expand Down
1 change: 0 additions & 1 deletion hybridse/src/codegen/fn_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ void CheckResult(node::FnNodeFnDef *fn_def, R exp, V1 a, V2 b) {
m->print(::llvm::errs(), NULL, true, true);
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(vm::HybridSeJitWrapper::Create());
jit->Init();
vm::HybridSeJitWrapper::InitJitSymbols(jit.get());
ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx)));
auto test_fn = (R(*)(V1, V2))jit->FindFunction(fn_def->header_->GeIRFunctionName());
R result = test_fn(a, b);
Expand Down
1 change: 0 additions & 1 deletion hybridse/src/codegen/fn_let_ir_builder_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ void CheckFnLetBuilderWithParameterRow(::hybridse::node::NodeManager* manager, v
auto jit = std::unique_ptr<vm::HybridSeJitWrapper>(
vm::HybridSeJitWrapper::Create());
jit->Init();
vm::HybridSeJitWrapper::InitJitSymbols(jit.get());

ASSERT_TRUE(jit->AddModule(std::move(m), std::move(ctx)));
auto address = jit->FindFunction("test_at_fn");
Expand Down
3 changes: 1 addition & 2 deletions hybridse/src/codegen/insert_row_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ TEST_F(InsertRowBuilderTest, encode) {

auto jit = std::shared_ptr<vm::HybridSeJitWrapper>(vm::HybridSeJitWrapper::Create());
ASSERT_TRUE(jit->Init());
ASSERT_TRUE(vm::HybridSeJitWrapper::InitJitSymbols(jit.get()));

InsertRowBuilder builder(jit, &sc);
InsertRowBuilder builder(jit.get(), &sc);

auto as = builder.ComputeRow(dynamic_cast<node::ExprListNode*>(exprlist));
ASSERT_TRUE(as.ok()) << as.status();
Expand Down
14 changes: 6 additions & 8 deletions hybridse/src/codegen/ir_base_builder_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ class ModuleTestFunction {
std::unique_ptr<::llvm::LLVMContext> llvm_ctx) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
jit = std::unique_ptr<vm::HybridSeJitWrapper>(
vm::HybridSeJitWrapper::Create());
jit->Init();
InitBuiltinJitSymbols(jit.get());
if (library != nullptr) {
library->InitJITSymbols(jit.get());
} else {
udf::DefaultUdfLibrary::get()->InitJITSymbols(jit.get());
base::Status s;
jit =
std::unique_ptr<vm::HybridSeJitWrapper>(vm::HybridSeJitWrapper::CreateWithDefaultSymbols(library, &s, {}));
if (jit == nullptr || !s.isOK()) {
LOG(WARNING) << "create jit failed" << s;
return;

Check warning on line 205 in hybridse/src/codegen/ir_base_builder_test.h

View check run for this annotation

Codecov / codecov/patch

hybridse/src/codegen/ir_base_builder_test.h#L204-L205

Added lines #L204 - L205 were not covered by tests
}

llvm::errs() << *(module.get()) << "\n";
Expand Down
16 changes: 15 additions & 1 deletion hybridse/src/vm/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ bool HybridSeLlvmJitWrapper::Init() {
this->mi_ = std::unique_ptr<::llvm::orc::MangleAndInterner>(
new ::llvm::orc::MangleAndInterner(jit_->getExecutionSession(),
jit_->getDataLayout()));

auto s = InitJitSymbols();
if (!s.isOK()) {
LOG(WARNING) << s;
return false;

Check warning on line 173 in hybridse/src/vm/jit.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/jit.cc#L172-L173

Added lines #L172 - L173 were not covered by tests
}

return true;
}

Expand Down Expand Up @@ -206,7 +213,14 @@ bool HybridSeLlvmJitWrapper::AddExternalFunction(const std::string& name,
}

#ifdef LLVM_EXT_ENABLE
bool HybridSeMcJitWrapper::Init() { return true; }
bool HybridSeMcJitWrapper::Init() {
auto s = InitJitSymbols();
if (!s.isOK()) {
LOG(WARNING) << s;
return false;
}
return true;
}

bool HybridSeMcJitWrapper::OptModule(::llvm::Module* module) {
DLOG(INFO) << "Module before opt:\n" << LlvmToString(*module);
Expand Down
30 changes: 30 additions & 0 deletions hybridse/src/vm/jit_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
namespace hybridse {
namespace vm {

static void InitBuiltinJitSymbols(HybridSeJitWrapper* jit_ptr);

bool HybridSeJitWrapper::AddModuleFromBuffer(const base::RawBuffer& buf) {
std::string buf_str(buf.addr, buf.size);
::llvm::SMDiagnostic diagnostic;
Expand All @@ -61,8 +63,24 @@ bool HybridSeJitWrapper::AddModuleFromBuffer(const base::RawBuffer& buf) {
return this->AddModule(std::move(llvm_module), std::move(llvm_ctx));
}

HybridSeJitWrapper::HybridSeJitWrapper() {}
HybridSeJitWrapper::HybridSeJitWrapper(udf::UdfLibrary* lib) : lib_(lib) {}

Check warning on line 67 in hybridse/src/vm/jit_wrapper.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/jit_wrapper.cc#L67

Added line #L67 was not covered by tests

base::Status HybridSeJitWrapper::InitJitSymbols() {
// default builitin
InitBuiltinJitSymbols(this);
// udf external functions
if (lib_ == nullptr) {
lib_ = udf::DefaultUdfLibrary::get();
}
lib_->InitJITSymbols(this);
return {};
}

bool HybridSeJitWrapper::InitJitSymbols(HybridSeJitWrapper* jit) {
// default builitin
InitBuiltinJitSymbols(jit);
// udf external functions
udf::DefaultUdfLibrary::get()->InitJITSymbols(jit);
return true;
}
Expand All @@ -71,6 +89,18 @@ HybridSeJitWrapper* HybridSeJitWrapper::Create() {
return Create(JitOptions());
}

HybridSeJitWrapper* HybridSeJitWrapper::CreateWithDefaultSymbols(udf::UdfLibrary* lib, base::Status* status,
const JitOptions& options) {
auto jit = vm::HybridSeJitWrapper::Create(options);
jit->SetLib(lib);
if (!jit->Init()) {
LOG(WARNING) << "fail to init jit";
*status = {common::kCodegenError, "fail to init jit"};
return nullptr;

Check warning on line 99 in hybridse/src/vm/jit_wrapper.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/jit_wrapper.cc#L97-L99

Added lines #L97 - L99 were not covered by tests
}
return jit;
}

HybridSeJitWrapper* HybridSeJitWrapper::Create(const JitOptions& jit_options) {
if (jit_options.IsEnableMcjit()) {
#ifdef LLVM_EXT_ENABLE
Expand Down
28 changes: 23 additions & 5 deletions hybridse/src/vm/jit_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,22 @@
#include "vm/engine_context.h"

namespace hybridse {

namespace udf {
class UdfLibrary;
}
namespace vm {

class JitOptions;

class HybridSeJitWrapper {
public:
HybridSeJitWrapper() {}
virtual ~HybridSeJitWrapper() {}
HybridSeJitWrapper();
explicit HybridSeJitWrapper(udf::UdfLibrary* lib);
HybridSeJitWrapper(const HybridSeJitWrapper&) = delete;

virtual ~HybridSeJitWrapper() {}

virtual bool Init() = 0;
virtual bool OptModule(::llvm::Module* module) = 0;

Expand All @@ -47,14 +53,26 @@ class HybridSeJitWrapper {

virtual hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) = 0;

// create the JIT wrapper with default builtin symbols imported already
static HybridSeJitWrapper* CreateWithDefaultSymbols(udf::UdfLibrary*, base::Status*,
const JitOptions& jit_options = {});

static HybridSeJitWrapper* Create(const JitOptions& jit_options);
static HybridSeJitWrapper* Create();
static void DeleteJit(HybridSeJitWrapper* jit);

// TODO(someone): remove it, java wrapper should ensure deletion
// deprecated
static void DeleteJit(HybridSeJitWrapper* jit);
// deprecated, use InitJitSymbols()
static bool InitJitSymbols(HybridSeJitWrapper* jit);
};

void InitBuiltinJitSymbols(HybridSeJitWrapper* jit_ptr);
void SetLib(udf::UdfLibrary* lib) { lib_ = lib; }

protected:
base::Status InitJitSymbols();

udf::UdfLibrary* lib_ = nullptr;
};

} // namespace vm
} // namespace hybridse
Expand Down
2 changes: 0 additions & 2 deletions hybridse/src/vm/jit_wrapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ void simple_test(const EngineOptions &options) {
ASSERT_FALSE(ir_str.empty());
HybridSeJitWrapper *jit = HybridSeJitWrapper::Create();
ASSERT_TRUE(jit->Init());
HybridSeJitWrapper::InitJitSymbols(jit);

base::RawBuffer ir_buf(const_cast<char *>(ir_str.data()), ir_str.size());
ASSERT_TRUE(jit->AddModuleFromBuffer(ir_buf));
Expand Down Expand Up @@ -151,7 +150,6 @@ TEST_F(JitWrapperTest, test_window) {
ASSERT_FALSE(ir_str.empty());
HybridSeJitWrapper *jit = HybridSeJitWrapper::Create();
ASSERT_TRUE(jit->Init());
HybridSeJitWrapper::InitJitSymbols(jit);

base::RawBuffer ir_buf(const_cast<char *>(ir_str.data()), ir_str.size());
ASSERT_TRUE(jit->AddModuleFromBuffer(ir_buf));
Expand Down
8 changes: 2 additions & 6 deletions hybridse/src/vm/sql_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,11 @@ bool SqlCompiler::Compile(SqlContext& ctx, Status& status) { // NOLINT
}
// ::llvm::errs() << *(m.get());
auto jit = std::shared_ptr<HybridSeJitWrapper>(
HybridSeJitWrapper::Create(ctx.jit_options));
if (jit == nullptr || !jit->Init()) {
status.msg = "fail to init jit let";
status.code = common::kJitError;
HybridSeJitWrapper::CreateWithDefaultSymbols(ctx.udf_library, &status, ctx.jit_options));
if (!status.isOK()) {
LOG(WARNING) << status;
return false;
}
InitBuiltinJitSymbols(jit.get());
ctx.udf_library->InitJITSymbols(jit.get());
if (!jit->OptModule(m.get())) {
LOG(WARNING) << "fail to opt ir module for sql " << ctx.sql;
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,11 @@ public static synchronized HybridSeJitWrapper getJit(String tag) {
if(!jit.Init()){
throw new RuntimeException("Fail to init jit");
}
HybridSeJitWrapper.InitJitSymbols(jit);
jits.put(tag, jit);
}
return jits.get(tag);
}


public static synchronized void rmoveJit(String tag) {
// HybridSeJitWrapper is a proxy class to C pointer, Java do not automatic
// lifetime of C pointer, so it must made explicitly
HybridSeJitWrapper e = jits.remove(tag);
if (e != null) {
HybridSeJitWrapper.DeleteJit(e);
}
}

private static JitOptions getJitOptions() {
JitOptions options = new JitOptions();
try (InputStream input = JitManager.class.getClassLoader().getResourceAsStream("jit.properties")) {
Expand Down Expand Up @@ -145,6 +134,8 @@ public static synchronized void initJitModule(String tag, ByteBuffer moduleBuffe
* @param tag module tag
*/
public static synchronized void removeModule(String tag) {
// HybridSeJitWrapper is a proxy class to C pointer, Java do not automatic
// lifetime of C pointer, so it must made explicitly
initializedModuleTags.remove(tag);
HybridSeJitWrapper jit = jits.remove(tag);
if (jit != null) {
Expand Down
5 changes: 1 addition & 4 deletions src/sdk/sql_cluster_router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s
if (!jit->Init()) {
return false;

Check warning on line 481 in src/sdk/sql_cluster_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_cluster_router.cc#L481

Added line #L481 was not covered by tests
}
if (!hybridse::vm::HybridSeJitWrapper::InitJitSymbols(jit.get())) {
return false;
}
::hybridse::codegen::InsertRowBuilder insert_builder(jit, &sc);
::hybridse::codegen::InsertRowBuilder insert_builder(jit.get(), &sc);

size_t total_rows_size = insert_stmt->values_.size();
for (size_t i = 0; i < total_rows_size; i++) {
Expand Down

0 comments on commit 6b6de97

Please sign in to comment.