Skip to content

Commit

Permalink
update topsrider env and generator code
Browse files Browse the repository at this point in the history
  • Loading branch information
aimsky committed Nov 7, 2023
1 parent 012f5fe commit e13afb5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 18 deletions.
2 changes: 1 addition & 1 deletion dipu/scripts/ci/topsrider/ci_topsrider_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export DIPU_LOCAL_DIR=/path/to/dipu

export DIOPI_ROOT=${DIPU_LOCAL_DIR}/third_party/DIOPI/impl/lib
export DIPU_ROOT=${DIPU_LOCAL_DIR}/torch_dipu
export LD_LIBRARY_PATH=$DIPU_ROOT:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$DIPU_ROOT:$DIOPI_ROOT:$LD_LIBRARY_PATH
export PYTHONPATH=${CONDA_ROOT}/envs/dipu/lib/python3.8:${DIPU_LOCAL_DIR}:${PYTHONPATH}
export PATH=${PYTORCH_DIR}/build/bin:${CONDA_ROOT}/envs/dipu/bin:${CONDA_ROOT}/bin:${PATH}

Expand Down
3 changes: 2 additions & 1 deletion dipu/scripts/ci/topsrider/ci_topsrider_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ function config_dipu_cmake() {
mkdir -p build && cd ./build && rm -rf ./*
cmake ../ -DCMAKE_BUILD_TYPE=Debug \
-DDEVICE=tops \
-DWITH_DIOPI=INTERNAL
# -DCMAKE_C_FLAGS_DEBUG="-g -O0" \
# -DCMAKE_CXX_FLAGS_DEBUG="-g -O0"

Expand All @@ -15,9 +16,9 @@ function config_all_cmake() {
mkdir -p build && cd ./build && rm -rf ./*
cmake ../ -DCMAKE_BUILD_TYPE=Debug \
-DDEVICE=tops \
-DWITH_DIOPI=INTERNAL
# -DCMAKE_C_FLAGS_DEBUG="-g -O0" \
# -DCMAKE_CXX_FLAGS_DEBUG="-g -O0"
-DWITH_DIOPI=INTERNAL
cd ../
}

Expand Down
3 changes: 2 additions & 1 deletion dipu/torch_dipu/csrc_dipu/vendor/topsrider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ message(STATUS "CMAKE_MODULE_PATH: " ${CMAKE_MODULE_PATH})
find_package(TOPSRT REQUIRED)
if(TOPSRT_FOUND)
set(VENDOR_INCLUDE_DIRS ${TOPSRT_INCLUDE_DIR} ${TOPSRT_INCLUDE_DIR}/.. /usr/include/eccl/ PARENT_SCOPE)
set(VENDOR_LIB_DIRS ${TOPSRT_LIBRARIES_DIR} PARENT_SCOPE)
set(DIPU_VENDOR_LIB topsrt eccl PARENT_SCOPE)
message("TOPSRT_INCLUDE_DIR:" ${TOPSRT_INCLUDE_DIR})
message("TOPSRT_LIBRARIES:" ${TOPSRT_LIBRARIES})
message("TOPSRT_LIBRARIES_DIR:" ${TOPSRT_LIBRARIES_DIR})
message("VENDOR_LIB_DIRS:" ${VENDOR_LIB_DIRS})
else()
message(FATAL_ERROR "Not found TOPSRT.")
Expand Down
42 changes: 31 additions & 11 deletions dipu/torch_dipu/csrc_dipu/vendor/topsrider/TopsGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,41 @@
#include <csrc_dipu/runtime/core/DIPUGeneratorImpl.h>

namespace dipu {
// just an example

static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;

class TopsGeneratorImpl : public dipu::DIPUGeneratorImpl {
protected:
mutable std::once_flag init_state_flag;

public:
TopsGeneratorImpl(at::DeviceIndex device_index): dipu::DIPUGeneratorImpl(device_index) {
}
TopsGeneratorImpl(at::DeviceIndex device_index) : dipu::DIPUGeneratorImpl(device_index) {}

void set_state(const c10::TensorImpl& new_state) override {
at::detail::check_rng_state(new_state);
auto new_state_size = new_state.numel();
TORCH_CHECK(new_state_size == total_size || new_state_size == total_size - offset_size, "RNG state is wrong size");

void set_state(const c10::TensorImpl& state) override {
}
at::Tensor state_tmp(new_state.shallow_copy_and_detach(new_state.version_counter(), true));
state_ = state_tmp;
state_need_reset_ = false;
}

void update_state() const override {
}
void update_state() const override {
if (state_need_reset_) {
state_ = at::detail::empty_cpu({(int64_t)total_size}, c10::ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_.data_ptr<uint8_t>();
uint64_t seed = this->current_seed();
int64_t offset = 0;
std::memcpy(rng_state, &seed, seed_size);
std::memcpy(rng_state + seed_size, &offset, offset_size);
state_need_reset_ = false;
}
}
};

const at::Generator vendorMakeGenerator(at::DeviceIndex device_index) {
return at::make_generator<TopsGeneratorImpl>(device_index);
}
const at::Generator vendorMakeGenerator(at::DeviceIndex device_index) { return at::make_generator<TopsGeneratorImpl>(device_index); }

} // namespace torch_dipu
} // namespace dipu
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ find_path(TOPSRT_INCLUDE_DIR
HINTS /usr/include/tops/ /usr/include/ /usr/include/dtu/ /usr/include/dtu/tops
/usr/include/dtu/3_0/runtime
/home/cse/src/install/usr/include/tops
/opt/tops/include/
/opt/tops/include/tops/
)

find_library(TOPSRT_LIBRARIES
NAMES topsrt
find_path(TOPSRT_LIBRARIES_DIR
NAMES libtopsrt.so
HINTS /usr/lib64
/usr/lib
/usr/local/lib64
/usr/local/lib
/home/cse/src/install/usr/lib
/opt/tops/lib/
)

find_package_handle_standard_args(TOPSRT DEFAULT_MSG
TOPSRT_INCLUDE_DIR
TOPSRT_LIBRARIES)
TOPSRT_LIBRARIES_DIR)

mark_as_advanced(TOPSRT_INCLUDE_DIR TOPSRT_LIBRARIES)
mark_as_advanced(TOPSRT_INCLUDE_DIR TOPSRT_LIBRARIES_DIR)

0 comments on commit e13afb5

Please sign in to comment.