diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml new file mode 100644 index 0000000..7704583 --- /dev/null +++ b/.github/workflows/wheels.yml @@ -0,0 +1,58 @@ +name: Build + +on: [push, pull_request] + +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] # , windows-latest, macos-13, macos-latest + + steps: + - uses: actions/checkout@v4 + + # Used to host cibuildwheel + - uses: actions/setup-python@v5 + + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==2.21.3 + + + - name: Install cibuildwheel + run: python -m pip install cibuildwheel + + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + + - name: Display vcpkg logs on failure + if: failure() + run: | + echo "vcpkg bootstrap log:" + cat ${{ github.workspace }}/.vcpkg/buildtrees/vcpkg/bootstrap-out.log || echo "Bootstrap log not found" + echo "vcpkg install log:" + find /tmp -name vcpkg-bootstrap.log -exec cat {} \; || echo "Install log not found" + shell: bash + + - uses: actions/upload-artifact@v4 + with: + name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} + path: ./wheelhouse/*.whl + + publish: + needs: build_wheels + runs-on: ubuntu-latest + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + steps: + - uses: actions/download-artifact@v4 + with: + pattern: cibw-wheels-* + path: dist + merge-multiple: true + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@v1.8.14 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/jsp/CMakeLists.txt b/jsp/CMakeLists.txt index 256a78e..5e9f210 100644 --- a/jsp/CMakeLists.txt +++ b/jsp/CMakeLists.txt @@ -10,22 +10,33 @@ find_package(glfw3 REQUIRED) find_package(CURL REQUIRED) find_package(fmt REQUIRED) + + + + + +find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) + # Check if nanobind_DIR is set if(DEFINED nanobind_DIR) list(APPEND CMAKE_MODULE_PATH "${nanobind_DIR}") message(STATUS "nanobind_DIR is set to: ${nanobind_DIR}") else() message(WARNING "nanobind_DIR is not set") + # site-packages/nanobind/cmake + list(APPEND CMAKE_MODULE_PATH "${Python_SITEARCH}/nanobind/cmake") + + + get_cmake_property(_variableNames VARIABLES) + list (SORT _variableNames) + foreach (_variableName ${_variableNames}) + message(STATUS "${_variableName}=${${_variableName}}") + endforeach() + message( "${Python_SITEARCH}/nanobind/cmake") endif() # Try to find nanobind -find_package(Python COMPONENTS Interpreter Development REQUIRED) - -# print path to Python executable -message(STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}") -set(nanobind_DIR "${Python_SITELIB}/nanobind/cmake") find_package(nanobind CONFIG REQUIRED) - if(NOT nanobind_FOUND) message(FATAL_ERROR "nanobind not found. Please set nanobind_DIR to the directory containing nanobind-config.cmake") endif() @@ -71,8 +82,8 @@ if(NOT DEFINED BUILD_EXECUTABLE OR BUILD_EXECUTABLE) endif() # Create the Python module -nanobind_add_module(jobshop bindings/jobshop_bindings.cpp ${SOURCES} ${HEADERS}) -target_include_directories(jobshop PRIVATE include) +nanobind_add_module(jobshop bindings/jobshop_bindings.cpp bindings/multi_dim_array_bind.cpp ${SOURCES} ${HEADERS}) +target_include_directories(jobshop PRIVATE include ${Python_INCLUDE_DIRS}) target_link_libraries(jobshop PRIVATE ${OPENGL_LIBRARIES} GLEW::GLEW @@ -80,6 +91,7 @@ target_link_libraries(jobshop PRIVATE ${CURL_LIBRARIES} fmt::fmt imgui::imgui + ${Python_LIBRARIES} ) # Installation diff --git a/jsp/CMakePresets.json b/jsp/CMakePresets.json index 130b479..5439c5d 100644 --- a/jsp/CMakePresets.json +++ b/jsp/CMakePresets.json @@ -23,7 +23,9 @@ "type": "FILEPATH" } }, - "environment": {} + "environment": { + + } }, { "name": "debug", @@ -63,7 +65,22 @@ "CMAKE_CXX_FLAGS": "-O3 -DNDEBUG -march=native -flto -ffast-math -fno-omit-frame-pointer", "BUILD_EXECUTABLE": "OFF" } + }, + { + "name": "debug-dev", + "inherits": "debug", + "displayName": "Debug Dev", + "cacheVariables": { + "nanobind_DIR": "../venv/lib/python3.12/site-packages/nanobind/cmake" + }, + "environment": { + "VCPKG_ROOT": "../vcpkg" + } + } + + + ], "workflowPresets": [] } \ No newline at end of file diff --git a/jsp/bindings/jobshop_bindings.cpp b/jsp/bindings/jobshop_bindings.cpp index 7078a2a..4dc928a 100644 --- a/jsp/bindings/jobshop_bindings.cpp +++ b/jsp/bindings/jobshop_bindings.cpp @@ -1,204 +1,289 @@ -#include -#include -#include -#include -#include -#include - -#include "job_shop_environment.h" -#include "job_shop_qlearning.h" -#include "job_shop_actor_critic.h" -#include "job_shop_taillard_generator.h" -#include "job_shop_plotter.h" -#include "job_shop_algorithms.h" - -namespace nb = nanobind; - -NB_MODULE(jobshop, m) { - // Bind Operation struct - nb::class_(m, "Operation") - .def(nb::init<>()) - .def_rw("duration", &Operation::duration) - .def_rw("machine", &Operation::machine) - .def_rw("eligibleMachines", &Operation::eligibleMachines) - .def_rw("dependentOperations", &Operation::dependentOperations); - - // Bind Job struct - nb::class_(m, "Job") - .def(nb::init<>()) - .def_rw("operations", &Job::operations) - .def_rw("dependentJobs", &Job::dependentJobs); - - // Bind Action struct - nb::class_(m, "Action") - .def(nb::init<>()) - .def(nb::init()) - .def_rw("job", &Action::job) - .def_rw("machine", &Action::machine) - .def_rw("operation", &Action::operation); - - // Bind State struct - nb::class_(m, "State") - .def(nb::init()) - .def_rw("jobProgress", &State::jobProgress) - .def_rw("machineAvailability", &State::machineAvailability) - .def_rw("nextOperationForJob", &State::nextOperationForJob) - .def_rw("completedJobs", &State::completedJobs) - .def_rw("jobStartTimes", &State::jobStartTimes); - - // Bind ScheduleEntry struct - nb::class_(m, "ScheduleEntry") - .def(nb::init<>()) - .def(nb::init()) - .def_rw("job", &ScheduleEntry::job) - .def_rw("operation", &ScheduleEntry::operation) - .def_rw("start", &ScheduleEntry::start) - .def_rw("duration", &ScheduleEntry::duration) - .def("__getstate__", [](const ScheduleEntry &se) { - return std::make_tuple(se.job, se.operation, se.start, se.duration); - }) - .def("__setstate__", [](ScheduleEntry &se, const std::tuple &state) { - new (&se) ScheduleEntry{ - std::get<0>(state), - std::get<1>(state), - std::get<2>(state), - std::get<3>(state) - }; - }); - - // Bind JobShopEnvironment class - nb::class_(m, "JobShopEnvironment") - .def(nb::init>()) - .def("step", &JobShopEnvironment::step) - .def("reset", &JobShopEnvironment::reset) - .def("getState", &JobShopEnvironment::getState) - .def("getPossibleActions", &JobShopEnvironment::getPossibleActions) - .def("isDone", &JobShopEnvironment::isDone) - .def("getTotalTime", &JobShopEnvironment::getTotalTime) - .def("getJobs", &JobShopEnvironment::getJobs) - .def("getNumMachines", &JobShopEnvironment::getNumMachines) - .def("getScheduleData", &JobShopEnvironment::getScheduleData) - .def("printSchedule", &JobShopEnvironment::printSchedule) - .def("generateOperationGraph", &JobShopEnvironment::generateOperationGraph); - - // Bind JobShopAlgorithm abstract class - nb::class_(m, "JobShopAlgorithm") - .def("train", &JobShopAlgorithm::train) - .def("printBestSchedule", &JobShopAlgorithm::printBestSchedule) - .def("saveBestScheduleToFile", &JobShopAlgorithm::saveBestScheduleToFile); - - // Bind JobShopQLearning class - nb::class_(m, "JobShopQLearning") - .def(nb::init()) - .def("runEpisode", &JobShopQLearning::runEpisode) - .def("train", &JobShopQLearning::train) - .def("printBestSchedule", &JobShopQLearning::printBestSchedule) - .def("saveBestScheduleToFile", &JobShopQLearning::saveBestScheduleToFile) - .def("applyAndPrintSchedule", &JobShopQLearning::applyAndPrintSchedule); - - // Bind JobShopActorCritic class - nb::class_(m, "JobShopActorCritic") - .def(nb::init()) - .def("runEpisode", &JobShopActorCritic::runEpisode) - .def("train", &JobShopActorCritic::train) - .def("printBestSchedule", &JobShopActorCritic::printBestSchedule) - .def("saveBestScheduleToFile", &JobShopActorCritic::saveBestScheduleToFile) - .def("applyAndPrintSchedule", &JobShopActorCritic::applyAndPrintSchedule); - - // Bind TaillardInstance enum - nb::enum_(m, "TaillardInstance") - .value("TA01", TaillardInstance::TA01) - .value("TA02", TaillardInstance::TA02) - .value("TA03", TaillardInstance::TA03) - .value("TA04", TaillardInstance::TA04) - .value("TA05", TaillardInstance::TA05) - .value("TA06", TaillardInstance::TA06) - .value("TA07", TaillardInstance::TA07) - .value("TA08", TaillardInstance::TA08) - .value("TA09", TaillardInstance::TA09) - .value("TA10", TaillardInstance::TA10) - .value("TA11", TaillardInstance::TA11) - .value("TA12", TaillardInstance::TA12) - .value("TA13", TaillardInstance::TA13) - .value("TA14", TaillardInstance::TA14) - .value("TA15", TaillardInstance::TA15) - .value("TA16", TaillardInstance::TA16) - .value("TA17", TaillardInstance::TA17) - .value("TA18", TaillardInstance::TA18) - .value("TA19", TaillardInstance::TA19) - .value("TA20", TaillardInstance::TA20) - .value("TA21", TaillardInstance::TA21) - .value("TA22", TaillardInstance::TA22) - .value("TA23", TaillardInstance::TA23) - .value("TA24", TaillardInstance::TA24) - .value("TA25", TaillardInstance::TA25) - .value("TA26", TaillardInstance::TA26) - .value("TA27", TaillardInstance::TA27) - .value("TA28", TaillardInstance::TA28) - .value("TA29", TaillardInstance::TA29) - .value("TA30", TaillardInstance::TA30) - .value("TA31", TaillardInstance::TA31) - .value("TA32", TaillardInstance::TA32) - .value("TA33", TaillardInstance::TA33) - .value("TA34", TaillardInstance::TA34) - .value("TA35", TaillardInstance::TA35) - .value("TA36", TaillardInstance::TA36) - .value("TA37", TaillardInstance::TA37) - .value("TA38", TaillardInstance::TA38) - .value("TA39", TaillardInstance::TA39) - .value("TA40", TaillardInstance::TA40) - .value("TA41", TaillardInstance::TA41) - .value("TA42", TaillardInstance::TA42) - .value("TA43", TaillardInstance::TA43) - .value("TA44", TaillardInstance::TA44) - .value("TA45", TaillardInstance::TA45) - .value("TA46", TaillardInstance::TA46) - .value("TA47", TaillardInstance::TA47) - .value("TA48", TaillardInstance::TA48) - .value("TA49", TaillardInstance::TA49) - .value("TA50", TaillardInstance::TA50) - .value("TA51", TaillardInstance::TA51) - .value("TA52", TaillardInstance::TA52) - .value("TA53", TaillardInstance::TA53) - .value("TA54", TaillardInstance::TA54) - .value("TA55", TaillardInstance::TA55) - .value("TA56", TaillardInstance::TA56) - .value("TA57", TaillardInstance::TA57) - .value("TA58", TaillardInstance::TA58) - .value("TA59", TaillardInstance::TA59) - .value("TA60", TaillardInstance::TA60) - .value("TA61", TaillardInstance::TA61) - .value("TA62", TaillardInstance::TA62) - .value("TA63", TaillardInstance::TA63) - .value("TA64", TaillardInstance::TA64) - .value("TA65", TaillardInstance::TA65) - .value("TA66", TaillardInstance::TA66) - .value("TA67", TaillardInstance::TA67) - .value("TA68", TaillardInstance::TA68) - .value("TA69", TaillardInstance::TA69) - .value("TA70", TaillardInstance::TA70) - .value("TA71", TaillardInstance::TA71) - .value("TA72", TaillardInstance::TA72) - .value("TA73", TaillardInstance::TA73) - .value("TA74", TaillardInstance::TA74) - .value("TA75", TaillardInstance::TA75) - .value("TA76", TaillardInstance::TA76) - .value("TA77", TaillardInstance::TA77) - .value("TA78", TaillardInstance::TA78) - .value("TA79", TaillardInstance::TA79) - .value("TA80", TaillardInstance::TA80); - - // Bind TaillardJobShopGenerator class - nb::class_(m, "TaillardJobShopGenerator") - .def_static("loadProblem", &TaillardJobShopGenerator::loadProblem) - .def_static("verifyJobsData", &TaillardJobShopGenerator::verifyJobsData) - .def_static("verifyOptimalSolution", &TaillardJobShopGenerator::verifyOptimalSolution) - .def_static("runAllVerifications", &TaillardJobShopGenerator::runAllVerifications); - - // Bind LivePlotter class - nb::class_(m, "LivePlotter") - .def(nb::init()) - .def("render", &LivePlotter::render) - .def("updateSchedule", &LivePlotter::updateSchedule) - .def("shouldClose", &LivePlotter::shouldClose); +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "job_shop_environment.h" +#include "job_shop_qlearning.h" +#include "job_shop_actor_critic.h" +#include "job_shop_taillard_generator.h" +#include "job_shop_plotter.h" +#include "job_shop_algorithms.h" +#include + +namespace nb = nanobind; + + + +template +void bind_multi_dim_array(nb::module_ &m, const char* name) { + nb::class_>(m, name) + .def(nb::init&>()) + .def("__array__", [](MultiDimensionalArray& arr, nb::kwargs kwargs) { + bool copy = kwargs.contains("copy") && nb::cast(kwargs["copy"]); + + // throw if copy is requested + if (copy) { + throw std::runtime_error("copy is not supported"); + } + + std::vector shape(arr.shape().begin(), arr.shape().end()); + + nb::handle owner = nb::cast(&arr, nb::rv_policy::reference_internal); + + return nb::ndarray>( + arr.data_ptr(), + shape.size(), + shape.data(), + owner, + nullptr,//strides.data(), + nb::dtype(), + 0, // device_type (CPU) + 0 // device_id + ); + }, nb::rv_policy::reference_internal) + .def("shape", &MultiDimensionalArray::shape) + .def("size", &MultiDimensionalArray::size) + .def("fill", &MultiDimensionalArray::fill); +} + + + +NB_MODULE(jobshop, m) { + + + bind_multi_dim_array(m, "MultiDimArray1f"); + bind_multi_dim_array(m, "MultiDimArray2f"); + bind_multi_dim_array(m, "MultiDimArray3f"); + bind_multi_dim_array(m, "MultiDimArray1d"); + bind_multi_dim_array(m, "MultiDimArray2d"); + bind_multi_dim_array(m, "MultiDimArray3d"); + bind_multi_dim_array(m, "MultiDimArray1i"); + bind_multi_dim_array(m, "MultiDimArray2i"); + bind_multi_dim_array(m, "MultiDimArray3i"); + + + // Bind Operation struct + nb::class_(m, "Operation") + .def(nb::init<>()) + .def_rw("duration", &Operation::duration) + .def_rw("machine", &Operation::machine) + .def_rw("eligibleMachines", &Operation::eligibleMachines) + .def_rw("dependentOperations", &Operation::dependentOperations) + .def("__getstate__", [](const Operation &op) { + return nb::make_tuple( + op.duration, + op.machine, + op.eligibleMachines.to_string(), // Convert bitset to string + op.dependentOperations + ); + }) + .def("__setstate__", [](Operation &op, nb::tuple state) { + if (state.size() != 4) { + throw std::runtime_error("Invalid state!"); + } + op.duration = nb::cast(state[0]); + op.machine = nb::cast(state[1]); + op.eligibleMachines = std::bitset(nb::cast(state[2])); // Convert string back to bitset + op.dependentOperations = nb::cast>>(state[3]); + }); + + // Bind Job struct + nb::class_(m, "Job") + .def(nb::init<>()) + .def_rw("operations", &Job::operations) + .def_rw("dependentJobs", &Job::dependentJobs) + .def("__getstate__", [](const Job &job) { + return std::make_tuple(job.operations, job.dependentJobs); + }) + .def("__setstate__", [](Job &job, const std::tuple, std::vector> &state) { + new (&job) Job{ + std::get<0>(state), + std::get<1>(state) + }; + }); + + + // Bind Action struct + nb::class_(m, "Action") + .def(nb::init<>()) + .def(nb::init()) + .def_rw("job", &Action::job) + .def_rw("machine", &Action::machine) + .def_rw("operation", &Action::operation); + + // Bind State struct + nb::class_(m, "State") + .def(nb::init()) + .def_rw("jobProgress", &State::jobProgress) + .def_rw("machineAvailability", &State::machineAvailability) + .def_rw("nextOperationForJob", &State::nextOperationForJob) + .def_rw("completedJobs", &State::completedJobs) + .def_rw("jobStartTimes", &State::jobStartTimes); + + // Bind ScheduleEntry struct + nb::class_(m, "ScheduleEntry") + .def(nb::init<>()) + .def(nb::init()) + .def_rw("job", &ScheduleEntry::job) + .def_rw("operation", &ScheduleEntry::operation) + .def_rw("start", &ScheduleEntry::start) + .def_rw("duration", &ScheduleEntry::duration) + .def("__getstate__", [](const ScheduleEntry &se) { + return std::make_tuple(se.job, se.operation, se.start, se.duration); + }) + .def("__setstate__", [](ScheduleEntry &se, const std::tuple &state) { + new (&se) ScheduleEntry{ + std::get<0>(state), + std::get<1>(state), + std::get<2>(state), + std::get<3>(state) + }; + }); + + // Bind JobShopEnvironment class + nb::class_(m, "JobShopEnvironment") + .def(nb::init>()) + .def("step", &JobShopEnvironment::step) + .def("reset", [](JobShopEnvironment &env, nb::kwargs kwargs) { + std::optional seed; + if (kwargs.contains("seed")) { + seed = nb::cast(kwargs["seed"]); + } + return env.reset(); + }, nb::arg("seed") = nb::none(), "Reset the environment. Optionally provide a seed for randomization.") + .def("getState", &JobShopEnvironment::getState) + .def("getPossibleActions", &JobShopEnvironment::getPossibleActions) + .def("isDone", &JobShopEnvironment::isDone) + .def("getTotalTime", &JobShopEnvironment::getTotalTime) + .def("getJobs", &JobShopEnvironment::getJobs) + .def("getNumMachines", &JobShopEnvironment::getNumMachines) + .def("getScheduleData", &JobShopEnvironment::getScheduleData) + .def("printSchedule", &JobShopEnvironment::printSchedule) + .def("generateOperationGraph", &JobShopEnvironment::generateOperationGraph); + + // Bind JobShopAlgorithm abstract class + nb::class_(m, "JobShopAlgorithm") + .def("train", &JobShopAlgorithm::train) + .def("printBestSchedule", &JobShopAlgorithm::printBestSchedule) + .def("saveBestScheduleToFile", &JobShopAlgorithm::saveBestScheduleToFile); + + // Bind JobShopQLearning class + nb::class_(m, "JobShopQLearning") + .def(nb::init()) + .def("runEpisode", &JobShopQLearning::runEpisode) + .def("train", &JobShopQLearning::train) + .def("printBestSchedule", &JobShopQLearning::printBestSchedule) + .def("saveBestScheduleToFile", &JobShopQLearning::saveBestScheduleToFile) + .def("applyAndPrintSchedule", &JobShopQLearning::applyAndPrintSchedule); + + // Bind JobShopActorCritic class + nb::class_(m, "JobShopActorCritic") + .def(nb::init()) + .def("runEpisode", &JobShopActorCritic::runEpisode) + .def("train", &JobShopActorCritic::train) + .def("printBestSchedule", &JobShopActorCritic::printBestSchedule) + .def("saveBestScheduleToFile", &JobShopActorCritic::saveBestScheduleToFile) + .def("applyAndPrintSchedule", &JobShopActorCritic::applyAndPrintSchedule); + + // Bind TaillardInstance enum + nb::enum_(m, "TaillardInstance") + .value("TA01", TaillardInstance::TA01) + .value("TA02", TaillardInstance::TA02) + .value("TA03", TaillardInstance::TA03) + .value("TA04", TaillardInstance::TA04) + .value("TA05", TaillardInstance::TA05) + .value("TA06", TaillardInstance::TA06) + .value("TA07", TaillardInstance::TA07) + .value("TA08", TaillardInstance::TA08) + .value("TA09", TaillardInstance::TA09) + .value("TA10", TaillardInstance::TA10) + .value("TA11", TaillardInstance::TA11) + .value("TA12", TaillardInstance::TA12) + .value("TA13", TaillardInstance::TA13) + .value("TA14", TaillardInstance::TA14) + .value("TA15", TaillardInstance::TA15) + .value("TA16", TaillardInstance::TA16) + .value("TA17", TaillardInstance::TA17) + .value("TA18", TaillardInstance::TA18) + .value("TA19", TaillardInstance::TA19) + .value("TA20", TaillardInstance::TA20) + .value("TA21", TaillardInstance::TA21) + .value("TA22", TaillardInstance::TA22) + .value("TA23", TaillardInstance::TA23) + .value("TA24", TaillardInstance::TA24) + .value("TA25", TaillardInstance::TA25) + .value("TA26", TaillardInstance::TA26) + .value("TA27", TaillardInstance::TA27) + .value("TA28", TaillardInstance::TA28) + .value("TA29", TaillardInstance::TA29) + .value("TA30", TaillardInstance::TA30) + .value("TA31", TaillardInstance::TA31) + .value("TA32", TaillardInstance::TA32) + .value("TA33", TaillardInstance::TA33) + .value("TA34", TaillardInstance::TA34) + .value("TA35", TaillardInstance::TA35) + .value("TA36", TaillardInstance::TA36) + .value("TA37", TaillardInstance::TA37) + .value("TA38", TaillardInstance::TA38) + .value("TA39", TaillardInstance::TA39) + .value("TA40", TaillardInstance::TA40) + .value("TA41", TaillardInstance::TA41) + .value("TA42", TaillardInstance::TA42) + .value("TA43", TaillardInstance::TA43) + .value("TA44", TaillardInstance::TA44) + .value("TA45", TaillardInstance::TA45) + .value("TA46", TaillardInstance::TA46) + .value("TA47", TaillardInstance::TA47) + .value("TA48", TaillardInstance::TA48) + .value("TA49", TaillardInstance::TA49) + .value("TA50", TaillardInstance::TA50) + .value("TA51", TaillardInstance::TA51) + .value("TA52", TaillardInstance::TA52) + .value("TA53", TaillardInstance::TA53) + .value("TA54", TaillardInstance::TA54) + .value("TA55", TaillardInstance::TA55) + .value("TA56", TaillardInstance::TA56) + .value("TA57", TaillardInstance::TA57) + .value("TA58", TaillardInstance::TA58) + .value("TA59", TaillardInstance::TA59) + .value("TA60", TaillardInstance::TA60) + .value("TA61", TaillardInstance::TA61) + .value("TA62", TaillardInstance::TA62) + .value("TA63", TaillardInstance::TA63) + .value("TA64", TaillardInstance::TA64) + .value("TA65", TaillardInstance::TA65) + .value("TA66", TaillardInstance::TA66) + .value("TA67", TaillardInstance::TA67) + .value("TA68", TaillardInstance::TA68) + .value("TA69", TaillardInstance::TA69) + .value("TA70", TaillardInstance::TA70) + .value("TA71", TaillardInstance::TA71) + .value("TA72", TaillardInstance::TA72) + .value("TA73", TaillardInstance::TA73) + .value("TA74", TaillardInstance::TA74) + .value("TA75", TaillardInstance::TA75) + .value("TA76", TaillardInstance::TA76) + .value("TA77", TaillardInstance::TA77) + .value("TA78", TaillardInstance::TA78) + .value("TA79", TaillardInstance::TA79) + .value("TA80", TaillardInstance::TA80); + + // Bind TaillardJobShopGenerator class + nb::class_(m, "TaillardJobShopGenerator") + .def_static("loadProblem", &TaillardJobShopGenerator::loadProblem) + .def_static("verifyJobsData", &TaillardJobShopGenerator::verifyJobsData) + .def_static("verifyOptimalSolution", &TaillardJobShopGenerator::verifyOptimalSolution) + .def_static("runAllVerifications", &TaillardJobShopGenerator::runAllVerifications); + + // Bind LivePlotter class + nb::class_(m, "LivePlotter") + .def(nb::init()) + .def("render", &LivePlotter::render) + .def("updateSchedule", &LivePlotter::updateSchedule) + .def("shouldClose", &LivePlotter::shouldClose); } \ No newline at end of file diff --git a/jsp/bindings/multi_dim_array_bind.cpp b/jsp/bindings/multi_dim_array_bind.cpp index c44262c..fd4a643 100644 --- a/jsp/bindings/multi_dim_array_bind.cpp +++ b/jsp/bindings/multi_dim_array_bind.cpp @@ -1,40 +1,9 @@ #include -#include -#include + #include namespace nb = nanobind; -template -nb::ndarray, nb::device::cpu> wrap_multi_dim_array(MultiDimensionalArray& arr) { - return nb::ndarray, nb::device::cpu>( - arr.data_ptr(), - NDim, - arr.shape().data(), - arr.strides().data() - ); -} - -template -void bind_multi_dim_array(nb::module_ &m, const char* name) { - nb::class_>(m, name) - .def(nb::init&>()) - .def("__array__", [](MultiDimensionalArray& arr) { - return wrap_multi_dim_array(arr); - }) - .def("shape", &MultiDimensionalArray::shape) - .def("size", &MultiDimensionalArray::size) - .def("fill", &MultiDimensionalArray::fill); -} - NB_MODULE(multi_dim_array_ext, m) { -bind_multi_dim_array(m, "MultiDimArray1f"); -bind_multi_dim_array(m, "MultiDimArray2f"); -bind_multi_dim_array(m, "MultiDimArray3f"); -bind_multi_dim_array(m, "MultiDimArray1d"); -bind_multi_dim_array(m, "MultiDimArray2d"); -bind_multi_dim_array(m, "MultiDimArray3d"); -bind_multi_dim_array(m, "MultiDimArray1i"); -bind_multi_dim_array(m, "MultiDimArray2i"); -bind_multi_dim_array(m, "MultiDimArray3i"); + } \ No newline at end of file diff --git a/jsp/examples/ppo_action.py b/jsp/examples/ppo_action.py new file mode 100644 index 0000000..914c209 --- /dev/null +++ b/jsp/examples/ppo_action.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import argparse +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Optional, Callable + +import jobshop +import numpy as np +import gymnasium as gym +from gymnasium import spaces +from sb3_contrib import MaskablePPO +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.vec_env import DummyVecEnv + +class ObservationSpace(ABC): + @abstractmethod + def get_observation_space(self, env: JobShopGymEnv) -> spaces.Space: + pass + + @abstractmethod + def get_observation(self, env: JobShopGymEnv) -> np.ndarray: + pass + +class DefaultObservationSpace(ObservationSpace): + def get_observation_space(self, env: JobShopGymEnv) -> spaces.Space: + num_jobs = len(env.env.getJobs()) + num_machines = env.env.getNumMachines() + max_operations = max(len(job.operations) for job in env.env.getJobs()) + + low = np.zeros(num_jobs * max_operations + num_jobs * 3 + num_machines) + high = np.inf * np.ones(num_jobs * max_operations + num_jobs * 3 + num_machines) + + return spaces.Box(low=low, high=high, dtype=np.float32) + + def get_observation(self, env: JobShopGymEnv) -> np.ndarray: + state: jobshop.JobShopState = env.env.getState() + job_progress = np.array(state.jobProgress, copy=False).flatten() + completed_jobs = np.array(state.completedJobs, dtype=np.float32) + job_start_times = np.array(state.jobStartTimes, dtype=np.float32) + machine_availability = np.array(state.machineAvailability, dtype=np.float32) + next_operation_for_job = np.array(state.nextOperationForJob, dtype=np.float32) + + return np.concatenate([ + job_progress, + completed_jobs, + job_start_times, + machine_availability, + next_operation_for_job + ]) + +class NormalizedObservationSpace(ObservationSpace): + def get_observation_space(self, env: JobShopGymEnv) -> spaces.Space: + num_jobs = len(env.env.getJobs()) + num_machines = env.env.getNumMachines() + max_operations = max(len(job.operations) for job in env.env.getJobs()) + + return spaces.Box(low=0, high=1, shape=(num_jobs * max_operations + num_jobs * 3 + num_machines,), dtype=np.float32) + + def get_observation(self, env: JobShopGymEnv) -> np.ndarray: + state: jobshop.JobShopState = env.env.getState() + total_time = env.env.getTotalTime() + max_time = sum(op.duration for job in env.env.getJobs() for op in job.operations) + + job_progress = np.array(state.jobProgress, copy=False).flatten() / max_time + completed_jobs = np.array(state.completedJobs, dtype=np.float32) + job_start_times = np.array(state.jobStartTimes, dtype=np.float32) / max_time + machine_availability = np.array(state.machineAvailability, dtype=np.float32) / max_time + next_operation_for_job = np.array(state.nextOperationForJob, dtype=np.float32) / max(len(job.operations) for job in env.env.getJobs()) + + return np.concatenate([ + job_progress, + completed_jobs, + job_start_times, + machine_availability, + next_operation_for_job + ]) + +class RewardFunction(ABC): + @abstractmethod + def calculate_reward(self, env: JobShopGymEnv, done: bool) -> float: + pass + +class MakespanRewardFunction(RewardFunction): + def calculate_reward(self, env: JobShopGymEnv, done: bool) -> float: + if done: + return -env.env.getTotalTime() + return 0 + +class ProgressRewardFunction(RewardFunction): + def __init__(self, completion_bonus: float = 1000): + self.completion_bonus = completion_bonus + self.last_progress = 0 + + def calculate_reward(self, env: JobShopGymEnv, done: bool) -> float: + state = env.env.getState() + current_progress = sum(state.nextOperationForJob) / sum(len(job.operations) for job in env.env.getJobs()) + progress_reward = (current_progress - self.last_progress) * 100 + self.last_progress = current_progress + + if done: + return progress_reward + self.completion_bonus - env.env.getTotalTime() + return progress_reward + +class JobShopGymEnv(gym.Env): + metadata: Dict[str, List[str]] = {'render.modes': ['human']} + + def __init__(self, jobshop_env: jobshop.JobShopEnvironment, max_steps: int = 200, + observation_space: ObservationSpace = DefaultObservationSpace(), + reward_function: RewardFunction = MakespanRewardFunction()): + super().__init__() + self.env: jobshop.JobShopEnvironment = jobshop_env + self.num_jobs: int = len(self.env.getJobs()) + self.num_machines: int = self.env.getNumMachines() + self.max_operations: int = max(len(job.operations) for job in self.env.getJobs()) + self.max_num_actions: int = self.num_jobs * self.num_machines * self.max_operations + self.action_space: spaces.Discrete = spaces.Discrete(self.max_num_actions) + + self.observation_space_impl = observation_space + self.observation_space = self.observation_space_impl.get_observation_space(self) + self.reward_function = reward_function + + self.action_indices = np.array([ + [job * self.num_machines * self.max_operations + machine * self.max_operations + for machine in range(self.num_machines)] + for job in range(self.num_jobs) + ]) + + self.action_map: Dict[int, jobshop.Action] = {} + self.use_masking: bool = True + self._action_mask: Optional[np.ndarray] = None + self.max_steps: int = max_steps + self.current_step: int = 0 + + def reset(self, **kwargs: Any) -> Tuple[np.ndarray, Dict[str, Any]]: + self.env.reset() + self.current_step = 0 + obs: np.ndarray = self.observation_space_impl.get_observation(self) + self._update_action_mask() + return obs, {} + + def step(self, action_idx: int) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: + self.current_step += 1 + + if self._action_mask[action_idx] == 0: + reward: float = -1 + done: bool = self.current_step >= self.max_steps + obs: np.ndarray = self.observation_space_impl.get_observation(self) + info: Dict[str, Any] = {'invalid_action': True, 'makespan': self.env.getTotalTime()} + + if done: + info["schedule_data"] = self.env.getScheduleData() + info["isDone"] = True + + return obs, reward, done, False, info + + action: jobshop.Action = self.action_map[action_idx] + self.env.step(action) + done: bool = self.env.isDone() or self.current_step >= self.max_steps + obs: np.ndarray = self.observation_space_impl.get_observation(self) + makespan: int = self.env.getTotalTime() + info: Dict[str, Any] = {'makespan': makespan} + if done: + info["schedule_data"] = self.env.getScheduleData() + info["isDone"] = self.env.isDone() + + reward: float = self.reward_function.calculate_reward(self, done) + self._update_action_mask() + + return obs, reward, done, False, info + + def _update_action_mask(self) -> None: + possible_actions: List[jobshop.Action] = self.env.getPossibleActions() + self._action_mask = np.zeros(self.max_num_actions, dtype=np.int8) + self.action_map.clear() + for action in possible_actions: + action_idx: int = self._action_to_index(action) + self._action_mask[action_idx] = 1 + self.action_map[action_idx] = action + + def action_masks(self) -> np.ndarray: + return self._action_mask + + def _action_to_index(self, action: jobshop.Action) -> int: + return self.action_indices[action.job, action.machine] + action.operation + + def get_jobshop_env(self) -> jobshop.JobShopEnvironment: + return self.env + +class MakespanCallback(BaseCallback): + def __init__(self, verbose: int = 0, plotter: Optional[jobshop.LivePlotter] = None): + super().__init__(verbose) + self.best_makespan: float = float('inf') + self.plotter = plotter + self.best_schedule_data: Optional[List[jobshop.ScheduleEntry]] = None + self.episode_count: int = 0 + self.episode_reward: float = 0 + self.episode_length: int = 0 + + def _on_step(self) -> bool: + self.episode_length += 1 + self.episode_reward += self.locals['rewards'][0] + + if self.locals['dones'][0]: + self._on_episode_end() + + return True + + def _on_episode_end(self) -> None: + self.episode_count += 1 + info = self.locals['infos'][0] + current_makespan: float = info['makespan'] + isDone: bool = info.get('isDone', False) + + if current_makespan < self.best_makespan and isDone: + self.best_makespan = current_makespan + self.best_schedule_data = info['schedule_data'] + + if self.plotter: + self.plotter.updateSchedule(self.best_schedule_data, current_makespan) + for _ in range(10): + self.plotter.render() + + self.logger.record("jobshop/best_makespan", self.best_makespan) + self.logger.record("jobshop/episode_reward", self.episode_reward) + self.logger.record("jobshop/episode_length", self.episode_length) + self.logger.record("jobshop/episode_makespan", current_makespan) + + self.episode_reward = 0 + self.episode_length = 0 + + def _on_training_end(self) -> None: + print(f"\nTraining completed.") + print(f"Total episodes: {self.episode_count}") + print(f"Best makespan achieved: {self.best_makespan}") + + if self.best_schedule_data: + print("\nBest Schedule:") + jobshop_env = self.training_env.env_method("get_jobshop_env")[0] + jobshop_env.printSchedule(self.best_schedule_data) + else: + print("No best schedule data available.") + +def run_experiment(algorithm_name: str, taillard_instance: str, use_gui: bool, max_steps: int, + observation_space: str, reward_function: str) -> None: + def make_env() -> Callable[[], gym.Env]: + instance: jobshop.TaillardInstance = getattr(jobshop.TaillardInstance, taillard_instance) + jobs, ta_optimal = jobshop.TaillardJobShopGenerator.loadProblem(instance, True) + print(f"Optimal makespan for {taillard_instance}: {ta_optimal}") + + obs_space = DefaultObservationSpace() if observation_space == "default" else NormalizedObservationSpace() + reward_func = MakespanRewardFunction() if reward_function == "makespan" else ProgressRewardFunction() + + return lambda: JobShopGymEnv(jobshop.JobShopEnvironment(jobs), max_steps, obs_space, reward_func) + + env: DummyVecEnv = DummyVecEnv([make_env()]) + model: MaskablePPO = MaskablePPO('MlpPolicy', env, verbose=1) + total_timesteps: int = 100000 + + plotter: Optional[jobshop.LivePlotter] = None + if use_gui: + jobshop_env = env.env_method("get_jobshop_env")[0] + plotter = jobshop.LivePlotter(jobshop_env.getNumMachines()) + + makespan_callback: MakespanCallback = MakespanCallback(plotter=plotter) + model.learn(total_timesteps=total_timesteps, callback=makespan_callback) + + print(f"Best makespan achieved: {makespan_callback.best_makespan}") + print(f"Optimal makespan: {ta_optimal}") + print(f"Gap: {(makespan_callback.best_makespan - ta_optimal) / ta_optimal * 100:.2f}%") + +if __name__ == "__main__": + parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Run Job Shop Scheduling experiment with PPO") + parser.add_argument("algorithm", choices=["PPO"], help="Algorithm type") + parser.add_argument("taillard_instance", choices=[f"TA{i:02d}" for i in range(1, 81)], help="Taillard instance") + parser.add_argument("--no-gui", action="store_true", help="Disable GUI") + parser.add_argument("--max-steps", type=int, default=1000, help="Maximum number of steps per episode") + parser.add_argument("--observation-space", choices=["default", "normalized"], default="normalized", help="Observation space type") + parser.add_argument("--reward-function", choices=["makespan", "progress"], default="progress", help="Reward function type") + args: argparse.Namespace = parser.parse_args() + + run_experiment(args.algorithm, args.taillard_instance, not args.no_gui, args.max_steps, + args.observation_space, args.reward_function) \ No newline at end of file diff --git a/jsp/examples/requirements.txt b/jsp/examples/requirements.txt new file mode 100644 index 0000000..fa9cf06 --- /dev/null +++ b/jsp/examples/requirements.txt @@ -0,0 +1 @@ +tqdm \ No newline at end of file diff --git a/jsp/examples/sample_application_simple.py b/jsp/examples/sample_application_simple.py new file mode 100644 index 0000000..7816415 --- /dev/null +++ b/jsp/examples/sample_application_simple.py @@ -0,0 +1,40 @@ +import argparse +import jobshop +from tqdm import tqdm + +def run_experiment(algorithm_name: str, taillard_instance: str): + algorithm_class = getattr(jobshop, f"JobShop{algorithm_name}") + instance = getattr(jobshop.TaillardInstance, taillard_instance) + jobs, ta_optimal = jobshop.TaillardJobShopGenerator.loadProblem(instance, True) + print(f"Optimal makespan for {taillard_instance}: {ta_optimal}") + + env = jobshop.JobShopEnvironment(jobs) + algorithm = algorithm_class(env, 0.1, 0.9, 0.3) + + print("Initial Schedule:") + algorithm.printBestSchedule() + + total_episodes = 10000 + best_makespan = float('inf') + + with tqdm(total=total_episodes, desc="Training") as pbar: + def callback(make_span: int): + nonlocal best_makespan + if make_span < best_makespan: + best_makespan = make_span + pbar.set_postfix_str(f"Best makespan: {best_makespan}") + pbar.update(1) + + algorithm.train(total_episodes, callback) + + print("\nFinal Best Schedule:") + algorithm.printBestSchedule() + print(f"Best makespan achieved: {best_makespan}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run Job Shop Scheduling experiment") + parser.add_argument("algorithm", choices=["QLearning", "ActorCritic"], help="Algorithm type") + parser.add_argument("taillard_instance", choices=[f"TA{i:02d}" for i in range(1, 81)], help="Taillard instance") + args = parser.parse_args() + + run_experiment(args.algorithm, args.taillard_instance) \ No newline at end of file diff --git a/jsp/examples/sample_application.py b/jsp/examples/sample_applicationmp.py similarity index 100% rename from jsp/examples/sample_application.py rename to jsp/examples/sample_applicationmp.py diff --git a/jsp/include/job_shop_environment.h b/jsp/include/job_shop_environment.h index cc780aa..a36dccc 100644 --- a/jsp/include/job_shop_environment.h +++ b/jsp/include/job_shop_environment.h @@ -49,6 +49,10 @@ struct State { , completedJobs(numJobs, false) , jobStartTimes(numJobs, -1) { jobProgress.fill(0); + + + + } }; diff --git a/pyproject.toml b/pyproject.toml index c84172e..af1acea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,27 +1,105 @@ -[build-system] -requires = ["setuptools>=42", "wheel", "cmake>=3.30", "nanobind==1.9.2", "ninja"] -build-backend = "setuptools.build_meta" - -[project] -name = "jobshop" -version = "0.1.0" -description = "Job Shop Scheduling Algorithms" -authors = [{name = "Per-Arne Andersen", email = "per@sysx.no"}] -license = {file = "LICENSE"} -readme = "README.md" -requires-python = ">=3.10" -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -"Homepage" = "https://github.com/cair/jobshop" -"Bug Tracker" = "https://github.com/cair/jobshop/issues" - -[tool.setuptools] -packages = ["jsp"] - -[tool.pytest.ini_options] -pythonpath = ["."] \ No newline at end of file +[build-system] +requires = ["setuptools>=42", "wheel", "cmake>=3.30", "nanobind==1.9.2", "ninja"] +build-backend = "setuptools.build_meta" + +[project] +name = "jobshop" +version = "0.1.0" +description = "Job Shop Scheduling Algorithms" +authors = [{name = "Per-Arne Andersen", email = "per@sysx.no"}] +license = {file = "LICENSE"} +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.urls] +"Homepage" = "https://github.com/cair/jobshop" +"Bug Tracker" = "https://github.com/cair/jobshop/issues" + +[tool.setuptools] +packages = ["jsp"] + +[tool.pytest.ini_options] +pythonpath = ["."] + + + + +[tool.cibuildwheel] +build = "*" +test-skip = "" +skip = [ + "pp*", + "*i686*", + "cp38-*", # Skip Python 3.8 + "cp39-*", # Skip Python 3.9 + "cp310-*", # Skip Python 3.10 + #"cp*-manylinux_x86_64", + "cp*-manylinux_i686", + #"cp*-manylinux_aarch64", + "cp*-manylinux_ppc64le", + "cp*-manylinux_s390x", + "cp*-manylinux_pypy_*_x86_64", + "cp*-manylinux_pypy_*_aarch64", + "cp*-win32", + #"cp*-win_amd64", + "cp*-linux_x86_64", + #"cp*-macosx_x86_64", + #"cp*-macosx_arm64", + "cp*-musllinux_x86_64", + "cp*-musllinux_aarch64", + "cp*-musllinux_ppc64le", + "cp*-musllinux_s390x", +] + +archs = ["auto"] +build-frontend = "build" +config-settings = {} +dependency-versions = "pinned" +environment = {} +environment-pass = [] +build-verbosity = "" + +before-all = "" +before-build = "rm -rf build" +repair-wheel-command = "" + +test-command = "" +before-test = "" +test-requires = [] +test-extras = [] + +container-engine = "docker" + +manylinux-x86_64-image = "manylinux_2_28" +manylinux-aarch64-image = "manylinux_2_28" +manylinux-ppc64le-image = "manylinux_2_28" +manylinux-s390x-image = "manylinux_2_28" +manylinux-pypy_x86_64-image = "manylinux_2_28" +manylinux-pypy_aarch64-image = "manylinux_2_28" + +musllinux-x86_64-image = "musllinux_1_1" +musllinux-aarch64-image = "musllinux_1_1" +musllinux-ppc64le-image = "musllinux_1_1" +musllinux-s390x-image = "musllinux_1_1" + +[tool.cibuildwheel.linux] +before-all = """ +dnf install -y perl-IPC-Cmd zip libXinerama-devel libXcursor-devel xorg-x11-server-devel mesa-libGLU-devel pkgconfig wayland-devel libxkbcommon-devel libXrandr-devel libXi-devel libXxf86vm-devel mesa-libGL-devel git python3-devel && \ +git clone https://github.com/Microsoft/vcpkg.git .vcpkg && \ +./.vcpkg/bootstrap-vcpkg.sh && \ +./.vcpkg/vcpkg integrate install +""" + +repair-wheel-command = "auditwheel repair -w {dest_dir} {wheel}" + +environment = { VCPKG_ROOT = "/project/.vcpkg"} + +[tool.cibuildwheel.macos] +repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel}" + +[tool.cibuildwheel.windows] \ No newline at end of file diff --git a/setup.py b/setup.py index 330f31a..727f802 100644 --- a/setup.py +++ b/setup.py @@ -1,97 +1,95 @@ -import os -import subprocess -import sys -from pathlib import Path -from setuptools import setup, Extension -from setuptools.command.build_ext import build_ext -import site -import sysconfig -import shutil - -class CMakeExtension(Extension): - def __init__(self, name, sourcedir=""): - Extension.__init__(self, name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) - -class CMakeBuild(build_ext): - def run(self): - try: - subprocess.check_output(["cmake", "--version"]) - except OSError: - raise RuntimeError("CMake must be installed to build the following extensions: " + - ", ".join(e.name for e in self.extensions)) - - for ext in self.extensions: - self.build_extension(ext) - - def build_extension(self, ext): - extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - - # Find nanobind - nanobind_path = None - for path in site.getsitepackages(): - candidate = os.path.join(path, 'nanobind', 'cmake') - if os.path.exists(os.path.join(candidate, 'nanobind-config.cmake')): - nanobind_path = candidate - break - - if nanobind_path is None: - raise RuntimeError("Could not find nanobind installation.") - - # Get Python information - python_include = sysconfig.get_path('include') - python_lib = sysconfig.get_config_var('LIBDIR') - - # Create build directory - build_temp = os.path.join(self.build_temp, ext.name) - if not os.path.exists(build_temp): - os.makedirs(build_temp) - - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DPYTHON_INCLUDE_DIR={python_include}", - f"-DPYTHON_LIBRARY={python_lib}", - f"-Dnanobind_DIR={nanobind_path}", - ] - - # Find Ninja - ninja_path = shutil.which("ninja") - if ninja_path: - print(f"Found Ninja at: {ninja_path}") - cmake_args.extend(["-GNinja", f"-DCMAKE_MAKE_PROGRAM={ninja_path}"]) - else: - print("Ninja not found, using default CMake generator") - - build_args = [ - "--config", "Release", - "--target", "jobshop" - ] - - env = os.environ.copy() - env["CXXFLAGS"] = f"{env.get('CXXFLAGS', '')} -DVERSION_INFO=\\\"{self.distribution.get_version()}\\\"" - - preset_file = os.path.join(ext.sourcedir, "jsp", "CMakePresets.json") - if os.path.exists(preset_file): - cmake_args.extend(["--preset", "pip-install"]) - - # Configure - subprocess.check_call(["cmake", os.path.join(ext.sourcedir, "jsp")] + cmake_args, cwd=build_temp, env=env) - - - # Build - subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) - -setup( - name="jobshop", - version="0.1.0", - author="Per-Arne Andersen", - author_email="per@sysx.no", - description="Job Shop Scheduling Algorithms", - long_description=Path("README.md").read_text(), - long_description_content_type="text/markdown", - ext_modules=[CMakeExtension("jobshop")], - cmdclass={"build_ext": CMakeBuild}, - zip_safe=False, - python_requires=">=3.10", +import os +import subprocess +import sys +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +import site +import sysconfig +import shutil + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + +class CMakeBuild(build_ext): + def run(self): + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) + + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # Find nanobind + nanobind_path = None + for path in site.getsitepackages(): + candidate = os.path.join(path, 'nanobind', 'cmake') + if os.path.exists(os.path.join(candidate, 'nanobind-config.cmake')): + nanobind_path = candidate + break + + if nanobind_path is None: + raise RuntimeError("Could not find nanobind installation.") + + # Get Python information + python_include = sysconfig.get_path('include') + python_lib = sysconfig.get_config_var('LIBDIR') + + # Create build directory + build_temp = os.path.join(self.build_temp, ext.name) + if not os.path.exists(build_temp): + os.makedirs(build_temp) + + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPython_EXECUTABLE={sys.executable}", + f"-Dnanobind_DIR={nanobind_path}", + ] + + # Find Ninja + ninja_path = shutil.which("ninja") + if ninja_path: + print(f"Found Ninja at: {ninja_path}") + cmake_args.extend(["-GNinja", f"-DCMAKE_MAKE_PROGRAM={ninja_path}"]) + else: + print("Ninja not found, using default CMake generator") + + build_args = [ + "--config", "Release", + "--target", "jobshop" + ] + + env = os.environ.copy() + env["CXXFLAGS"] = f"{env.get('CXXFLAGS', '')} -DVERSION_INFO=\\\"{self.distribution.get_version()}\\\"" + + preset_file = os.path.join(ext.sourcedir, "jsp", "CMakePresets.json") + if os.path.exists(preset_file): + cmake_args.extend(["--preset", "pip-install"]) + + # Configure + subprocess.check_call(["cmake", os.path.join(ext.sourcedir, "jsp")] + cmake_args, cwd=build_temp, env=env) + + + # Build + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) + +setup( + name="jobshop", + version="0.1.0", + author="Per-Arne Andersen", + author_email="per@sysx.no", + description="Job Shop Scheduling Algorithms", + long_description=Path("README.md").read_text(), + long_description_content_type="text/markdown", + ext_modules=[CMakeExtension("jobshop")], + cmdclass={"build_ext": CMakeBuild}, + zip_safe=False, + python_requires=">=3.10", ) \ No newline at end of file