diff --git a/.github/workflows/build-extensions.yml b/.github/workflows/build-extensions.yml index ab557dc4cb6..12772c6a252 100644 --- a/.github/workflows/build-extensions.yml +++ b/.github/workflows/build-extensions.yml @@ -12,15 +12,13 @@ jobs: run: make extension-release NUM_THREADS=$(nproc) - name: Collect built artifacts - run: | - mkdir -p extension-artifacts - find extension -type f -name "*.kuzu_extension" -exec cp {} extension-artifacts \; + run: /opt/python/cp310-cp310/bin/python scripts/collect-extensions.py - name: Upload built artifacts uses: actions/upload-artifact@v4 with: name: kuzu-extensions_linux-x86_64 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts build-linux-extensions-aarch64: runs-on: kuzu-self-hosted-linux-building-aarch64 @@ -32,15 +30,13 @@ jobs: run: make extension-release NUM_THREADS=$(nproc) - name: Collect built artifacts - run: | - mkdir -p extension-artifacts - find extension -type f -name "*.kuzu_extension" -exec cp {} extension-artifacts \; + run: /opt/python/cp310-cp310/bin/python scripts/collect-extensions.py - name: Upload built artifacts uses: actions/upload-artifact@v4 with: name: kuzu-extensions_linux-aarch64 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts build-linux-extensions-x86: runs-on: ubuntu-latest @@ -52,7 +48,7 @@ jobs: run: | docker run -d --name kuzu-x86 \ -v $PWD:/kuzu -w /kuzu \ - -e NUM_THREADS=2 -e CC=gcc -e CXX=g++ \ + -e NUM_THREADS=$(nproc) -e CC=gcc -e CXX=g++ \ i386/debian:stable tail -f /dev/null - name: Install dependencies @@ -66,8 +62,7 @@ jobs: - name: Collect built artifacts run: | - mkdir -p extension-artifacts - find extension -type f -name "*.kuzu_extension" -exec cp {} extension-artifacts \; + python3 scripts/collect-extensions.py sudo chown -R $USER:$USER extension-artifacts sudo chmod -R 755 extension-artifacts @@ -75,7 +70,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: kuzu-extensions_linux-x86 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts - name: Stop Docker container run: docker stop kuzu-x86 @@ -92,15 +87,13 @@ jobs: CMAKE_OSX_ARCHITECTURES: "arm64" - name: Collect built artifacts - run: | - mkdir -p extension-artifacts - find extension -type f -name "*.kuzu_extension" -exec cp {} extension-artifacts \; + run: python3 scripts/collect-extensions.py - name: Upload built artifacts uses: actions/upload-artifact@v4 with: name: kuzu-extensions_osx-arm64 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts build-mac-extensions-x86_64: runs-on: self-hosted-mac-x64 @@ -114,15 +107,13 @@ jobs: CMAKE_OSX_ARCHITECTURES: "x86_64" - name: Collect built artifacts - run: | - mkdir -p extension-artifacts - find extension -type f -name "*.kuzu_extension" -exec cp {} extension-artifacts \; + run: python3 scripts/collect-extensions.py - name: Upload built artifacts uses: actions/upload-artifact@v4 with: name: kuzu-extensions_osx-x86_64 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts build-windows-extensions-x86_64: runs-on: self-hosted-windows @@ -138,15 +129,13 @@ jobs: make extension-release - name: Collect built artifacts - run: | - mkdir -p extension-artifacts - Get-ChildItem -Path .\extension -Recurse -Filter "*.kuzu_extension" | ForEach-Object {Copy-Item $_.FullName -Destination extension-artifacts -Force} + run: python scripts/collect-extensions.py - name: Upload built artifacts uses: actions/upload-artifact@v4 with: name: kuzu-extensions_win-x86_64 - path: extension-artifacts/*.kuzu_extension + path: extension-artifacts update-extensions-repo: runs-on: ubuntu-latest @@ -172,23 +161,9 @@ jobs: repository: kuzudb/extension token: ${{ secrets.DOC_PUSH_TOKEN }} - - name: Ensure extension directories - run: | - mkdir -p releases/$RELEASE_VERSION/linux_amd64 - mkdir -p releases/$RELEASE_VERSION/linux_arm64 - mkdir -p releases/$RELEASE_VERSION/linux_x86 - mkdir -p releases/$RELEASE_VERSION/osx_amd64 - mkdir -p releases/$RELEASE_VERSION/osx_arm64 - mkdir -p releases/$RELEASE_VERSION/win_amd64 - - name: Clear old artifacts run: | - rm -rf releases/$RELEASE_VERSION/linux_amd64/* - rm -rf releases/$RELEASE_VERSION/linux_arm64/* - rm -rf releases/$RELEASE_VERSION/linux_x86/* - rm -rf releases/$RELEASE_VERSION/osx_amd64/* - rm -rf releases/$RELEASE_VERSION/osx_arm64/* - rm -rf releases/$RELEASE_VERSION/win_amd64/* + rm -rf releases/$RELEASE_VERSION/ - name: Create temporary directory for artifacts run: | @@ -238,24 +213,14 @@ jobs: - name: Copy built artifacts run: | - cp extension-artifacts/linux_amd64/*.kuzu_extension releases/$RELEASE_VERSION/linux_amd64 - cp extension-artifacts/linux_arm64/*.kuzu_extension releases/$RELEASE_VERSION/linux_arm64 - cp extension-artifacts/linux_x86/*.kuzu_extension releases/$RELEASE_VERSION/linux_x86 - cp extension-artifacts/osx_amd64/*.kuzu_extension releases/$RELEASE_VERSION/osx_amd64 - cp extension-artifacts/osx_arm64/*.kuzu_extension releases/$RELEASE_VERSION/osx_arm64 - cp extension-artifacts/win_amd64/*.kuzu_extension releases/$RELEASE_VERSION/win_amd64 + rsync -av extension-artifacts/ releases/$RELEASE_VERSION/ - name: Remove temporary directory run: rm -rf extension-artifacts - name: Set artifact permissions run: | - chmod 755 releases/$RELEASE_VERSION/linux_amd64/* - chmod 755 releases/$RELEASE_VERSION/linux_arm64/* - chmod 755 releases/$RELEASE_VERSION/linux_x86/* - chmod 755 releases/$RELEASE_VERSION/osx_amd64/* - chmod 755 releases/$RELEASE_VERSION/osx_arm64/* - chmod 755 releases/$RELEASE_VERSION/win_amd64/* + chmod -R 755 releases/$RELEASE_VERSION - name: List downloaded artifacts run: ls -R releases/$RELEASE_VERSION diff --git a/CMakeLists.txt b/CMakeLists.txt index 0dc86b1157b..fd727bd92ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -254,7 +254,7 @@ add_subdirectory(third_party) if(${BUILD_KUZU}) add_definitions(-DKUZU_ROOT_DIRECTORY="${PROJECT_SOURCE_DIR}") add_definitions(-DKUZU_CMAKE_VERSION="${CMAKE_PROJECT_VERSION}") -add_definitions(-DKUZU_EXTENSION_VERSION="0.5.1.11") +add_definitions(-DKUZU_EXTENSION_VERSION="0.5.2.0") include_directories(src/include) diff --git a/Makefile b/Makefile index 5913317afef..aa2ff22c0ca 100644 --- a/Makefile +++ b/Makefile @@ -171,7 +171,6 @@ extension-test-build: $(call run-cmake-release, \ -DBUILD_EXTENSIONS="httpfs;duckdb;json;postgres;sqlite" \ -DBUILD_EXTENSION_TESTS=TRUE \ - -DENABLE_ADDRESS_SANITIZER=TRUE \ ) extension-json-test-build: diff --git a/extension/duckdb/CMakeLists.txt b/extension/duckdb/CMakeLists.txt index 4961b8bb37b..d90c4672acd 100644 --- a/extension/duckdb/CMakeLists.txt +++ b/extension/duckdb/CMakeLists.txt @@ -1,14 +1,10 @@ -# Address sanitizer may change the ABI, so we use dynamic linking -# if it's enabled. -if(${ENABLE_ADDRESS_SANITIZER}) - set(DuckDB_USE_STATIC_LIBS OFF) -else() - set(DuckDB_USE_STATIC_LIBS ON) -endif() +if (WIN32) + set(DuckDB_USE_STATIC_LIBS ON) +else () + set(DuckDB_USE_STATIC_LIBS OFF) +endif () + find_package(DuckDB REQUIRED) -if(NOT WIN32) - find_library(DuckDB_STATIC REQUIRED NAMES libduckdb_static.a) -endif() include_directories( ${PROJECT_SOURCE_DIR}/src/include @@ -44,6 +40,43 @@ set_target_properties(duckdb_extension PROPERTIES SUFFIX ".kuzu_extension" ) +if (NOT WIN32) + add_library(duckdb_extension_installer + SHARED + src/duckdb_installer.cpp + src/duckdb_install_func.cpp) + + add_library(duckdb_extension_loader + SHARED + src/duckdb_loader.cpp) + + set_target_properties(duckdb_extension_installer + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) + + set_target_properties(duckdb_extension_loader + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) + + set_target_properties(duckdb_extension_installer PROPERTIES + OUTPUT_NAME duckdb_installer + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) + + set_target_properties(duckdb_extension_loader PROPERTIES + OUTPUT_NAME duckdb_loader + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) +endif () + if (WIN32) # On windows, there is no dynamic lookup available, so it's not # possible to generically look for symbols on library load. There are @@ -66,4 +99,6 @@ endif() if (APPLE) set_target_properties(duckdb_extension PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(duckdb_extension_installer PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(duckdb_extension_loader PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif () diff --git a/extension/duckdb/src/duckdb_install_func.cpp b/extension/duckdb/src/duckdb_install_func.cpp new file mode 100644 index 00000000000..9dd93190942 --- /dev/null +++ b/extension/duckdb/src/duckdb_install_func.cpp @@ -0,0 +1,15 @@ +#include "duckdb_installer.h" + +extern "C" { +// Because we link against the static library on windows, we implicitly inherit KUZU_STATIC_DEFINE, +// which cancels out any exporting, so we can't use KUZU_API. +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif +INIT_EXPORT void install(kuzu::main::ClientContext* context) { + kuzu::duckdb_extension::DuckDBInstaller installer{"duckdb"}; + installer.install(context); +} +} diff --git a/extension/duckdb/src/duckdb_installer.cpp b/extension/duckdb/src/duckdb_installer.cpp new file mode 100644 index 00000000000..ec4deb8ab2a --- /dev/null +++ b/extension/duckdb/src/duckdb_installer.cpp @@ -0,0 +1,29 @@ +#include "duckdb_installer.h" + +#include "common/file_system/virtual_file_system.h" +#include "duckdb_extension.h" +#include "main/client_context.h" + +namespace kuzu { +namespace duckdb_extension { + +void DuckDBInstaller::install(main::ClientContext* context) { + auto loaderFileRepoInfo = extension::ExtensionUtils::getExtensionLoaderRepoInfo(extensionName); + auto localLoaderFilePath = + extension::ExtensionUtils::getLocalPathForExtensionLoader(context, extensionName); + tryDownloadExtensionFile(context, loaderFileRepoInfo, localLoaderFilePath); + + for (auto& dependencyLib : DuckDBExtension::DEPENDENCY_LIB_FILES) { + auto dependencyLibWithSuffix = extension::ExtensionUtils::appendLibSuffix(dependencyLib); + auto localDependencyLibPath = + extension::ExtensionUtils::getLocalPathForSharedLib(context, dependencyLibWithSuffix); + if (!context->getVFSUnsafe()->fileOrPathExists(localDependencyLibPath)) { + auto dependencyLibRepoInfo = + extension::ExtensionUtils::getSharedLibRepoInfo(dependencyLibWithSuffix); + tryDownloadExtensionFile(context, dependencyLibRepoInfo, localDependencyLibPath); + } + } +} + +} // namespace duckdb_extension +} // namespace kuzu diff --git a/extension/duckdb/src/duckdb_loader.cpp b/extension/duckdb/src/duckdb_loader.cpp new file mode 100644 index 00000000000..91749984927 --- /dev/null +++ b/extension/duckdb/src/duckdb_loader.cpp @@ -0,0 +1,32 @@ +#include "duckdb_loader.h" + +#include "duckdb_extension.h" + +namespace kuzu { +namespace duckdb_extension { + +void DuckDBLoader::loadDependency(main::ClientContext* context) { + for (auto& dependencyLib : DuckDBExtension::DEPENDENCY_LIB_FILES) { + auto dependencyLibWithSuffix = extension::ExtensionUtils::appendLibSuffix(dependencyLib); + auto dependencyLibPath = + extension::ExtensionUtils::getLocalPathForSharedLib(context, dependencyLibWithSuffix); + auto dependencyLoader = extension::ExtensionLibLoader(extensionName, dependencyLibPath); + } +} + +} // namespace duckdb_extension +} // namespace kuzu + +extern "C" { +// Because we link against the static library on windows, we implicitly inherit KUZU_STATIC_DEFINE, +// which cancels out any exporting, so we can't use KUZU_API. +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif +INIT_EXPORT void load(kuzu::main::ClientContext* context) { + kuzu::duckdb_extension::DuckDBLoader loader{"duckdb"}; + loader.loadDependency(context); +} +} diff --git a/extension/duckdb/src/include/duckdb_extension.h b/extension/duckdb/src/include/duckdb_extension.h index 4c71a3f2cd9..e340b588bb7 100644 --- a/extension/duckdb/src/include/duckdb_extension.h +++ b/extension/duckdb/src/include/duckdb_extension.h @@ -8,6 +8,7 @@ namespace duckdb_extension { class DuckDBExtension final : public extension::Extension { public: static constexpr char EXTENSION_NAME[] = "DUCKDB"; + static constexpr const char* DEPENDENCY_LIB_FILES[] = {"libduckdb"}; public: static void load(main::ClientContext* context); diff --git a/extension/duckdb/src/include/duckdb_installer.h b/extension/duckdb/src/include/duckdb_installer.h new file mode 100644 index 00000000000..bb461e11c36 --- /dev/null +++ b/extension/duckdb/src/include/duckdb_installer.h @@ -0,0 +1,16 @@ +#pragma once +#include "extension/extension_installer.h" + +namespace kuzu { +namespace duckdb_extension { + +class DuckDBInstaller : public extension::ExtensionInstaller { +public: + explicit DuckDBInstaller(const std::string extensionName) + : ExtensionInstaller{std::move(extensionName)} {} + + void install(main::ClientContext* context) override; +}; + +} // namespace duckdb_extension +} // namespace kuzu diff --git a/extension/duckdb/src/include/duckdb_loader.h b/extension/duckdb/src/include/duckdb_loader.h new file mode 100644 index 00000000000..5f2603394b1 --- /dev/null +++ b/extension/duckdb/src/include/duckdb_loader.h @@ -0,0 +1,16 @@ +#pragma once +#include "extension/extension_loader.h" + +namespace kuzu { +namespace duckdb_extension { + +class DuckDBLoader final : public extension::ExtensionLoader { +public: + explicit DuckDBLoader(std::string extensionName) + : extension::ExtensionLoader{std::move(extensionName)} {} + + void loadDependency(main::ClientContext* context) override; +}; + +} // namespace duckdb_extension +} // namespace kuzu diff --git a/extension/postgres/CMakeLists.txt b/extension/postgres/CMakeLists.txt index 81c24cdddbf..0ae45b52ac0 100644 --- a/extension/postgres/CMakeLists.txt +++ b/extension/postgres/CMakeLists.txt @@ -1,14 +1,10 @@ -# Address sanitizer may change the ABI, so we use dynamic linking -# if it's enabled. -if(${ENABLE_ADDRESS_SANITIZER}) - set(DuckDB_USE_STATIC_LIBS OFF) -else() - set(DuckDB_USE_STATIC_LIBS ON) -endif() +if (WIN32) + set(DuckDB_USE_STATIC_LIBS ON) +else () + set(DuckDB_USE_STATIC_LIBS OFF) +endif () + find_package(DuckDB REQUIRED) -if(NOT WIN32) - find_library(DuckDB_STATIC REQUIRED NAMES libduckdb_static.a) -endif() add_library(postgres_extension SHARED @@ -45,6 +41,43 @@ target_link_libraries(postgres_extension PRIVATE ${DuckDB_LIBRARIES}) +if (NOT WIN32) + add_library(postgres_extension_installer + SHARED + src/postgres_install_func.cpp + ../duckdb/src/duckdb_installer.cpp) + + add_library(postgres_extension_loader + SHARED + ../duckdb/src/duckdb_loader.cpp) + + set_target_properties(postgres_extension_loader PROPERTIES + OUTPUT_NAME postgres_loader + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) + + set_target_properties(postgres_extension_installer PROPERTIES + OUTPUT_NAME postgres_installer + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) + + set_target_properties(postgres_extension_installer + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) + + set_target_properties(postgres_extension_loader + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) +endif () + if (WIN32) # On windows, there is no dynamic lookup available, so it's not # possible to generically look for symbols on library load. There are @@ -67,4 +100,6 @@ endif() if (APPLE) set_target_properties(postgres_extension PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(postgres_extension_loader PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(postgres_extension_installer PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif () diff --git a/extension/postgres/src/postgres_install_func.cpp b/extension/postgres/src/postgres_install_func.cpp new file mode 100644 index 00000000000..3e0635d3e23 --- /dev/null +++ b/extension/postgres/src/postgres_install_func.cpp @@ -0,0 +1,15 @@ +#include "duckdb_installer.h" + +extern "C" { +// Because we link against the static library on windows, we implicitly inherit KUZU_STATIC_DEFINE, +// which cancels out any exporting, so we can't use KUZU_API. +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif +INIT_EXPORT void install(kuzu::main::ClientContext* context) { + kuzu::duckdb_extension::DuckDBInstaller installer{"postgres"}; + installer.install(context); +} +} diff --git a/extension/sqlite/CMakeLists.txt b/extension/sqlite/CMakeLists.txt index bb4db9e0969..1ee4a8f488f 100644 --- a/extension/sqlite/CMakeLists.txt +++ b/extension/sqlite/CMakeLists.txt @@ -1,14 +1,9 @@ -# Address sanitizer may change the ABI, so we use dynamic linking -# if it's enabled. -if(${ENABLE_ADDRESS_SANITIZER}) - set(DuckDB_USE_STATIC_LIBS OFF) -else() - set(DuckDB_USE_STATIC_LIBS ON) -endif() +if (WIN32) + set(DuckDB_USE_STATIC_LIBS ON) +else () + set(DuckDB_USE_STATIC_LIBS OFF) +endif () find_package(DuckDB REQUIRED) -if(NOT WIN32) - find_library(DuckDB_STATIC REQUIRED NAMES libduckdb_static.a) -endif() include_directories( ${PROJECT_SOURCE_DIR}/src/include @@ -45,6 +40,43 @@ set_target_properties(sqlite_extension PROPERTIES SUFFIX ".kuzu_extension" ) +if (NOT WIN32) + add_library(sqlite_extension_installer + SHARED + src/sqlite_install_func.cpp + ../duckdb/src/duckdb_installer.cpp) + + add_library(sqlite_extension_loader + SHARED + ../duckdb/src/duckdb_loader.cpp) + + set_target_properties(sqlite_extension_loader + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) + + set_target_properties(sqlite_extension_installer + PROPERTIES + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build" + ) + + set_target_properties(sqlite_extension_loader PROPERTIES + OUTPUT_NAME sqlite_loader + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) + + set_target_properties(sqlite_extension_installer PROPERTIES + OUTPUT_NAME sqlite_installer + PREFIX "lib" + SUFFIX ".kuzu_extension" + ) +endif () + if (WIN32) # On windows, there is no dynamic lookup available, so it's not # possible to generically look for symbols on library load. There are @@ -67,4 +99,6 @@ endif() if (APPLE) set_target_properties(sqlite_extension PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(sqlite_extension_loader PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + set_target_properties(sqlite_extension_installer PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif () diff --git a/extension/sqlite/src/sqlite_connector.cpp b/extension/sqlite/src/sqlite_connector.cpp index e394f4f7768..95889e45a5a 100644 --- a/extension/sqlite/src/sqlite_connector.cpp +++ b/extension/sqlite/src/sqlite_connector.cpp @@ -8,6 +8,8 @@ void SqliteConnector::connect(const std::string& dbPath, const std::string& cata // Creates an in-memory duckdb instance, then install httpfs and attach SQLITE. instance = std::make_unique(nullptr); connection = std::make_unique(*instance); + executeQuery("install sqlite;"); + executeQuery("load sqlite;"); executeQuery( common::stringFormat("attach '{}' as {} (TYPE sqlite, read_only)", dbPath, catalogName)); } diff --git a/extension/sqlite/src/sqlite_install_func.cpp b/extension/sqlite/src/sqlite_install_func.cpp new file mode 100644 index 00000000000..9ffeaa516cb --- /dev/null +++ b/extension/sqlite/src/sqlite_install_func.cpp @@ -0,0 +1,15 @@ +#include "duckdb_installer.h" + +extern "C" { +// Because we link against the static library on windows, we implicitly inherit KUZU_STATIC_DEFINE, +// which cancels out any exporting, so we can't use KUZU_API. +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif +INIT_EXPORT void install(kuzu::main::ClientContext* context) { + kuzu::duckdb_extension::DuckDBInstaller installer{"sqlite"}; + installer.install(context); +} +} diff --git a/scripts/collect-extensions.py b/scripts/collect-extensions.py new file mode 100644 index 00000000000..de7b496e4dd --- /dev/null +++ b/scripts/collect-extensions.py @@ -0,0 +1,58 @@ +import os +import shutil +import platform + +FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +DST_DIR = os.path.abspath(os.path.join(FILE_DIR, "..", "extension-artifacts")) +SRC_DIR = os.path.abspath(os.path.join(FILE_DIR, "..", "extension")) + + +def collect_exts(): + for ext in os.listdir(SRC_DIR): + ext_build_path = os.path.abspath(os.path.join(SRC_DIR, ext, "build")) + if not os.path.exists(ext_build_path): + continue + print("Found extension: " + ext) + ext_dst_path = os.path.abspath(os.path.join(DST_DIR, ext)) + os.makedirs(ext_dst_path, exist_ok=True) + for f in os.listdir(ext_build_path): + if not f.endswith(".kuzu_extension"): + continue + ext_file_path = os.path.abspath(os.path.join(ext_build_path, f)) + shutil.copy(ext_file_path, ext_dst_path) + print(" \tCopied: " + f, "=>", ext_dst_path) + + +def find_duckdb(): + if platform.system() == 'Darwin': + candidates = ["/usr/local/lib/libduckdb.dylib", "/opt/homebrew/lib/libduckdb.dylib"] + elif platform.system() == 'Linux': + candidates = ["/usr/local/lib/libduckdb.so", "/usr/lib/libduckdb.so", "/usr/lib64/libduckdb.so"] + elif platform.system() == 'Windows': + candidates = ["C:\\Program Files\\duckdb\\build\\release\\src\\Release\\duckdb.lib"] + for candidate in candidates: + if os.path.exists(candidate): + return os.path.abspath(candidate) + return None + + +def copy_duckdb(): + duckdb_dst_path = os.path.abspath(os.path.join(DST_DIR, "common")) + os.makedirs(duckdb_dst_path, exist_ok=True) + duckdb_path = find_duckdb() + if duckdb_path is None: + print("DuckDB not found, copying is skipped") + return + shutil.copy(duckdb_path, duckdb_dst_path) + print("Copied DuckDB: " + duckdb_path, "=>", duckdb_dst_path) + + +def main(): + shutil.rmtree(DST_DIR, ignore_errors=True) + os.makedirs(DST_DIR, exist_ok=True) + collect_exts() + copy_duckdb() + + +if __name__ == "__main__": + main() diff --git a/src/extension/CMakeLists.txt b/src/extension/CMakeLists.txt index f2072d83694..edd7011b777 100644 --- a/src/extension/CMakeLists.txt +++ b/src/extension/CMakeLists.txt @@ -1,7 +1,8 @@ add_library(kuzu_extension OBJECT catalog_extension.cpp - extension.cpp) + extension.cpp + extension_installer.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/extension/extension.cpp b/src/extension/extension.cpp index 6605fddcc54..301e21d39ac 100644 --- a/src/extension/extension.cpp +++ b/src/extension/extension.cpp @@ -1,11 +1,24 @@ #include "extension/extension.h" #include "catalog/catalog.h" +#include "common/exception/io.h" +#include "common/file_system/virtual_file_system.h" #include "common/string_format.h" #include "common/string_utils.h" +#include "common/system_message.h" #include "function/table_functions.h" +#include "main/client_context.h" #include "main/database.h" #include "transaction/transaction.h" +#ifdef _WIN32 + +#include "windows.h" +#define RTLD_NOW 0 +#define RTLD_LOCAL 0 + +#else +#include +#endif namespace kuzu { namespace extension { @@ -39,25 +52,71 @@ std::string getPlatform() { return getOS() + "_" + getArch(); } -std::string ExtensionUtils::getExtensionPath(const std::string& extensionDir, - const std::string& name) { - return common::stringFormat("{}/lib{}.kuzu_extension", extensionDir, name); -} - bool ExtensionUtils::isFullPath(const std::string& extension) { return extension.find('.') != std::string::npos || extension.find('/') != std::string::npos || extension.find('\\') != std::string::npos; } -ExtensionRepoInfo ExtensionUtils::getExtensionRepoInfo(const std::string& extension) { - auto extensionURL = common::stringFormat(EXTENSION_REPO, KUZU_EXTENSION_VERSION, getPlatform(), - common::StringUtils::getLower(extension)); +static ExtensionRepoInfo getExtensionRepoInfo(std::string& extensionURL) { common::StringUtils::replaceAll(extensionURL, "http://", ""); auto hostNamePos = extensionURL.find('/'); auto hostName = extensionURL.substr(0, hostNamePos); auto hostURL = "http://" + hostName; auto hostPath = extensionURL.substr(hostNamePos); return {hostPath, hostURL, extensionURL}; +}; + +ExtensionRepoInfo ExtensionUtils::getExtensionLibRepoInfo(const std::string& extensionName) { + auto extensionURL = common::stringFormat(EXTENSION_FILE_REPO, KUZU_EXTENSION_VERSION, + getPlatform(), extensionName, getExtensionFileName(extensionName)); + return getExtensionRepoInfo(extensionURL); +} + +ExtensionRepoInfo ExtensionUtils::getExtensionLoaderRepoInfo(const std::string& extensionName) { + auto extensionURL = + common::stringFormat(EXTENSION_FILE_REPO, KUZU_EXTENSION_VERSION, getPlatform(), + extensionName, getExtensionFileName(extensionName + EXTENSION_LOADER_SUFFIX)); + return getExtensionRepoInfo(extensionURL); +} + +ExtensionRepoInfo ExtensionUtils::getExtensionInstallerRepoInfo(const std::string& extensionName) { + auto extensionURL = + common::stringFormat(EXTENSION_FILE_REPO, KUZU_EXTENSION_VERSION, getPlatform(), + extensionName, getExtensionFileName(extensionName + EXTENSION_INSTALLER_SUFFIX)); + return getExtensionRepoInfo(extensionURL); +} + +ExtensionRepoInfo ExtensionUtils::getSharedLibRepoInfo(const std::string& fileName) { + auto extensionURL = + common::stringFormat(SHARED_LIB_REPO, KUZU_EXTENSION_VERSION, getPlatform(), fileName); + return getExtensionRepoInfo(extensionURL); +} + +std::string ExtensionUtils::getExtensionFileName(const std::string& name) { + return common::stringFormat(EXTENSION_FILE_NAME, common::StringUtils::getLower(name)); +} + +std::string ExtensionUtils::getLocalPathForExtensionLib(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}{}/{}", context->getExtensionDir(), extensionName, + getExtensionFileName(extensionName)); +} + +std::string ExtensionUtils::getLocalPathForExtensionLoader(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}{}/{}", context->getExtensionDir(), extensionName, + getExtensionFileName(extensionName + EXTENSION_LOADER_SUFFIX)); +} + +std::string ExtensionUtils::getLocalPathForExtensionInstaller(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}{}/{}", context->getExtensionDir(), extensionName, + getExtensionFileName(extensionName + EXTENSION_INSTALLER_SUFFIX)); +} + +std::string ExtensionUtils::getLocalExtensionDir(main::ClientContext* context, + const std::string& extensionName) { + return common::stringFormat("{}{}", context->getExtensionDir(), extensionName); } void ExtensionUtils::registerTableFunction(main::Database& database, @@ -74,6 +133,28 @@ void ExtensionUtils::registerTableFunction(main::Database& database, catalog::CatalogEntryType::TABLE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)); } +std::string ExtensionUtils::appendLibSuffix(const std::string& libName) { + auto os = getOS(); + std::string suffix; + if (os == "linux" || os == "linux_old") { + suffix = "so"; + } else if (os == "osx") { + suffix = "dylib"; + } else { + KU_UNREACHABLE; + } + return common::stringFormat("{}.{}", libName, suffix); +} + +std::string ExtensionUtils::getLocalPathForSharedLib(main::ClientContext* context, + const std::string& libName) { + return common::stringFormat("{}common/{}", context->getExtensionDir(), libName); +} + +std::string ExtensionUtils::getLocalPathForSharedLib(main::ClientContext* context) { + return common::stringFormat("{}common/", context->getExtensionDir()); +} + void ExtensionUtils::registerFunctionSet(main::Database& database, std::string name, function::function_set functionSet) { auto catalog = database.getCatalog(); @@ -95,6 +176,38 @@ bool ExtensionUtils::isOfficialExtension(const std::string& extension) { return false; } +ExtensionLibLoader::ExtensionLibLoader(const std::string& extensionName, const std::string& path) + : extensionName{extensionName} { + libHdl = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (libHdl == nullptr) { + throw common::IOException(common::stringFormat( + "Failed to load library: {} which is needed by extension: {}.\nError: {}.", path, + extensionName, common::dlErrMessage())); + } +} + +ext_load_func_t ExtensionLibLoader::getLoadFunc() { + return (ext_load_func_t)getDynamicLibFunc(EXTENSION_LOAD_FUNC_NAME); +} + +ext_init_func_t ExtensionLibLoader::getInitFunc() { + return (ext_init_func_t)getDynamicLibFunc(EXTENSION_INIT_FUNC_NAME); +} + +ext_install_func_t ExtensionLibLoader::getInstallFunc() { + return (ext_install_func_t)getDynamicLibFunc(EXTENSION_INSTALL_FUNC_NAME); +} + +void* ExtensionLibLoader::getDynamicLibFunc(const std::string& funcName) { + auto sym = dlsym(libHdl, funcName.c_str()); + if (sym == nullptr) { + throw common::IOException( + common::stringFormat("Failed to load {} function in extension {}.\nError: {}", funcName, + extensionName, common::dlErrMessage())); + } + return sym; +} + void ExtensionOptions::addExtensionOption(std::string name, common::LogicalTypeID type, common::Value defaultValue) { common::StringUtils::toLower(name); @@ -106,5 +219,33 @@ main::ExtensionOption* ExtensionOptions::getExtensionOption(std::string name) { return extensionOptions.contains(name) ? &extensionOptions.at(name) : nullptr; } +#ifdef _WIN32 +std::wstring utf8ToUnicode(const char* input) { + uint32_t result; + + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); + if (result == 0) { + throw common::IOException("Failure in MultiByteToWideChar"); + } + auto buffer = std::make_unique(result); + result = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result); + if (result == 0) { + throw common::IOException("Failure in MultiByteToWideChar"); + } + return std::wstring(buffer.get(), result); +} + +void* dlopen(const char* file, int /*mode*/) { + KU_ASSERT(file); + auto fpath = utf8ToUnicode(file); + return (void*)LoadLibraryW(fpath.c_str()); +} + +void* dlsym(void* handle, const char* name) { + KU_ASSERT(handle); + return (void*)GetProcAddress((HINSTANCE)handle, name); +} +#endif + } // namespace extension } // namespace kuzu diff --git a/src/extension/extension_installer.cpp b/src/extension/extension_installer.cpp new file mode 100644 index 00000000000..8693a93ca3d --- /dev/null +++ b/src/extension/extension_installer.cpp @@ -0,0 +1,60 @@ +#include "extension/extension_installer.h" + +#include "common/exception/io.h" +#include "common/file_system/virtual_file_system.h" +#include "httplib.h" +#include "main/client_context.h" + +namespace kuzu { +namespace extension { + +void ExtensionInstaller::tryDownloadExtensionFile(main::ClientContext* context, + const ExtensionRepoInfo& info, const std::string& localFilePath) { + httplib::Client cli(info.hostURL.c_str()); + httplib::Headers headers = { + {"User-Agent", common::stringFormat("kuzu/v{}", KUZU_EXTENSION_VERSION)}}; + auto res = cli.Get(info.hostPath.c_str(), headers); + if (!res || res->status != 200) { + if (res.error() == httplib::Error::Success) { + // LCOV_EXCL_START + throw common::IOException(common::stringFormat( + "HTTP Returns: {}, Failed to download extension: \"{}\" from {}.", + res.value().status, extensionName, info.repoURL)); + // LCOC_EXCL_STOP + } else { + throw common::IOException( + common::stringFormat("Failed to download extension: {} at URL {} (ERROR: {})", + extensionName, info.repoURL, to_string(res.error()))); + } + } + + auto vfs = context->getVFSUnsafe(); + auto fileInfo = vfs->openFile(localFilePath, O_WRONLY | O_CREAT); + fileInfo->writeFile(reinterpret_cast(res->body.c_str()), res->body.size(), + 0 /* offset */); + fileInfo->syncFile(); +} + +void ExtensionInstaller::install(main::ClientContext* context) { + auto vfs = context->getVFSUnsafe(); + auto localExtensionDir = context->getExtensionDir(); + if (!vfs->fileOrPathExists(localExtensionDir, context)) { + vfs->createDir(localExtensionDir); + } + auto localDirForExtension = + extension::ExtensionUtils::getLocalExtensionDir(context, extensionName); + if (!vfs->fileOrPathExists(localDirForExtension)) { + vfs->createDir(localDirForExtension); + } + auto localDirForSharedLib = extension::ExtensionUtils::getLocalPathForSharedLib(context); + if (!vfs->fileOrPathExists(localDirForSharedLib)) { + vfs->createDir(localDirForSharedLib); + } + auto libFileRepoInfo = extension::ExtensionUtils::getExtensionLibRepoInfo(extensionName); + auto localLibFilePath = + extension::ExtensionUtils::getLocalPathForExtensionLib(context, extensionName); + tryDownloadExtensionFile(context, libFileRepoInfo, localLibFilePath); +} + +} // namespace extension +} // namespace kuzu diff --git a/src/include/common/system_message.h b/src/include/common/system_message.h index fde046d1710..43ba8df7ee3 100644 --- a/src/include/common/system_message.h +++ b/src/include/common/system_message.h @@ -4,6 +4,8 @@ #include #include +#include "common/api.h" + namespace kuzu { namespace common { @@ -21,7 +23,7 @@ inline std::string posixErrMessage() { // LCOV_EXCL_STOP } -std::string dlErrMessage(); +KUZU_API std::string dlErrMessage(); } // namespace common } // namespace kuzu diff --git a/src/include/extension/extension.h b/src/include/extension/extension.h index 0d4603502dc..43ab4a418ea 100644 --- a/src/include/extension/extension.h +++ b/src/include/extension/extension.h @@ -21,6 +21,10 @@ class Database; namespace extension { +typedef void (*ext_init_func_t)(kuzu::main::ClientContext*); +using ext_load_func_t = ext_init_func_t; +using ext_install_func_t = ext_init_func_t; + std::string getPlatform(); class KUZU_API Extension { @@ -35,17 +39,50 @@ struct ExtensionRepoInfo { }; struct ExtensionUtils { - static constexpr const char* EXTENSION_REPO = - "http://extension.kuzudb.com/v{}/{}/lib{}.kuzu_extension"; + static constexpr const char* EXTENSION_FILE_REPO = "http://extension.kuzudb.com/v{}/{}/{}/{}"; + + static constexpr const char* SHARED_LIB_REPO = "http://extension.kuzudb.com/v{}/{}/common/{}"; + + static constexpr const char* EXTENSION_FILE_NAME = "lib{}.kuzu_extension"; static constexpr const char* OFFICIAL_EXTENSION[] = {"HTTPFS", "POSTGRES", "DUCKDB", "JSON", "SQLITE"}; - static std::string getExtensionPath(const std::string& extensionDir, const std::string& name); + static constexpr const char* EXTENSION_LOADER_SUFFIX = "_loader"; + + static constexpr const char* EXTENSION_INSTALLER_SUFFIX = "_installer"; static bool isFullPath(const std::string& extension); - static ExtensionRepoInfo getExtensionRepoInfo(const std::string& extension); + static KUZU_API ExtensionRepoInfo getExtensionLibRepoInfo(const std::string& extensionName); + + static KUZU_API ExtensionRepoInfo getExtensionLoaderRepoInfo(const std::string& extensionName); + + static KUZU_API ExtensionRepoInfo getExtensionInstallerRepoInfo( + const std::string& extensionName); + + static KUZU_API ExtensionRepoInfo getSharedLibRepoInfo(const std::string& fileName); + + static std::string getExtensionFileName(const std::string& name); + + KUZU_API static std::string getLocalPathForExtensionLib(main::ClientContext* context, + const std::string& extensionName); + + KUZU_API static std::string getLocalPathForExtensionLoader(main::ClientContext* context, + const std::string& extensionName); + + KUZU_API static std::string getLocalPathForExtensionInstaller(main::ClientContext* context, + const std::string& extensionName); + + KUZU_API static std::string getLocalExtensionDir(main::ClientContext* context, + const std::string& extensionName); + + KUZU_API static std::string appendLibSuffix(const std::string& libName); + + KUZU_API static std::string getLocalPathForSharedLib(main::ClientContext* context, + const std::string& libName); + + KUZU_API static std::string getLocalPathForSharedLib(main::ClientContext* context); KUZU_API static void registerTableFunction(main::Database& database, std::unique_ptr function); @@ -56,6 +93,31 @@ struct ExtensionUtils { static bool isOfficialExtension(const std::string& extension); }; +class KUZU_API ExtensionLibLoader { +public: + static constexpr const char* EXTENSION_LOAD_FUNC_NAME = "load"; + + static constexpr const char* EXTENSION_INIT_FUNC_NAME = "init"; + + static constexpr const char* EXTENSION_INSTALL_FUNC_NAME = "install"; + +public: + ExtensionLibLoader(const std::string& extensionName, const std::string& path); + + ext_load_func_t getLoadFunc(); + + ext_init_func_t getInitFunc(); + + ext_install_func_t getInstallFunc(); + +private: + void* getDynamicLibFunc(const std::string& funcName); + +private: + std::string extensionName; + void* libHdl; +}; + struct ExtensionOptions { std::unordered_map extensionOptions; @@ -65,5 +127,13 @@ struct ExtensionOptions { main::ExtensionOption* getExtensionOption(std::string name); }; +#ifdef _WIN32 +std::wstring utf8ToUnicode(const char* input); + +void* dlopen(const char* file, int /*mode*/); + +void* dlsym(void* handle, const char* name); +#endif + } // namespace extension } // namespace kuzu diff --git a/src/include/extension/extension_installer.h b/src/include/extension/extension_installer.h new file mode 100644 index 00000000000..79f15d337db --- /dev/null +++ b/src/include/extension/extension_installer.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include "common/api.h" +#include "extension.h" + +namespace kuzu { +namespace main { +class ClientContext; +} +namespace extension { + +class KUZU_API ExtensionInstaller { +public: + explicit ExtensionInstaller(const std::string extensionName) + : extensionName{std::move(extensionName)} {} + + virtual ~ExtensionInstaller() = default; + + virtual void install(main::ClientContext* context); + +protected: + void tryDownloadExtensionFile(main::ClientContext* context, const ExtensionRepoInfo& info, + const std::string& localFilePath); + +protected: + std::string extensionName; +}; + +} // namespace extension +} // namespace kuzu diff --git a/src/include/extension/extension_loader.h b/src/include/extension/extension_loader.h new file mode 100644 index 00000000000..33d9ba215b5 --- /dev/null +++ b/src/include/extension/extension_loader.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "common/api.h" + +namespace kuzu { +namespace main { +class ClientContext; +} +namespace extension { + +class KUZU_API ExtensionLoader { +public: + explicit ExtensionLoader(const std::string extensionName) + : extensionName{std::move(extensionName)} {} + + virtual ~ExtensionLoader() = default; + + virtual void loadDependency(main::ClientContext* context) = 0; + +protected: + std::string extensionName; +}; + +} // namespace extension +} // namespace kuzu diff --git a/src/include/processor/operator/simple/install_extension.h b/src/include/processor/operator/simple/install_extension.h index 2e44f5217d3..9bc64317e76 100644 --- a/src/include/processor/operator/simple/install_extension.h +++ b/src/include/processor/operator/simple/install_extension.h @@ -38,10 +38,6 @@ class InstallExtension final : public Simple { } private: - std::string tryDownloadExtension(); - - void saveExtensionToLocalFile(const std::string& extensionData, main::ClientContext* context); - void installExtension(main::ClientContext* context); private: diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 7f11ad23a1b..37acaba0d88 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -163,7 +163,7 @@ extension::ExtensionOptions* ClientContext::getExtensionOptions() const { } std::string ClientContext::getExtensionDir() const { - return stringFormat("{}/.kuzu/extension/{}/{}", clientConfig.homeDirectory, + return stringFormat("{}/.kuzu/extension/{}/{}/", clientConfig.homeDirectory, KUZU_EXTENSION_VERSION, kuzu::extension::getPlatform()); } diff --git a/src/processor/operator/simple/install_extension.cpp b/src/processor/operator/simple/install_extension.cpp index 5a8fe22873e..8738c2901ea 100644 --- a/src/processor/operator/simple/install_extension.cpp +++ b/src/processor/operator/simple/install_extension.cpp @@ -1,9 +1,9 @@ #include "processor/operator/simple/install_extension.h" -#include "common/exception/io.h" #include "common/file_system/virtual_file_system.h" #include "common/string_format.h" #include "extension/extension.h" +#include "extension/extension_installer.h" #include "httplib.h" #include "main/database.h" @@ -25,45 +25,35 @@ std::string InstallExtension::getOutputMsg() { return common::stringFormat("Extension: {} has been installed.", name); } -std::string InstallExtension::tryDownloadExtension() { - auto extensionRepoInfo = ExtensionUtils::getExtensionRepoInfo(name); +static void saveExtensionToLocalFile(const std::string& extensionData, const std::string& name, + main::ClientContext* context) { + auto extensionPath = ExtensionUtils::getLocalPathForExtensionInstaller(context, name); + auto fileInfo = context->getVFSUnsafe()->openFile(extensionPath, O_WRONLY | O_CREAT); + fileInfo->writeFile(reinterpret_cast(extensionData.c_str()), + extensionData.size(), 0 /* offset */); + fileInfo->syncFile(); +} + +static void installDependencies(const std::string& name, main::ClientContext* context) { + auto extensionRepoInfo = ExtensionUtils::getExtensionInstallerRepoInfo(name); httplib::Client cli(extensionRepoInfo.hostURL.c_str()); httplib::Headers headers = { {"User-Agent", common::stringFormat("kuzu/v{}", KUZU_EXTENSION_VERSION)}}; auto res = cli.Get(extensionRepoInfo.hostPath.c_str(), headers); if (!res || res->status != 200) { - if (res.error() == httplib::Error::Success) { - // LCOV_EXCL_START - throw IOException(common::stringFormat( - "HTTP Returns: {}, Failed to download extension: \"{}\" from {}.", - res.value().status, name, extensionRepoInfo.repoURL)); - // LCOC_EXCL_STOP - } else { - throw IOException( - common::stringFormat("Failed to download extension: {} at URL {} (ERROR: {})", name, - extensionRepoInfo.repoURL, to_string(res.error()))); - } + return; } - return res->body; -} - -void InstallExtension::saveExtensionToLocalFile(const std::string& extensionData, - main::ClientContext* context) { - auto extensionDir = context->getExtensionDir(); - auto extensionPath = ExtensionUtils::getExtensionPath(extensionDir, name); - auto vfs = context->getVFSUnsafe(); - if (!vfs->fileOrPathExists(extensionDir, context)) { - vfs->createDir(extensionDir); - } - auto fileInfo = vfs->openFile(extensionPath, O_WRONLY | O_CREAT); - fileInfo->writeFile(reinterpret_cast(extensionData.c_str()), - extensionData.size(), 0 /* offset */); - fileInfo->syncFile(); + saveExtensionToLocalFile(res->body, name, context); + auto extensionInstallerPath = ExtensionUtils::getLocalPathForExtensionInstaller(context, name); + auto libLoader = ExtensionLibLoader(name, extensionInstallerPath.c_str()); + auto install = libLoader.getInstallFunc(); + (*install)(context); } void InstallExtension::installExtension(main::ClientContext* context) { - auto extensionData = tryDownloadExtension(); - saveExtensionToLocalFile(extensionData, context); + extension::ExtensionInstaller installer{name}; + installer.install(context); + installDependencies(name, context); } } // namespace processor diff --git a/src/processor/operator/simple/load_extension.cpp b/src/processor/operator/simple/load_extension.cpp index 4ff2ca14901..420069996c2 100644 --- a/src/processor/operator/simple/load_extension.cpp +++ b/src/processor/operator/simple/load_extension.cpp @@ -1,25 +1,11 @@ #include "processor/operator/simple/load_extension.h" -#include "common/exception/io.h" - -#ifdef _WIN32 - -#include "windows.h" -#define RTLD_NOW 0 -#define RTLD_LOCAL 0 - -#else -#include -#endif - -#include "common/system_message.h" +#include "common/file_system/virtual_file_system.h" #include "extension/extension.h" #include "main/database.h" using namespace kuzu::common; -typedef void (*ext_init_func_t)(kuzu::main::ClientContext*); - namespace kuzu { namespace processor { @@ -29,50 +15,30 @@ std::string LoadExtensionPrintInfo::toString() const { return "Load " + extensionName; } -#ifdef _WIN32 -std::wstring utf8ToUnicode(const char* input) { - uint32_t result; - - result = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); - if (result == 0) { - throw IOException("Failure in MultiByteToWideChar"); - } - auto buffer = std::make_unique(result); - result = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result); - if (result == 0) { - throw IOException("Failure in MultiByteToWideChar"); +static void executeExtensionLoader(main::ClientContext* context, const std::string& extensionName) { + auto loaderPath = ExtensionUtils::getLocalPathForExtensionLoader(context, extensionName); + if (context->getVFSUnsafe()->fileOrPathExists(loaderPath)) { + auto libLoader = ExtensionLibLoader(extensionName, loaderPath); + auto load = libLoader.getLoadFunc(); + (*load)(context); } - return std::wstring(buffer.get(), result); } -void* dlopen(const char* file, int /*mode*/) { - KU_ASSERT(file); - auto fpath = utf8ToUnicode(file); - return (void*)LoadLibraryW(fpath.c_str()); -} - -void* dlsym(void* handle, const char* name) { - KU_ASSERT(handle); - return (void*)GetProcAddress((HINSTANCE)handle, name); -} -#endif - void LoadExtension::executeInternal(kuzu::processor::ExecutionContext* context) { + auto fullPath = path; if (!extension::ExtensionUtils::isFullPath(path)) { - path = ExtensionUtils::getExtensionPath(context->clientContext->getExtensionDir(), path); + auto localPathForSharedLib = + ExtensionUtils::getLocalPathForSharedLib(context->clientContext); + if (!context->clientContext->getVFSUnsafe()->fileOrPathExists(localPathForSharedLib)) { + context->clientContext->getVFSUnsafe()->createDir(localPathForSharedLib); + } + executeExtensionLoader(context->clientContext, path); + fullPath = ExtensionUtils::getLocalPathForExtensionLib(context->clientContext, path); } - auto libHdl = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); - if (libHdl == nullptr) { - throw common::IOException( - stringFormat("Extension \"{}\" could not be loaded.\nError: {}", path, dlErrMessage())); - } - auto load = (ext_init_func_t)(dlsym(libHdl, "init")); - if (load == nullptr) { - throw common::IOException( - stringFormat("Extension \"{}\" does not have a valid init function.\nError: {}", path, - dlErrMessage())); - } - (*load)(context->clientContext); + + auto libLoader = ExtensionLibLoader(path, fullPath); + auto init = libLoader.getInitFunc(); + (*init)(context->clientContext); } std::string LoadExtension::getOutputMsg() { diff --git a/test/c_api/database_test.cpp b/test/c_api/database_test.cpp index 6a8a29afece..f14f10140ee 100644 --- a/test/c_api/database_test.cpp +++ b/test/c_api/database_test.cpp @@ -94,11 +94,3 @@ TEST_F(CApiDatabaseTest, CreationHomeDir) { kuzu_database_destroy(&database); std::filesystem::remove_all(homePath + "/ku_test.db"); } - -TEST_F(CApiDatabaseTest, dadsa) { - createDBAndConn(); - printf("%s", - conn->query("create node table test1(ID INT64, description STRUCT, PRIMARY KEY(ID))") - ->toString() - .c_str()); -}