diff --git a/dependencies.sh b/dependencies.sh index 01ba78c..0e035bc 100755 --- a/dependencies.sh +++ b/dependencies.sh @@ -58,4 +58,13 @@ echo "*** Download and installing [golang-1.22] ***" wget https://go.dev/dl/go1.22.10.linux-amd64.tar.gz sudo tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz +echo "*** Download and installing [abseil-cpp] ***" +cd ${REPO_ROOT}/thirdparties +git clone ${GITHUB_PROXY}/abseil/abseil-cpp.git +cd abseil-cpp +mkdir -p build +cd build +CXXFLAGS="-fPIC" CFLAGS="-fPIC" cmake .. -DCMAKE_POSITION_INDEPENDENT_CODE=ON +make -j$(nproc) && sudo make install + echo "*** Dependencies Installed! ***" diff --git a/mooncake-integration/vllm/vllm_adaptor.cpp b/mooncake-integration/vllm/vllm_adaptor.cpp index 85f45e3..474dd71 100644 --- a/mooncake-integration/vllm/vllm_adaptor.cpp +++ b/mooncake-integration/vllm/vllm_adaptor.cpp @@ -199,13 +199,13 @@ int VLLMAdaptor::transferSync(const char *target_hostname, uintptr_t buffer, entry.target_id = handle; entry.target_offset = peer_buffer_address; - int ret = engine_->submitTransfer(batch_id, {entry}); - if (ret < 0) return -1; + Status s = engine_->submitTransfer(batch_id, {entry}); + if (!s.ok()) return -1; TransferStatus status; while (true) { - int ret = engine_->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine_->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) { engine_->freeBatchID(batch_id); return 0; diff --git a/mooncake-integration/vllm/vllm_adaptor.h b/mooncake-integration/vllm/vllm_adaptor.h index b20e42a..2161c03 100644 --- a/mooncake-integration/vllm/vllm_adaptor.h +++ b/mooncake-integration/vllm/vllm_adaptor.h @@ -24,6 +24,7 @@ #include #include +#include "common/base/status.h" #include "transfer_engine.h" #include "transport/rdma_transport/rdma_transport.h" #include "transport/transport.h" diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index 233355b..49f395e 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -24,6 +24,7 @@ #include #include +#include "common/base/status.h" #include "transfer_engine.h" #include "transport/transport.h" @@ -159,7 +160,7 @@ static inline std::string calculateRate(uint64_t data_bytes, double duration) { volatile bool running = true; std::atomic total_batch_count(0); -int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, +Status initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, void *addr) { bindToSocket(thread_id % NR_SOCKETS); TransferRequest::OpCode opcode; @@ -183,7 +184,7 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, size_t batch_count = 0; while (running) { auto batch_id = engine->allocateBatchID(FLAGS_batch_size); - int ret = 0; + Status s; std::vector requests; for (int i = 0; i < FLAGS_batch_size; ++i) { TransferRequest entry; @@ -198,14 +199,14 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, requests.emplace_back(entry); } - ret = engine->submitTransfer(batch_id, requests); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, requests); + LOG_ASSERT(s.ok()); for (int task_id = 0; task_id < FLAGS_batch_size; ++task_id) { bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, task_id, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, task_id, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { @@ -216,13 +217,13 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, } } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); batch_count++; } LOG(INFO) << "Worker " << thread_id << " stopped!"; total_batch_count.fetch_add(batch_count); - return 0; + return Status::OK(); } std::string formatDeviceNames(const std::string &device_names) { diff --git a/mooncake-transfer-engine/include/common.h b/mooncake-transfer-engine/include/common.h index b717f65..95c8370 100644 --- a/mooncake-transfer-engine/include/common.h +++ b/mooncake-transfer-engine/include/common.h @@ -26,6 +26,8 @@ #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "error.h" #if defined(__x86_64__) diff --git a/mooncake-transfer-engine/include/common/base/status.h b/mooncake-transfer-engine/include/common/base/status.h new file mode 100644 index 0000000..4c46109 --- /dev/null +++ b/mooncake-transfer-engine/include/common/base/status.h @@ -0,0 +1,284 @@ +// Copyright 2025 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The design of this code is adapted from the RocksDB project with some +// modifications. +// https://github.com/facebook/rocksdb/blob/main/include/rocksdb/status.h + +#ifndef STATUS_H +#define STATUS_H + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace mooncake { + +class Status final { + public: + // The code of the status. + enum class Code : uint16_t { + kOk = 0, + kInvalidArgument = 1, + kTooManyRequests = 2, + kAddressNotRegistered = 3, + kBatchBusy = 4, + kDeviceNotFound = 6, + kAddressOverlapped = 7, + kDns = 101, + kSocket = 102, + kMalformedJson = 103, + kRejectHandshake = 104, + kMetadata = 200, + kEndpoint = 201, + kContext = 202, + kNuma = 300, + kClock = 301, + kMemory = 302, + kNotImplmented = 999, + kMaxCode + }; + + // Builds an OK Status. + Status() = default; + + ~Status() { delete[] message_; } + + // Constructs a Status object containing a status code and message. + // If 'code == Code::kOk', 'msg' is ignored and an object identical to an OK + // status is constructed. + Status(Code code, absl::string_view message); + + Status(const Status& s); + Status& operator=(const Status& s); + Status(Status&& s); + Status& operator=(Status&& s); + + // Returns the stored status code. + Code code() const { return code_; } + + // Return the error message (if any). + absl::string_view message() const { + if (message_) { + return message_; + } else { + return absl::string_view(); + } + } + + // Returns true if the Status is OK. + ABSL_MUST_USE_RESULT bool ok() const { return Code::kOk == code_; } + + // Returns true iff the status indicates an InvalidArgument error. + ABSL_MUST_USE_RESULT bool IsInvalidArgument() const { + return Code::kInvalidArgument == code_; + } + + // Returns true iff the status indicates a TooManyRequests error. + ABSL_MUST_USE_RESULT bool IsTooManyRequests() const { + return Code::kTooManyRequests == code_; + } + + // Returns true iff the status indicates an AddressNotRegistered error. + ABSL_MUST_USE_RESULT bool IsAddressNotRegistered() const { + return Code::kAddressNotRegistered == code_; + } + + // Returns true iff the status indicates a BatchBusy error. + ABSL_MUST_USE_RESULT bool IsBatchBusy() const { + return Code::kBatchBusy == code_; + } + + // Returns true iff the status indicates an DeviceNotFound error. + ABSL_MUST_USE_RESULT bool IsDeviceNotFound() const { + return Code::kDeviceNotFound == code_; + } + + // Returns true iff the status indicates an AddressOverlapped error. + ABSL_MUST_USE_RESULT bool IsAddressOverlapped() const { + return Code::kAddressOverlapped == code_; + } + + // Returns true iff the status indicates a dns error. + ABSL_MUST_USE_RESULT bool IsDns() const { + return Code::kDns == code_; + } + + // Returns true iff the status indicates an Socket error. + ABSL_MUST_USE_RESULT bool IsSocket() const { + return Code::kSocket == code_; + } + + // Returns true iff the status indicates a MalformedJson error. + ABSL_MUST_USE_RESULT bool IsMalformedJson() const { + return Code::kMalformedJson == code_; + } + + // Returns true iff the status indicates a RejectHandshake error. + ABSL_MUST_USE_RESULT bool IsRejectHandshake() const { + return Code::kRejectHandshake == code_; + } + + // Returns true iff the status indicates a Metadata error. + ABSL_MUST_USE_RESULT bool IsMetadata() const { + return Code::kMetadata == code_; + } + + // Returns true iff the status indicates an Endpoint error. + ABSL_MUST_USE_RESULT bool IsEndpoint() const { + return Code::kEndpoint == code_; + } + + // Returns true iff the status indicates a Context error. + ABSL_MUST_USE_RESULT bool IsContext() const { + return Code::kContext == code_; + } + + // Returns true iff the status indicates a Numa error. + ABSL_MUST_USE_RESULT bool IsNuma() const { + return Code::kNuma == code_; + } + + // Returns true iff the status indicates a Clock error. + ABSL_MUST_USE_RESULT bool IsClock() const { + return Code::kClock == code_; + } + + // Returns true iff the status indicates a Memory error. + ABSL_MUST_USE_RESULT bool IsMemory() const { + return Code::kMemory == code_; + } + + // Returns true iff the status indicates a NotImplmented error. + ABSL_MUST_USE_RESULT bool IsNotImplmented() const { + return Code::kNotImplmented == code_; + } + + // Return a combination of the error code name and message. + std::string ToString() const; + + bool operator==(const Status& s) const; + bool operator!=(const Status& s) const; + + // Return a status of an appropriate type. + static Status OK() { return Status(); } + static Status InvalidArgument(absl::string_view msg) { + return Status(Code::kInvalidArgument, msg); + } + static Status TooManyRequests(absl::string_view msg) { + return Status(Code::kTooManyRequests, msg); + } + static Status AddressNotRegistered(absl::string_view msg) { + return Status(Code::kAddressNotRegistered, msg); + } + static Status BatchBusy(absl::string_view msg) { + return Status(Code::kBatchBusy, msg); + } + static Status DeviceNotFound(absl::string_view msg) { + return Status(Code::kDeviceNotFound, msg); + } + static Status AddressOverlapped(absl::string_view msg) { + return Status(Code::kAddressOverlapped, msg); + } + static Status Dns(absl::string_view msg) { + return Status(Code::kDns, msg); + } + static Status Socket(absl::string_view msg) { + return Status(Code::kSocket, msg); + } + static Status MalformedJson(absl::string_view msg) { + return Status(Code::kMalformedJson, msg); + } + static Status RejectHandshake(absl::string_view msg) { + return Status(Code::kRejectHandshake, msg); + } + static Status Metadata(absl::string_view msg) { + return Status(Code::kMetadata, msg); + } + static Status Endpoint(absl::string_view msg) { + return Status(Code::kEndpoint, msg); + } + static Status Context(absl::string_view msg) { + return Status(Code::kContext, msg); + } + static Status Numa(absl::string_view msg) { + return Status(Code::kNuma, msg); + } + static Status Clock(absl::string_view msg) { + return Status(Code::kClock, msg); + } + static Status Memory(absl::string_view msg) { + return Status(Code::kMemory, msg); + } + static Status NotImplmented(absl::string_view msg) { + return Status(Code::kNotImplmented, msg); + } + + // Return a human-readable name of the 'code'. + static std::string_view CodeToString(Code code); + + private: + // Return a copy of the message 'msg'. + static const char* CopyMessage(const char* msg); + + // The code of the status. + Code code_ = Code::kOk; + // The error message of the status. Refer to the Status definition in RocksDB, + // we don't use 'std::string' type message but 'const char*' type one for the + // performance considerations. A memory allocation in the std::string + // construction could be avoid for the most cases that the Status is OK. And + // the total size of 'message_' is only 8 bytes on a x86-64 platform, while + // the size of a uninitialized strings with SSO (Small String Optimization) + // will be 24 to 32 bytes big, excluding the dynamically allocated memory. + const char* message_ = nullptr; +}; + +inline Status::Status(const Status& s) : code_(s.code_) { + message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); +} + +inline Status& Status::operator=(const Status& s) { + if (this != &s) { + code_ = s.code_; + delete[] message_; + message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); + } + return *this; +} + +inline Status::Status(Status&& s) : Status() { *this = std::move(s); } + +inline Status& Status::operator=(Status&& s) { + if (this != &s) { + code_ = std::move(s.code_); + s.code_ = Code::kOk; + delete[] message_; + message_ = nullptr; + std::swap(message_, s.message_); + } + return *this; +} + +// Prints a human-readable representation name of the 'code' to 'os'. +std::ostream& operator<<(std::ostream& os, Status::Code code); + +// Prints a human-readable representation of 's' to 'os'. +std::ostream& operator<<(std::ostream& os, const Status& s); + +} // namespace mooncake + +#endif // STATUS_H diff --git a/mooncake-transfer-engine/include/multi_transport.h b/mooncake-transfer-engine/include/multi_transport.h index cdbd37f..10294ae 100644 --- a/mooncake-transfer-engine/include/multi_transport.h +++ b/mooncake-transfer-engine/include/multi_transport.h @@ -36,12 +36,12 @@ class MultiTransport { BatchID allocateBatchID(size_t batch_size); - int freeBatchID(BatchID batch_id); + Status freeBatchID(BatchID batch_id); - int submitTransfer(BatchID batch_id, + Status submitTransfer(BatchID batch_id, const std::vector &entries); - int getTransferStatus(BatchID batch_id, size_t task_id, + Status getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status); Transport *installTransport(const std::string &proto, diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 0a0da73..ff09731 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -83,17 +83,17 @@ class TransferEngine { return multi_transports_->allocateBatchID(batch_size); } - int freeBatchID(BatchID batch_id) { + Status freeBatchID(BatchID batch_id) { return multi_transports_->freeBatchID(batch_id); } - int submitTransfer(BatchID batch_id, - const std::vector &entries) { + Status submitTransfer(BatchID batch_id, + const std::vector &entries) { return multi_transports_->submitTransfer(batch_id, entries); } - int getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { return multi_transports_->getTransferStatus(batch_id, task_id, status); } diff --git a/mooncake-transfer-engine/include/transfer_engine_c.h b/mooncake-transfer-engine/include/transfer_engine_c.h index 488ddcc..8515638 100644 --- a/mooncake-transfer-engine/include/transfer_engine_c.h +++ b/mooncake-transfer-engine/include/transfer_engine_c.h @@ -18,6 +18,8 @@ #include #include +#include "common/base/status.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -126,13 +128,14 @@ int unregisterLocalMemoryBatch(transfer_engine_t engine, void **addr_list, batch_id_t allocateBatchID(transfer_engine_t engine, size_t batch_size); -int submitTransfer(transfer_engine_t engine, batch_id_t batch_id, - struct transfer_request *entries, size_t count); +mooncake::Status submitTransfer(transfer_engine_t engine, batch_id_t batch_id, + struct transfer_request *entries, size_t count); -int getTransferStatus(transfer_engine_t engine, batch_id_t batch_id, - size_t task_id, struct transfer_status *status); +mooncake::Status getTransferStatus(transfer_engine_t engine, + batch_id_t batch_id, size_t task_id, + struct transfer_status *status); -int freeBatchID(transfer_engine_t engine, batch_id_t batch_id); +mooncake::Status freeBatchID(transfer_engine_t engine, batch_id_t batch_id); int syncSegmentCache(transfer_engine_t engine); diff --git a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h index efc1a57..a64cb6a 100644 --- a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h +++ b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h @@ -49,10 +49,10 @@ class CxlTransport : public Transport { int submitTransfer(BatchID batch_id, const std::vector &entries) override; - int getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) override; + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; - int freeBatchID(BatchID batch_id) override; + Status freeBatchID(BatchID batch_id) override; private: int install(std::string &local_server_name, diff --git a/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h b/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h index 1db3deb..fa51633 100644 --- a/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h +++ b/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h @@ -37,13 +37,13 @@ class NVMeoFTransport : public Transport { BatchID allocateBatchID(size_t batch_size) override; - int submitTransfer(BatchID batch_id, - const std::vector &entries) override; + Status submitTransfer(BatchID batch_id, + const std::vector &entries) override; - int getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) override; + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; - int freeBatchID(BatchID batch_id) override; + Status freeBatchID(BatchID batch_id) override; private: struct NVMeoFBatchDesc { diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h index 55b2913..4f2bccb 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h @@ -73,18 +73,18 @@ class RdmaTransport : public Transport { // TRANSFER - int submitTransfer(BatchID batch_id, + Status submitTransfer(BatchID batch_id, const std::vector &entries) override; - int submitTransferTask( + Status submitTransferTask( const std::vector &request_list, const std::vector &task_list) override; - int getTransferStatus(BatchID batch_id, - std::vector &status); + Status getTransferStatus(BatchID batch_id, + std::vector &status); - int getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) override; + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; SegmentID getSegmentID(const std::string &segment_name); diff --git a/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h b/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h index a302e7c..21c1bce 100644 --- a/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h +++ b/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h @@ -46,14 +46,14 @@ class TcpTransport : public Transport { ~TcpTransport(); - int submitTransfer(BatchID batch_id, + Status submitTransfer(BatchID batch_id, const std::vector &entries) override; - int submitTransferTask( + Status submitTransferTask( const std::vector &request_list, const std::vector &task_list) override; - int getTransferStatus(BatchID batch_id, size_t task_id, + Status getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) override; private: diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index b6180d8..c281dc8 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -26,6 +26,7 @@ #include #include +#include "common/base/status.h" #include "transfer_metadata.h" namespace mooncake { @@ -153,26 +154,27 @@ class Transport { virtual BatchID allocateBatchID(size_t batch_size); /// @brief Free an allocated batch. - virtual int freeBatchID(BatchID batch_id); + virtual Status freeBatchID(BatchID batch_id); /// @brief Submit a batch of transfer requests to the batch. /// @return The number of successfully submitted transfers on success. If /// that number is less than nr, errno is set. - virtual int submitTransfer(BatchID batch_id, + virtual Status submitTransfer(BatchID batch_id, const std::vector &entries) = 0; - virtual int submitTransferTask( + virtual Status submitTransferTask( const std::vector &request_list, const std::vector &task_list) { - return ERR_NOT_IMPLEMENTED; + return Status::NotImplmented( + "Transport::submitTransferTask is not implemented"); } /// @brief Get the status of a submitted transfer. This function shall not /// be called again after completion. /// @return Return 1 on completed (either success or failure); 0 if still in /// progress. - virtual int getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) = 0; + virtual Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) = 0; std::shared_ptr &meta() { return metadata_; } diff --git a/mooncake-transfer-engine/rust/src/transfer_engine.rs b/mooncake-transfer-engine/rust/src/transfer_engine.rs index c4ecb74..84f1ec4 100644 --- a/mooncake-transfer-engine/rust/src/transfer_engine.rs +++ b/mooncake-transfer-engine/rust/src/transfer_engine.rs @@ -193,7 +193,7 @@ impl TransferEngine { let ret = unsafe { bindings::submitTransfer(self.engine, batch_id, requests_c.as_mut_ptr(), requests.len()) }; - if ret < 0 { + if !ret.ok() { bail!("Failed to submit transfer") } else { Ok(()) @@ -207,7 +207,7 @@ impl TransferEngine { }; let ret = unsafe { bindings::getTransferStatus(self.engine, batch_id, task_id as usize, &mut status) }; - if ret < 0 { + if !ret.ok() { bail!("Failed to get transfer status") } else { Ok((status.status, status.transferred_bytes)) @@ -216,7 +216,7 @@ impl TransferEngine { pub fn free_batch_id(&self, batch_id: BatchID) -> Result<()> { let ret = unsafe { bindings::freeBatchID(self.engine, batch_id) }; - if ret < 0 { + if !ret.ok() { bail!("Failed to free batch ID") } else { Ok(()) diff --git a/mooncake-transfer-engine/src/CMakeLists.txt b/mooncake-transfer-engine/src/CMakeLists.txt index 0c8f8e7..1c2d2ac 100644 --- a/mooncake-transfer-engine/src/CMakeLists.txt +++ b/mooncake-transfer-engine/src/CMakeLists.txt @@ -1,8 +1,27 @@ file(GLOB ENGINE_SOURCES "*.cpp") +add_subdirectory(common) add_subdirectory(transport) SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +# Find Abseil's CMake targets +find_package(absl REQUIRED) +include_directories(${absl_INCLUDE_DIR}) +# Add all libs from absl into third party libs. +list(APPEND ABSL_LIBS absl::algorithm) +list(APPEND ABSL_LIBS absl::base) +list(APPEND ABSL_LIBS absl::debugging) +list(APPEND ABSL_LIBS absl::flat_hash_map) +list(APPEND ABSL_LIBS absl::flags) +list(APPEND ABSL_LIBS absl::memory) +list(APPEND ABSL_LIBS absl::meta) +list(APPEND ABSL_LIBS absl::numeric) +list(APPEND ABSL_LIBS absl::random_random) +list(APPEND ABSL_LIBS absl::strings) +list(APPEND ABSL_LIBS absl::synchronization) +list(APPEND ABSL_LIBS absl::time) +list(APPEND ABSL_LIBS absl::utility) + add_library(transfer_engine ${ENGINE_SOURCES} $) if (BUILD_SHARED_LIBS) install(TARGETS transfer_engine DESTINATION lib) @@ -23,7 +42,12 @@ if (USE_HTTP) find_package(CURL REQUIRED) target_link_libraries(transfer_engine PUBLIC ${CURL_LIBRARIES}) endif() -target_link_libraries(transfer_engine PUBLIC transport rdma_transport ibverbs glog gflags pthread jsoncpp numa) +target_link_libraries( + transfer_engine + PUBLIC + base transport rdma_transport ibverbs glog gflags pthread jsoncpp numa + ${ABSL_LIBS} + ) if (USE_CUDA) target_include_directories(transfer_engine PRIVATE /usr/local/cuda/include) diff --git a/mooncake-transfer-engine/src/common/CMakeLists.txt b/mooncake-transfer-engine/src/common/CMakeLists.txt new file mode 100644 index 0000000..0297e14 --- /dev/null +++ b/mooncake-transfer-engine/src/common/CMakeLists.txt @@ -0,0 +1,9 @@ +cmake_minimum_required(VERSION 3.20) + +# Add base sub directory. +add_subdirectory(base) +list(APPEND SUB_STATIC_LIBS base) + +set(SUB_STATIC_LIBS + ${SUB_STATIC_LIBS} + PARENT_SCOPE) diff --git a/mooncake-transfer-engine/src/common/base/CMakeLists.txt b/mooncake-transfer-engine/src/common/base/CMakeLists.txt new file mode 100644 index 0000000..de7e35d --- /dev/null +++ b/mooncake-transfer-engine/src/common/base/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.20) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# Add source files for base. +list( + APPEND + SRC + ${CMAKE_CURRENT_LIST_DIR}/status.cpp) + +# # Build all the source files of base dir into a static lib('base') +add_library(base STATIC ${SRC}) \ No newline at end of file diff --git a/mooncake-transfer-engine/src/common/base/status.cpp b/mooncake-transfer-engine/src/common/base/status.cpp new file mode 100644 index 0000000..9ebe0c9 --- /dev/null +++ b/mooncake-transfer-engine/src/common/base/status.cpp @@ -0,0 +1,126 @@ +// Copyright 2025 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The design of this code is adapted from the RocksDB project with some +// modifications. +// https://github.com/facebook/rocksdb/blob/main/util/status.cc + +#include "common/base/status.h" + +#include + +#include "absl/strings/str_cat.h" +#include "glog/logging.h" + +namespace mooncake { + +Status::Status(Status::Code code, absl::string_view message) + : code_(code) { + if (code != Code::kOk) { + // Only store the message when it is not empty. + if (!message.empty()) { + const size_t len = message.size(); + // +1 for null terminator + char* const result = new char[len + 1]; + memcpy(result, message.data(), len); + result[len] = '\0'; + message_ = result; + } + } +} + +std::string Status::ToString() const { + if (ok()) { + return "OK"; + } else { + return absl::StrCat(CodeToString(code()), ": ", message()); + } +} + +std::string_view Status::CodeToString(Status::Code code) { + switch (code) { + case Code::kOk: + return "OK"; + case Code::kInvalidArgument: + return "InvalidArgument"; + case Code::kTooManyRequests: + return "TooManyRequests"; + case Code::kAddressNotRegistered: + return "AddressNotRegistered"; + case Code::kBatchBusy: + return "BatchBusy"; + case Code::kDeviceNotFound: + return "DeviceNotFound"; + case Code::kAddressOverlapped: + return "AddressOverlapped"; + case Code::kDns: + return "Dns"; + case Code::kSocket: + return "Socket"; + case Code::kMalformedJson: + return "MalformedJson"; + case Code::kRejectHandshake: + return "RejectHandshake"; + case Code::kMetadata: + return "Metadata"; + case Code::kEndpoint: + return "Endpoint"; + case Code::kContext: + return "Context"; + case Code::kNuma: + return "Numa"; + case Code::kClock: + return "Clock"; + case Code::kMemory: + return "Memory"; + case Code::kNotImplmented: + return "NotImplmented"; + default: + LOG(ERROR) << "Unknown code: " << static_cast(code); + return absl::StrCat(code); + } +} + +const char* Status::CopyMessage(const char* msg) { + // +1 for the null terminator + const size_t len = std::strlen(msg) + 1; + return std::strncpy(new char[len], msg, len); +} + +bool Status::operator==(const Status& s) const { + // Compare the code. + if (code_ != s.code_) { + return false; + } + // Compare the message content. + if (message_ == nullptr && s.message_ == nullptr) { + return true; + } + if (message_ != nullptr && s.message_ != nullptr) { + return strcmp(message_, s.message_) == 0; + } + return false; +} + +bool Status::operator!=(const Status& s) const { return !(*this == s); } + +std::ostream& operator<<(std::ostream& os, Status::Code code) { + return os << Status::CodeToString(code); +} + +std::ostream& operator<<(std::ostream& os, const Status& s) { + return os << s.ToString(); +} + +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index ca19c16..fcdd240 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -47,29 +47,30 @@ MultiTransport::BatchID MultiTransport::allocateBatchID(size_t batch_size) { return batch_desc->id; } -int MultiTransport::freeBatchID(BatchID batch_id) { +Status MultiTransport::freeBatchID(BatchID batch_id) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); for (size_t task_id = 0; task_id < task_count; task_id++) { if (!batch_desc.task_list[task_id].is_finished) { LOG(ERROR) << "BatchID cannot be freed until all tasks are done"; - return ERR_BATCH_BUSY; - } + return Status::BatchBusy( + "BatchID cannot be freed until all tasks are done"); } } delete &batch_desc; #ifdef CONFIG_USE_BATCH_DESC_SET RWSpinlock::WriteGuard guard(batch_desc_lock_); batch_desc_set_.erase(batch_id); #endif - return 0; + return Status::OK(); } -int MultiTransport::submitTransfer( +Status MultiTransport::submitTransfer( BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { LOG(ERROR) << "MultiTransport: Exceed the limitation of batch capacity"; - return ERR_TOO_MANY_REQUESTS; + return Status::TooManyRequests( + "Exceed the limitation of batch capacity"); } size_t task_id = batch_desc.task_list.size(); @@ -81,7 +82,10 @@ int MultiTransport::submitTransfer( std::unordered_map submit_tasks; for (auto &request : entries) { auto transport = selectTransport(request); - if (!transport) return ERR_INVALID_ARGUMENT; + if (!transport) { + return Status::InvalidArgument(absl::StrCat( + "SelectTransport failed for SegmentID: ", request.target_id)); + } auto &task = batch_desc.task_list[task_id]; ++task_id; submit_tasks[transport].request_list.push_back( @@ -89,22 +93,25 @@ int MultiTransport::submitTransfer( submit_tasks[transport].task_list.push_back(&task); } for (auto &entry : submit_tasks) { - int ret = entry.first->submitTransferTask(entry.second.request_list, + auto status = entry.first->submitTransferTask(entry.second.request_list, entry.second.task_list); - if (ret) { + if (!status.ok()) { LOG(ERROR) << "MultiTransport: Failed to submit transfer task to " << entry.first->getName(); - return ret; + return status; } } - return 0; + return Status::OK(); } -int MultiTransport::getTransferStatus(BatchID batch_id, size_t task_id, +Status MultiTransport::getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); - if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + if (task_id >= task_count) { + return Status::InvalidArgument( + "MultiTransport: task id is equal to or larger than task_count"); + } auto &task = batch_desc.task_list[task_id]; status.transferred_bytes = task.transferred_bytes; uint64_t success_slice_count = task.success_slice_count; @@ -119,7 +126,7 @@ int MultiTransport::getTransferStatus(BatchID batch_id, size_t task_id, } else { status.s = Transport::TransferStatusEnum::WAITING; } - return 0; + return Status::OK(); } Transport *MultiTransport::installTransport(const std::string &proto, diff --git a/mooncake-transfer-engine/src/transfer_engine_c.cpp b/mooncake-transfer-engine/src/transfer_engine_c.cpp index f2c6524..7881820 100644 --- a/mooncake-transfer-engine/src/transfer_engine_c.cpp +++ b/mooncake-transfer-engine/src/transfer_engine_c.cpp @@ -110,8 +110,9 @@ batch_id_t allocateBatchID(transfer_engine_t engine, size_t batch_size) { return (batch_id_t)native->allocateBatchID(batch_size); } -int submitTransfer(transfer_engine_t engine, batch_id_t batch_id, - struct transfer_request *entries, size_t count) { +Status submitTransfer(transfer_engine_t engine, batch_id_t batch_id, + struct transfer_request *entries, + size_t count) { TransferEngine *native = (TransferEngine *)engine; std::vector native_entries; native_entries.resize(count); @@ -126,20 +127,21 @@ int submitTransfer(transfer_engine_t engine, batch_id_t batch_id, return native->submitTransfer((Transport::BatchID)batch_id, native_entries); } -int getTransferStatus(transfer_engine_t engine, batch_id_t batch_id, - size_t task_id, struct transfer_status *status) { +Status getTransferStatus(transfer_engine_t engine, + batch_id_t batch_id, size_t task_id, + struct transfer_status *status) { TransferEngine *native = (TransferEngine *)engine; Transport::TransferStatus native_status; - int rc = native->getTransferStatus((Transport::BatchID)batch_id, task_id, - native_status); - if (rc == 0) { + Status s = native->getTransferStatus((Transport::BatchID)batch_id, + task_id, native_status); + if (s.ok()) { status->status = (int)native_status.s; status->transferred_bytes = native_status.transferred_bytes; } - return rc; + return s; } -int freeBatchID(transfer_engine_t engine, batch_id_t batch_id) { +Status freeBatchID(transfer_engine_t engine, batch_id_t batch_id) { TransferEngine *native = (TransferEngine *)engine; return native->freeBatchID(batch_id); } diff --git a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp index 862e723..e0dc0d4 100644 --- a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp +++ b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp @@ -41,12 +41,12 @@ CxlTransport::BatchID CxlTransport::allocateBatchID(size_t batch_size) { return batch_id; } -int CxlTransport::getTransferStatus(BatchID batch_id, size_t task_id, +Status CxlTransport::getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) { return 0; } -int CxlTransport::submitTransfer(BatchID batch_id, +Status CxlTransport::submitTransfer(BatchID batch_id, const std::vector &entries) { return 0; } diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index ef92efc..1d7c1f4 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -193,13 +193,15 @@ int RdmaTransport::unregisterLocalMemoryBatch( return metadata_->updateLocalSegmentDesc(); } -int RdmaTransport::submitTransfer(BatchID batch_id, +Status RdmaTransport::submitTransfer(BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { LOG(ERROR) << "RdmaTransport: Exceed the limitation of current batch's " "capacity"; - return ERR_TOO_MANY_REQUESTS; + return Status::InvalidArgument(absl::StrCat( + "RdmaTransport: Exceed the limitation of capacity, batch id: ", + batch_id)); } std::unordered_map, std::vector> @@ -245,16 +247,18 @@ int RdmaTransport::submitTransfer(BatchID batch_id, LOG(ERROR) << "RdmaTransport: Address not registered by any device(s) " << slice->source_addr; - return ERR_ADDRESS_NOT_REGISTERED; + return Status::AddressNotRegistered(absl::StrCat( + "RdmaTransport: not registered by any device(s), address:", + absl::StrFormat("%p", slice->source_addr))); } } } for (auto &entry : slices_to_post) entry.first->submitPostSend(entry.second); - return 0; + return Status::OK(); } -int RdmaTransport::submitTransferTask( +Status RdmaTransport::submitTransferTask( const std::vector &request_list, const std::vector &task_list) { std::unordered_map, std::vector> @@ -298,17 +302,19 @@ int RdmaTransport::submitTransferTask( LOG(ERROR) << "RdmaTransport: Address not registered by any device(s) " << slice->source_addr; - return ERR_ADDRESS_NOT_REGISTERED; + return Status::AddressNotRegistered(absl::StrCat( + "RdmaTransport: not registered by any device(s), address:", + absl::StrFormat("%p", slice->source_addr))); } } } for (auto &entry : slices_to_post) entry.first->submitPostSend(entry.second); - return 0; + return Status::OK(); } -int RdmaTransport::getTransferStatus(BatchID batch_id, - std::vector &status) { +Status RdmaTransport::getTransferStatus(BatchID batch_id, + std::vector &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); status.resize(task_count); @@ -328,14 +334,18 @@ int RdmaTransport::getTransferStatus(BatchID batch_id, status[task_id].s = TransferStatusEnum::WAITING; } } - return 0; + return Status::OK(); } -int RdmaTransport::getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { +Status RdmaTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); - if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + if (task_id >= task_count) { + return Status::InvalidArgument(absl::StrCat( + "RdmaTransport::getTransportStatus invalid argument, batch id:", + batch_id)); + } auto &task = batch_desc.task_list[task_id]; status.transferred_bytes = task.transferred_bytes; uint64_t success_slice_count = task.success_slice_count; @@ -350,7 +360,7 @@ int RdmaTransport::getTransferStatus(BatchID batch_id, size_t task_id, } else { status.s = TransferStatusEnum::WAITING; } - return 0; + return Status::OK(); } RdmaTransport::SegmentID RdmaTransport::getSegmentID( diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index a9d4cc4..b3cd5b9 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -266,11 +266,15 @@ int TcpTransport::unregisterLocalMemoryBatch( return metadata_->updateLocalSegmentDesc(); } -int TcpTransport::getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { +Status TcpTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); - if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + if (task_id >= task_count) { + return Status::InvalidArgument(absl::StrCat( + "TcpTransport::getTransportStatus invalid argument, batch id:", + batch_id)); + } auto &task = batch_desc.task_list[task_id]; status.transferred_bytes = task.transferred_bytes; uint64_t success_slice_count = task.success_slice_count; @@ -286,16 +290,18 @@ int TcpTransport::getTransferStatus(BatchID batch_id, size_t task_id, } else { status.s = TransferStatusEnum::WAITING; } - return 0; + return Status::OK(); } -int TcpTransport::submitTransfer(BatchID batch_id, +Status TcpTransport::submitTransfer(BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { LOG(ERROR) << "TcpTransport: Exceed the limitation of current batch's " "capacity"; - return ERR_TOO_MANY_REQUESTS; + return Status::InvalidArgument(absl::StrCat( + "TcpTransport: Exceed the limitation of capacity, batch id: ", + batch_id)); } size_t task_id = batch_desc.task_list.size(); @@ -317,10 +323,10 @@ int TcpTransport::submitTransfer(BatchID batch_id, startTransfer(slice); } - return 0; + return Status::OK(); } -int TcpTransport::submitTransferTask( +Status TcpTransport::submitTransferTask( const std::vector &request_list, const std::vector &task_list) { for (size_t index = 0; index < request_list.size(); ++index) { @@ -338,7 +344,7 @@ int TcpTransport::submitTransferTask( task.slice_count += 1; startTransfer(slice); } - return 0; + return Status::OK(); } void TcpTransport::worker() { diff --git a/mooncake-transfer-engine/src/transport/transport.cpp b/mooncake-transfer-engine/src/transport/transport.cpp index 12df7eb..038825a 100644 --- a/mooncake-transfer-engine/src/transport/transport.cpp +++ b/mooncake-transfer-engine/src/transport/transport.cpp @@ -33,13 +33,14 @@ Transport::BatchID Transport::allocateBatchID(size_t batch_size) { return batch_desc->id; } -int Transport::freeBatchID(BatchID batch_id) { +Status Transport::freeBatchID(BatchID batch_id) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); for (size_t task_id = 0; task_id < task_count; task_id++) { if (!batch_desc.task_list[task_id].is_finished) { LOG(ERROR) << "BatchID cannot be freed until all tasks are done"; - return ERR_BATCH_BUSY; + return Status::BatchBusy( + "BatchID cannot be freed until all tasks are done"); } } delete &batch_desc; @@ -47,7 +48,7 @@ int Transport::freeBatchID(BatchID batch_id) { RWSpinlock::WriteGuard guard(batch_desc_lock_); batch_desc_set_.erase(batch_id); #endif - return 0; + return Status::OK(); } int Transport::install(std::string &local_server_name, diff --git a/mooncake-transfer-engine/tests/rdma_transport_test.cpp b/mooncake-transfer-engine/tests/rdma_transport_test.cpp index 2019aa4..92e7b0a 100644 --- a/mooncake-transfer-engine/tests/rdma_transport_test.cpp +++ b/mooncake-transfer-engine/tests/rdma_transport_test.cpp @@ -137,7 +137,7 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, LOG(INFO) << "Write Data: " << std::string((char *)(addr), 16) << "..."; auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -145,13 +145,13 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { @@ -159,14 +159,14 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, completed = true; } } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); } { LOG(INFO) << "Stage 2: Read Data"; auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::READ; @@ -174,13 +174,13 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { @@ -188,8 +188,8 @@ int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, completed = true; } } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); } int ret = diff --git a/mooncake-transfer-engine/tests/rdma_transport_test2.cpp b/mooncake-transfer-engine/tests/rdma_transport_test2.cpp index a71048d..c9c065d 100644 --- a/mooncake-transfer-engine/tests/rdma_transport_test2.cpp +++ b/mooncake-transfer-engine/tests/rdma_transport_test2.cpp @@ -157,20 +157,20 @@ TEST_F(RDMATransportTest, MultiWrite) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { @@ -178,8 +178,8 @@ TEST_F(RDMATransportTest, MultiWrite) { completed = true; } } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } } @@ -191,20 +191,20 @@ TEST_F(RDMATransportTest, MultipleRead) { *((char *)(addr) + offset) = 'a' + lrand48() % 26; auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { @@ -212,8 +212,8 @@ TEST_F(RDMATransportTest, MultipleRead) { completed = true; } } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } times = 10; while (times--) { @@ -225,21 +225,22 @@ TEST_F(RDMATransportTest, MultipleRead) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - ASSERT_EQ(ret, 0); + Status s; + s = engine->submitTransfer(batch_id, {entry}); + ASSERT_EQ(s, Status::OK()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { completed = true; } } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); ret = memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, kDataLength); ASSERT_EQ(ret, 0); diff --git a/mooncake-transfer-engine/tests/tcp_transport_test.cpp b/mooncake-transfer-engine/tests/tcp_transport_test.cpp index 1269d3b..1915732 100644 --- a/mooncake-transfer-engine/tests/tcp_transport_test.cpp +++ b/mooncake-transfer-engine/tests/tcp_transport_test.cpp @@ -32,6 +32,8 @@ #include +#include "common/base/status.h" + static void checkCudaError(cudaError_t result, const char *message) { if (result != cudaSuccess) { LOG(ERROR) << message << " (Error code: " << result << " - " @@ -121,7 +123,7 @@ TEST_F(TCPTransportTest, Writetest) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; auto segment_id = engine->openSegment(local_server_name); TransferRequest entry; auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); @@ -131,18 +133,18 @@ TEST_F(TCPTransportTest, Writetest) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } TEST_F(TCPTransportTest, WriteAndReadtest) { @@ -169,30 +171,30 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::READ; @@ -200,18 +202,18 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, kDataLength)); @@ -242,48 +244,48 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::READ; entry.length = kDataLength; entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, kDataLength)); @@ -292,48 +294,48 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { *((char *)(addr) + offset) = 'a' + lrand48() % 26; { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - ASSERT_EQ(ret, 0); + Status s = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(s, Status::OK()); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = engine->freeBatchID(batch_id); - ASSERT_EQ(ret, 0); + s = engine->freeBatchID(batch_id); + ASSERT_EQ(s, Status::OK()); } { auto batch_id = engine->allocateBatchID(1); - int ret = 0; + Status s; TransferRequest entry; entry.opcode = TransferRequest::READ; entry.length = kDataLength; entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = engine->submitTransfer(batch_id, {entry}); - LOG_ASSERT(!ret); + s = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(s.ok()); bool completed = false; TransferStatus status; while (!completed) { - int ret = engine->getTransferStatus(batch_id, 0, status); - LOG_ASSERT(!ret); + Status s = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = engine->freeBatchID(batch_id); - LOG_ASSERT(!ret); + s = engine->freeBatchID(batch_id); + LOG_ASSERT(s.ok()); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, kDataLength));