From 8da6bdced6a2bb14c7fae7ac494a801cdcbf4553 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 13:32:51 -0500 Subject: [PATCH 01/10] feat: use lru_dedup_dict for rank call --- external_parser/CMakeLists.txt | 4 +- external_parser/joiners/example_joiner.h | 2 +- external_parser/joiners/i_joiner.h | 2 +- .../unit_tests/test_lru_dedup_cache.cc | 2 +- include/live_model.h | 10 ++++ include/model_mgmt.h | 1 + rlclientlib/CMakeLists.txt | 2 + rlclientlib/live_model.cc | 6 +++ rlclientlib/live_model_impl.cc | 5 ++ rlclientlib/live_model_impl.h | 1 + .../lru_dedup_cache.cc | 0 .../lru_dedup_cache.h | 1 + rlclientlib/vw_model/pdf_model.cc | 8 ++- rlclientlib/vw_model/pdf_model.h | 1 + rlclientlib/vw_model/safe_vw.cc | 50 +++++++++++++------ rlclientlib/vw_model/safe_vw.h | 22 ++++---- rlclientlib/vw_model/vw_model.cc | 13 +++-- rlclientlib/vw_model/vw_model.h | 3 ++ unit_test/live_model_test.cc | 50 ++++++++++++++++++- unit_test/safe_vw_test.cc | 18 +++---- 20 files changed, 157 insertions(+), 44 deletions(-) rename {external_parser => rlclientlib}/lru_dedup_cache.cc (100%) rename {external_parser => rlclientlib}/lru_dedup_cache.h (95%) diff --git a/external_parser/CMakeLists.txt b/external_parser/CMakeLists.txt index 99a822c0d..1a187291d 100644 --- a/external_parser/CMakeLists.txt +++ b/external_parser/CMakeLists.txt @@ -135,7 +135,7 @@ set(binary_parser_headers ${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h ${CMAKE_CURRENT_LIST_DIR}/log_converter.h - ${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h @@ -146,7 +146,7 @@ set(binary_parser_sources ${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/log_converter.cc - ${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc diff --git a/external_parser/joiners/example_joiner.h b/external_parser/joiners/example_joiner.h index efc11127c..bd95ac5f0 100644 --- a/external_parser/joiners/example_joiner.h +++ b/external_parser/joiners/example_joiner.h @@ -3,7 +3,7 @@ #include "event_processors/joined_event.h" #include "event_processors/loop.h" #include "joiners/i_joiner.h" -#include "lru_dedup_cache.h" +#include "../rlclientlib/lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/external_parser/joiners/i_joiner.h b/external_parser/joiners/i_joiner.h index c7dd5c8f4..3130fcb66 100644 --- a/external_parser/joiners/i_joiner.h +++ b/external_parser/joiners/i_joiner.h @@ -4,7 +4,7 @@ #include "generated/v2/CbEvent_generated.h" #include "generated/v2/FileFormat_generated.h" #include "generated/v2/Metadata_generated.h" -#include "lru_dedup_cache.h" +#include "../rlclientlib/lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/external_parser/unit_tests/test_lru_dedup_cache.cc b/external_parser/unit_tests/test_lru_dedup_cache.cc index f6abdc267..020231ff1 100644 --- a/external_parser/unit_tests/test_lru_dedup_cache.cc +++ b/external_parser/unit_tests/test_lru_dedup_cache.cc @@ -1,6 +1,6 @@ #include -#include "lru_dedup_cache.h" +#include "../rlclientlib/lru_dedup_cache.h" #include "parse_example_external.h" #include "test_common.h" #include "vw/config/options_cli.h" diff --git a/include/live_model.h b/include/live_model.h index 7ece185da..47a3df7e5 100644 --- a/include/live_model.h +++ b/include/live_model.h @@ -107,6 +107,16 @@ class live_model */ int init(api_status* status = nullptr); + /** + * @brief Load dedup cache. + * Load the dedup cache from the specified file. This cache is used to + * prevent duplicate actions from being sent to the online trainer. + * @param hash Hash of the dedup cache + * @param action_str Action string + * @return int Return error code. This will also be returned in the api_status object + */ + int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status); + /** * @brief Choose an action, given a list of actions, action features and context features. The * inference library chooses an action by creating a probability distribution over the actions diff --git a/include/model_mgmt.h b/include/model_mgmt.h index 4b381dba7..aec5c33df 100644 --- a/include/model_mgmt.h +++ b/include/model_mgmt.h @@ -74,6 +74,7 @@ class i_model { public: virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0; + virtual int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) = 0; virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) = 0; virtual int choose_continuous_action(string_view features, float& action, float& pdf_value, diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index f2a58afb7..fa0367e4f 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -68,6 +68,7 @@ set(PROJECT_SOURCES logger/logger_facade.cc logger/preamble.cc logger/preamble_sender.cc + lru_dedup_cache.cc model_mgmt/data_callback_fn.cc model_mgmt/empty_data_transport.cc model_mgmt/file_model_loader.cc @@ -149,6 +150,7 @@ set(PROJECT_PRIVATE_HEADERS logger/async_batcher.h logger/event_logger.h logger/logger_facade.h + lru_dedup_cache.h model_mgmt/data_callback_fn.h model_mgmt/empty_data_transport.h model_mgmt/file_model_loader.h diff --git a/rlclientlib/live_model.cc b/rlclientlib/live_model.cc index ab4776764..cb1abc6f8 100644 --- a/rlclientlib/live_model.cc +++ b/rlclientlib/live_model.cc @@ -61,6 +61,12 @@ std::vector live_model::c_array_to_vector(const int* c_array, size_t array_ return std::vector(c_array, c_array + array_size); } +int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +{ + INIT_CHECK(); + return _pimpl->add_lru_dedup_cache(hash, action_str, status); +} + int live_model::choose_rank( const char* event_id, string_view context_json, ranking_response& response, api_status* status) { diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index ebd1c8c22..173ac2210 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status) return error_code::success; } +int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +{ + return _model->add_lru_dedup_cache(hash, action_str, status); +} + int live_model_impl::choose_rank( const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status) { diff --git a/rlclientlib/live_model_impl.h b/rlclientlib/live_model_impl.h index 8a5a0f237..da4f155ec 100644 --- a/rlclientlib/live_model_impl.h +++ b/rlclientlib/live_model_impl.h @@ -28,6 +28,7 @@ class live_model_impl int init(api_status* status); + int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status); int choose_rank( const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status); // here the event_id is auto-generated diff --git a/external_parser/lru_dedup_cache.cc b/rlclientlib/lru_dedup_cache.cc similarity index 100% rename from external_parser/lru_dedup_cache.cc rename to rlclientlib/lru_dedup_cache.cc diff --git a/external_parser/lru_dedup_cache.h b/rlclientlib/lru_dedup_cache.h similarity index 95% rename from external_parser/lru_dedup_cache.h rename to rlclientlib/lru_dedup_cache.h index 408dedda6..b51d29479 100644 --- a/external_parser/lru_dedup_cache.h +++ b/rlclientlib/lru_dedup_cache.h @@ -35,6 +35,7 @@ struct lru_dedup_cache void* context = nullptr); bool exists(uint64_t dedup_id); void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr); + std::unordered_map* get_dict() { return &dedup_examples; } lru_dedup_cache() = default; ~lru_dedup_cache() = default; diff --git a/rlclientlib/vw_model/pdf_model.cc b/rlclientlib/vw_model/pdf_model.cc index 8c08da0b5..4ebb46a02 100644 --- a/rlclientlib/vw_model/pdf_model.cc +++ b/rlclientlib/vw_model/pdf_model.cc @@ -13,7 +13,7 @@ namespace model_management // We construct a VW object here to use the example parser to parse joined dsjson-style examples // to extract the PDF. pdf_model::pdf_model(i_trace* trace_logger, const utility::configuration& /*unused*/) - : _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf")) + : _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf", nullptr)) { } @@ -23,6 +23,12 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta return error_code::success; } +// TODO: Implement LRU cache for PDF models. +int pdf_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +{ + return error_code::not_supported; +} + int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { diff --git a/rlclientlib/vw_model/pdf_model.h b/rlclientlib/vw_model/pdf_model.h index 883e1fbee..438b08bb9 100644 --- a/rlclientlib/vw_model/pdf_model.h +++ b/rlclientlib/vw_model/pdf_model.h @@ -20,6 +20,7 @@ class pdf_model : public i_model public: pdf_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; + int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index 93db5e62b..eabf3dcde 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -19,13 +19,13 @@ namespace reinforcement_learning { static const std::string SEED_TAG = "seed="; -safe_vw::safe_vw(std::shared_ptr master) : _master(std::move(master)) +safe_vw::safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache) : _master(std::move(master)), _dedup_cache(dedup_cache) { _vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr); init(); } -safe_vw::safe_vw(const char* model_data, size_t len) +safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) { io_buf buf; buf.add_file(VW::io::create_buffer_view(model_data, len)); @@ -34,7 +34,7 @@ safe_vw::safe_vw(const char* model_data, size_t len) init(); } -safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline) +safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) { io_buf buf; buf.add_file(VW::io::create_buffer_view(model_data, len)); @@ -43,7 +43,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_comma init(); } -safe_vw::safe_vw(const std::string& vw_commandline) +safe_vw::safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) { _vw = VW::initialize(vw_commandline); init(); @@ -120,6 +120,24 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector& acti for (auto&& ex : examples) { _example_pool.emplace_back(ex); } } +void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str) +{ + if (_dedup_cache == nullptr) + { + _dedup_cache = new lru_dedup_cache(); + } + VW::multi_ex examples; + examples.push_back(get_or_create_example()); + + if (_vw->audit) + { + _vw->audit_buffer->clear(); + VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); + } + else { VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } + _dedup_cache->add(hash, examples[0]); +} + void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores) { VW::multi_ex examples; @@ -131,9 +149,9 @@ void safe_vw::rank(string_view context, std::vector& actions, std::vectoraudit) { _vw->audit_buffer->clear(); - VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); + VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); } - else { VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); } + else { VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); } // finalize example VW::setup_examples(*_vw, examples); @@ -372,19 +390,19 @@ void safe_vw::init() } } -safe_vw_factory::safe_vw_factory(std::string command_line) : _command_line(std::move(command_line)) {} +safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache) : _command_line(std::move(command_line)), _dedup_cache(dedup_cache) {} -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) : _master_data(master_data) {} +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache) : _master_data(master_data), _dedup_cache(dedup_cache) {} -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) : _master_data(master_data) {} +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache) : _master_data(master_data), _dedup_cache(dedup_cache) {} -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line) - : _master_data(master_data), _command_line(std::move(command_line)) +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache) + : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) { } -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line) - : _master_data(master_data), _command_line(std::move(command_line)) +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache) + : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) { } @@ -393,13 +411,13 @@ safe_vw* safe_vw_factory::operator()() if ((_master_data.data() != nullptr) && !_command_line.empty()) { // Construct new vw object from raw model data and command line argument - return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line); + return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line, _dedup_cache); } if (_master_data.data() != nullptr) { // Construct new vw object from raw model data. - return new safe_vw(_master_data.data(), _master_data.data_sz()); + return new safe_vw(_master_data.data(), _master_data.data_sz(), _dedup_cache); } - return new safe_vw(_command_line); + return new safe_vw(_command_line, _dedup_cache); } } // namespace reinforcement_learning diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index c97b4b064..56adf9a0a 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -1,6 +1,7 @@ #pragma once #include "model_mgmt.h" +#include "lru_dedup_cache.h" #include "vw/core/vw.h" #include @@ -14,19 +15,21 @@ class safe_vw std::shared_ptr _master; VW::workspace* _vw; std::vector _example_pool; + lru_dedup_cache* _dedup_cache; VW::example* get_or_create_example(); static VW::example& get_or_create_example_f(void* vw); public: - safe_vw(std::shared_ptr master); - safe_vw(const char* model_data, size_t len, const std::string& vw_commandline); - safe_vw(const char* model_data, size_t len); - safe_vw(const std::string& vw_commandline); + safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache); + safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache); + safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache); + safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache); ~safe_vw(); void parse_context_with_pdf(string_view context, std::vector& actions, std::vector& scores); + void add_lru_dedup_cache(uint64_t hash, std::string action_str); void rank(string_view context, std::vector& actions, std::vector& scores); void choose_continuous_action(string_view context, float& action, float& pdf_value); // Used for CCB @@ -57,14 +60,15 @@ class safe_vw_factory { model_management::model_data _master_data; std::string _command_line; + lru_dedup_cache* _dedup_cache; public: // model_data is copied and stored in the factory object. - safe_vw_factory(std::string command_line); - safe_vw_factory(const model_management::model_data& master_data); - safe_vw_factory(const model_management::model_data&& master_data); - safe_vw_factory(const model_management::model_data& master_data, std::string command_line); - safe_vw_factory(const model_management::model_data&& master_data, std::string command_line); + safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache); + safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache); + safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache); + safe_vw_factory(const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache); + safe_vw_factory(const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache); safe_vw* operator()(); }; diff --git a/rlclientlib/vw_model/vw_model.cc b/rlclientlib/vw_model/vw_model.cc index 53d951a93..a30cc4ad0 100644 --- a/rlclientlib/vw_model/vw_model.cc +++ b/rlclientlib/vw_model/vw_model.cc @@ -18,7 +18,7 @@ vw_model::vw_model(i_trace* trace_logger, const utility::configuration& config) , _initial_command_line(std::string(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id N/A")) + (_audit ? " --audit" : "")) - , _vw_pool(safe_vw_factory(_initial_command_line), + , _vw_pool(safe_vw_factory(_initial_command_line, _dedup_cache), config.get_int(name::VW_POOL_INIT_SIZE, value::DEFAULT_VW_POOL_INIT_SIZE), trace_logger) , _trace_logger(trace_logger) { @@ -34,13 +34,13 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat { std::string cmd_line = add_optional_audit_flag(_quiet_commandline_options); - std::unique_ptr init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line)); + std::unique_ptr init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line, _dedup_cache)); if (init_vw->is_CB_to_CCB_model_upgrade(_initial_command_line)) { cmd_line = add_optional_audit_flag(_upgrade_to_CCB_vw_commandline_options); } - safe_vw_factory factory(data, cmd_line); + safe_vw_factory factory(data, cmd_line, _dedup_cache); std::unique_ptr test_vw(factory()); if (test_vw->is_compatible(_initial_command_line)) { @@ -67,6 +67,13 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat return error_code::success; } +int vw_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +{ + auto vw = _vw_pool.get_or_create(); + vw->add_lru_dedup_cache(hash, action_str); + return error_code::success; +} + int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { diff --git a/rlclientlib/vw_model/vw_model.h b/rlclientlib/vw_model/vw_model.h index dd872b6e9..233075786 100644 --- a/rlclientlib/vw_model/vw_model.h +++ b/rlclientlib/vw_model/vw_model.h @@ -1,5 +1,6 @@ #pragma once #include "../utility/versioned_object_pool.h" +#include "lru_dedup_cache.h" #include "model_mgmt.h" #include "multistep.h" #include "safe_vw.h" @@ -26,6 +27,7 @@ class vw_model : public i_model vw_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; + int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, @@ -48,6 +50,7 @@ class vw_model : public i_model const std::string _quiet_commandline_options{"--json --quiet"}; const std::string _upgrade_to_CCB_vw_commandline_options{"--ccb_explore_adf --json --quiet"}; utility::versioned_object_pool _vw_pool; + lru_dedup_cache* _dedup_cache = nullptr; i_trace* _trace_logger; }; } // namespace model_management diff --git a/unit_test/live_model_test.cc b/unit_test/live_model_test.cc index 83de9cd1a..8e38cb516 100644 --- a/unit_test/live_model_test.cc +++ b/unit_test/live_model_test.cc @@ -448,7 +448,7 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only) r::ranking_response response; const auto JSON_CB_CONTEXT_3ACTIONS = - R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"} },{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})"; + R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"}},{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})"; // request ranking BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS, response), err::success); @@ -468,6 +468,54 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only) } } +BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only_dedup) +{ + // create a simple ds configuration + u::configuration config; + cfg::create_from_json(JSON_CFG, config); + config.set(r::name::EH_TEST, "true"); + config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); + config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER); + config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER); + config.set( + r::name::MODEL_VW_INITIAL_COMMAND_LINE, "--cb_explore_adf --json --quiet --epsilon 0.3 --first_only --id N/A"); + + r::api_status status; + + // create the ds live_model, and initialize it with the config + r::live_model model(config); + + BOOST_CHECK_EQUAL(model.init(&status), err::success); + const auto event_id = "event_id"; + + r::ranking_response response; + + const auto JSON_CB_CONTEXT_3ACTIONS_DEDUP = + R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"__aid":1},{"__aid":2},{"__aid":3}]})"; + + // add dedup + BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(1, "{\"TAction\":{\"a1\":\"f1\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(2, "{\"TAction\":{\"a2\":\"f2\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(3, "{\"TAction\":{\"a3\":\"f3\"}}", &status), err::success); + + // request ranking + BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS_DEDUP, response), err::success); + + size_t num_actions = response.size(); + BOOST_CHECK_EQUAL(num_actions, 3); + + const float EXPECTED_PROBS[3] = {0.8f, 0.1f, 0.1f}; + + // check that our PDF is what we expected + r::ranking_response::iterator it = response.begin(); + + for (uint32_t i = 0; i < num_actions; i++) + { + auto action_probability = *(it + i); + BOOST_CHECK_CLOSE(action_probability.probability, EXPECTED_PROBS[action_probability.action_id], FLOAT_TOL); + } +} + BOOST_AUTO_TEST_CASE(live_model_ranking_w_las_request_check_response_pdf_explore_only) { // create a simple ds configuration diff --git a/unit_test/safe_vw_test.cc b/unit_test/safe_vw_test.cc index 705a3534e..7784674a6 100644 --- a/unit_test/safe_vw_test.cc +++ b/unit_test/safe_vw_test.cc @@ -20,7 +20,7 @@ void get_model_data_from_raw(const char* data, unsigned int len, model_managemen BOOST_AUTO_TEST_CASE(safe_vw_1) { - safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len); + safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, nullptr); const auto json = R"({"a":{"0":1,"5":2},"_multi":[{"b":{"0":1}},{"b":{"0":2}},{"b":{"0":3}}]})"; std::vector actions; @@ -34,7 +34,7 @@ BOOST_AUTO_TEST_CASE(safe_vw_1) BOOST_AUTO_TEST_CASE(safe_vw_audit_logs) { - safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, "--json --quiet"); + safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, "--json --quiet", nullptr); const auto json = R"({"a":{"0":1,"5":2},"_multi":[{"b":{"0":1}},{"b":{"0":2}},{"b":{"0":3}}]})"; std::vector actions; @@ -43,7 +43,7 @@ BOOST_AUTO_TEST_CASE(safe_vw_audit_logs) BOOST_CHECK_EQUAL(0, vw.get_audit_data().size()); - safe_vw vw_w_audit((const char*)cb_data_5_model, cb_data_5_model_len, "--json --audit"); + safe_vw vw_w_audit((const char*)cb_data_5_model, cb_data_5_model_len, "--json --audit", nullptr); vw_w_audit.rank(json, actions, ranking); BOOST_CHECK_LT(0, vw_w_audit.get_audit_data().size()); @@ -59,7 +59,7 @@ BOOST_AUTO_TEST_CASE(factory_with_cb_model_and_ccb_arguments) model_management::model_data model_data; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &model_data); - const safe_vw_factory factory(model_data, vw_commandLine); + const safe_vw_factory factory(model_data, vw_commandLine, nullptr); versioned_object_pool pool(factory); { @@ -90,7 +90,7 @@ BOOST_AUTO_TEST_CASE(factory_with_initial_model) model_management::model_data model_data; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &model_data); - const safe_vw_factory factory(model_data); + const safe_vw_factory factory(model_data, nullptr); versioned_object_pool pool(factory); { @@ -99,7 +99,7 @@ BOOST_AUTO_TEST_CASE(factory_with_initial_model) // Update factory while an object is floating around model_management::model_data updated_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &updated_model); - pool.update_factory(safe_vw_factory(updated_model)); + pool.update_factory(safe_vw_factory(updated_model, nullptr)); std::vector actions; std::vector ranking; @@ -127,14 +127,14 @@ BOOST_AUTO_TEST_CASE(factory_with_empty_model) // Start with empty model data model_management::model_data empty_data; - const safe_vw_factory factory(empty_data); + const safe_vw_factory factory(empty_data, nullptr); versioned_object_pool pool(factory); // Initial model & rank call { model_management::model_data new_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &new_model); - pool.update_factory(safe_vw_factory(new_model)); + pool.update_factory(safe_vw_factory(new_model, nullptr)); auto vw = pool.get_or_create(); std::vector actions; @@ -148,7 +148,7 @@ BOOST_AUTO_TEST_CASE(factory_with_empty_model) { model_management::model_data new_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &new_model); - pool.update_factory(safe_vw_factory(new_model)); + pool.update_factory(safe_vw_factory(new_model, nullptr)); auto vw = pool.get_or_create(); std::vector actions; From 47cac141e0ed97c5cb1c85199c96af60b9a12a89 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 13:33:15 -0500 Subject: [PATCH 02/10] clang --- external_parser/joiners/example_joiner.h | 2 +- external_parser/joiners/i_joiner.h | 2 +- rlclientlib/vw_model/safe_vw.cc | 48 ++++++++++++++++-------- rlclientlib/vw_model/safe_vw.h | 8 ++-- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/external_parser/joiners/example_joiner.h b/external_parser/joiners/example_joiner.h index bd95ac5f0..4b3062aa7 100644 --- a/external_parser/joiners/example_joiner.h +++ b/external_parser/joiners/example_joiner.h @@ -1,9 +1,9 @@ #pragma once +#include "../rlclientlib/lru_dedup_cache.h" #include "event_processors/joined_event.h" #include "event_processors/loop.h" #include "joiners/i_joiner.h" -#include "../rlclientlib/lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/external_parser/joiners/i_joiner.h b/external_parser/joiners/i_joiner.h index 3130fcb66..dc1b3bd21 100644 --- a/external_parser/joiners/i_joiner.h +++ b/external_parser/joiners/i_joiner.h @@ -1,10 +1,10 @@ #pragma once +#include "../rlclientlib/lru_dedup_cache.h" #include "event_processors/reward.h" #include "generated/v2/CbEvent_generated.h" #include "generated/v2/FileFormat_generated.h" #include "generated/v2/Metadata_generated.h" -#include "../rlclientlib/lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index eabf3dcde..a843a6749 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -19,7 +19,8 @@ namespace reinforcement_learning { static const std::string SEED_TAG = "seed="; -safe_vw::safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache) : _master(std::move(master)), _dedup_cache(dedup_cache) +safe_vw::safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache) + : _master(std::move(master)), _dedup_cache(dedup_cache) { _vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr); init(); @@ -34,7 +35,8 @@ safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cach init(); } -safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) +safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache) + : _dedup_cache(dedup_cache) { io_buf buf; buf.add_file(VW::io::create_buffer_view(model_data, len)); @@ -122,19 +124,19 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector& acti void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str) { - if (_dedup_cache == nullptr) - { - _dedup_cache = new lru_dedup_cache(); - } + if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); } VW::multi_ex examples; examples.push_back(get_or_create_example()); - + if (_vw->audit) { _vw->audit_buffer->clear(); VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } - else { VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } + else + { + VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); + } _dedup_cache->add(hash, examples[0]); } @@ -149,9 +151,14 @@ void safe_vw::rank(string_view context, std::vector& actions, std::vectoraudit) { _vw->audit_buffer->clear(); - VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); + VW::read_line_json_s( + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); + } + else + { + VW::read_line_json_s( + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); } - else { VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); } // finalize example VW::setup_examples(*_vw, examples); @@ -390,18 +397,29 @@ void safe_vw::init() } } -safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache) : _command_line(std::move(command_line)), _dedup_cache(dedup_cache) {} +safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache) + : _command_line(std::move(command_line)), _dedup_cache(dedup_cache) +{ +} -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache) : _master_data(master_data), _dedup_cache(dedup_cache) {} +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache) + : _master_data(master_data), _dedup_cache(dedup_cache) +{ +} -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache) : _master_data(master_data), _dedup_cache(dedup_cache) {} +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache) + : _master_data(master_data), _dedup_cache(dedup_cache) +{ +} -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache) +safe_vw_factory::safe_vw_factory( + const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache) : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) { } -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache) +safe_vw_factory::safe_vw_factory( + const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache) : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) { } diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index 56adf9a0a..24ef33080 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -1,7 +1,7 @@ #pragma once -#include "model_mgmt.h" #include "lru_dedup_cache.h" +#include "model_mgmt.h" #include "vw/core/vw.h" #include @@ -67,8 +67,10 @@ class safe_vw_factory safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache); safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache); safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache); - safe_vw_factory(const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache); - safe_vw_factory(const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache); + safe_vw_factory( + const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache); + safe_vw_factory( + const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache); safe_vw* operator()(); }; From 3410f6e704d725228a9e9b58c8a2b8d69f52e0f3 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 13:45:11 -0500 Subject: [PATCH 03/10] mutex --- rlclientlib/vw_model/vw_model.cc | 3 +++ rlclientlib/vw_model/vw_model.h | 3 +++ 2 files changed, 6 insertions(+) diff --git a/rlclientlib/vw_model/vw_model.cc b/rlclientlib/vw_model/vw_model.cc index a30cc4ad0..82fada014 100644 --- a/rlclientlib/vw_model/vw_model.cc +++ b/rlclientlib/vw_model/vw_model.cc @@ -69,6 +69,7 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat int vw_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) { + std::lock_guard lock(_mutex); auto vw = _vw_pool.get_or_create(); vw->add_lru_dedup_cache(hash, action_str); return error_code::success; @@ -82,6 +83,7 @@ int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view f auto vw = _vw_pool.get_or_create(); // Get a ranked list of action_ids and corresponding pdf + std::lock_guard lock(_mutex); vw->rank(features, action_ids, action_pdf); if (_audit) { write_audit_log(event_id, vw->get_audit_data()); } @@ -104,6 +106,7 @@ int vw_model::choose_rank_multistep(const char* event_id, uint64_t rnd_seed, str const episode_history& history, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { + std::lock_guard lock(_mutex); return choose_rank(event_id, rnd_seed, features, action_ids, action_pdf, model_version, status); } diff --git a/rlclientlib/vw_model/vw_model.h b/rlclientlib/vw_model/vw_model.h index 233075786..edd4b28c7 100644 --- a/rlclientlib/vw_model/vw_model.h +++ b/rlclientlib/vw_model/vw_model.h @@ -6,6 +6,8 @@ #include "safe_vw.h" #include "trace_logger.h" +#include + namespace reinforcement_learning { namespace utility @@ -51,6 +53,7 @@ class vw_model : public i_model const std::string _upgrade_to_CCB_vw_commandline_options{"--ccb_explore_adf --json --quiet"}; utility::versioned_object_pool _vw_pool; lru_dedup_cache* _dedup_cache = nullptr; + std::mutex _mutex; i_trace* _trace_logger; }; } // namespace model_management From a6fd7fc2b19ca3a501c8e89d60eaf4f20c697b8e Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 13:48:30 -0500 Subject: [PATCH 04/10] onnx impl --- rlclientlib/extensions/onnx/src/onnx_model.cc | 6 ++++++ rlclientlib/extensions/onnx/src/onnx_model.h | 1 + 2 files changed, 7 insertions(+) diff --git a/rlclientlib/extensions/onnx/src/onnx_model.cc b/rlclientlib/extensions/onnx/src/onnx_model.cc index 7d4241941..b0f571701 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.cc +++ b/rlclientlib/extensions/onnx/src/onnx_model.cc @@ -138,6 +138,12 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea return error_code::success; } +// TODO: Implement LRU cache for ONNX models. +int onnx_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +{ + return error_code::not_supported; +} + int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { diff --git a/rlclientlib/extensions/onnx/src/onnx_model.h b/rlclientlib/extensions/onnx/src/onnx_model.h index d57aa943a..4b73d44cb 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.h +++ b/rlclientlib/extensions/onnx/src/onnx_model.h @@ -20,6 +20,7 @@ class onnx_model : public model_management::i_model public: onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input); int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override; + int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; From 6ef732995114baf8636c3f6f9d09773191727fea Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 13:55:54 -0500 Subject: [PATCH 05/10] clang tidy --- rlclientlib/live_model.cc | 2 +- rlclientlib/live_model_impl.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rlclientlib/live_model.cc b/rlclientlib/live_model.cc index cb1abc6f8..4619bdcc6 100644 --- a/rlclientlib/live_model.cc +++ b/rlclientlib/live_model.cc @@ -64,7 +64,7 @@ std::vector live_model::c_array_to_vector(const int* c_array, size_t array_ int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) { INIT_CHECK(); - return _pimpl->add_lru_dedup_cache(hash, action_str, status); + return _pimpl->add_lru_dedup_cache(hash, std::move(action_str), status); } int live_model::choose_rank( diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index 173ac2210..ecc3dc865 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -76,7 +76,7 @@ int live_model_impl::init(api_status* status) int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) { - return _model->add_lru_dedup_cache(hash, action_str, status); + return _model->add_lru_dedup_cache(hash, std::move(action_str), status); } int live_model_impl::choose_rank( From afdce95232abf856ba02ec18ae58240d95741947 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 9 Mar 2023 14:00:43 -0500 Subject: [PATCH 06/10] load action naming --- include/live_model.h | 4 ++-- include/model_mgmt.h | 2 +- rlclientlib/extensions/onnx/src/onnx_model.cc | 5 +---- rlclientlib/extensions/onnx/src/onnx_model.h | 2 +- rlclientlib/live_model.cc | 4 ++-- rlclientlib/live_model_impl.cc | 4 ++-- rlclientlib/live_model_impl.h | 2 +- rlclientlib/vw_model/pdf_model.cc | 5 +---- rlclientlib/vw_model/pdf_model.h | 2 +- rlclientlib/vw_model/safe_vw.cc | 4 ++-- rlclientlib/vw_model/safe_vw.h | 2 +- rlclientlib/vw_model/vw_model.cc | 4 ++-- rlclientlib/vw_model/vw_model.h | 2 +- unit_test/live_model_test.cc | 6 +++--- 14 files changed, 21 insertions(+), 27 deletions(-) diff --git a/include/live_model.h b/include/live_model.h index 47a3df7e5..0ec565ee7 100644 --- a/include/live_model.h +++ b/include/live_model.h @@ -111,11 +111,11 @@ class live_model * @brief Load dedup cache. * Load the dedup cache from the specified file. This cache is used to * prevent duplicate actions from being sent to the online trainer. - * @param hash Hash of the dedup cache + * @param action_id Hash of the dedup cache * @param action_str Action string * @return int Return error code. This will also be returned in the api_status object */ - int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status); + int load_action(uint64_t action_id, std::string action_str, api_status* status); /** * @brief Choose an action, given a list of actions, action features and context features. The diff --git a/include/model_mgmt.h b/include/model_mgmt.h index aec5c33df..0b7528f4a 100644 --- a/include/model_mgmt.h +++ b/include/model_mgmt.h @@ -74,7 +74,7 @@ class i_model { public: virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0; - virtual int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) = 0; + virtual int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) = 0; virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) = 0; virtual int choose_continuous_action(string_view features, float& action, float& pdf_value, diff --git a/rlclientlib/extensions/onnx/src/onnx_model.cc b/rlclientlib/extensions/onnx/src/onnx_model.cc index b0f571701..16f548a85 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.cc +++ b/rlclientlib/extensions/onnx/src/onnx_model.cc @@ -139,10 +139,7 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea } // TODO: Implement LRU cache for ONNX models. -int onnx_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) -{ - return error_code::not_supported; -} +int onnx_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; } int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) diff --git a/rlclientlib/extensions/onnx/src/onnx_model.h b/rlclientlib/extensions/onnx/src/onnx_model.h index 4b73d44cb..08bf1f805 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.h +++ b/rlclientlib/extensions/onnx/src/onnx_model.h @@ -20,7 +20,7 @@ class onnx_model : public model_management::i_model public: onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input); int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override; - int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; diff --git a/rlclientlib/live_model.cc b/rlclientlib/live_model.cc index 4619bdcc6..e15b39548 100644 --- a/rlclientlib/live_model.cc +++ b/rlclientlib/live_model.cc @@ -61,10 +61,10 @@ std::vector live_model::c_array_to_vector(const int* c_array, size_t array_ return std::vector(c_array, c_array + array_size); } -int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +int live_model::load_action(uint64_t action_id, std::string action_str, api_status* status) { INIT_CHECK(); - return _pimpl->add_lru_dedup_cache(hash, std::move(action_str), status); + return _pimpl->load_action(action_id, std::move(action_str), status); } int live_model::choose_rank( diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index ecc3dc865..681edad3e 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -74,9 +74,9 @@ int live_model_impl::init(api_status* status) return error_code::success; } -int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +int live_model_impl::load_action(uint64_t action_id, std::string action_str, api_status* status) { - return _model->add_lru_dedup_cache(hash, std::move(action_str), status); + return _model->load_action(action_id, std::move(action_str), status); } int live_model_impl::choose_rank( diff --git a/rlclientlib/live_model_impl.h b/rlclientlib/live_model_impl.h index da4f155ec..6bce84b16 100644 --- a/rlclientlib/live_model_impl.h +++ b/rlclientlib/live_model_impl.h @@ -28,7 +28,7 @@ class live_model_impl int init(api_status* status); - int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status); + int load_action(uint64_t action_id, std::string action_str, api_status* status); int choose_rank( const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status); // here the event_id is auto-generated diff --git a/rlclientlib/vw_model/pdf_model.cc b/rlclientlib/vw_model/pdf_model.cc index 4ebb46a02..60e6fb28d 100644 --- a/rlclientlib/vw_model/pdf_model.cc +++ b/rlclientlib/vw_model/pdf_model.cc @@ -24,10 +24,7 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta } // TODO: Implement LRU cache for PDF models. -int pdf_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) -{ - return error_code::not_supported; -} +int pdf_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; } int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) diff --git a/rlclientlib/vw_model/pdf_model.h b/rlclientlib/vw_model/pdf_model.h index 438b08bb9..77bdac3a5 100644 --- a/rlclientlib/vw_model/pdf_model.h +++ b/rlclientlib/vw_model/pdf_model.h @@ -20,7 +20,7 @@ class pdf_model : public i_model public: pdf_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; - int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index a843a6749..5b9037e69 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -122,7 +122,7 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector& acti for (auto&& ex : examples) { _example_pool.emplace_back(ex); } } -void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str) +void safe_vw::load_action(uint64_t action_id, std::string action_str) { if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); } VW::multi_ex examples; @@ -137,7 +137,7 @@ void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str) { VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } - _dedup_cache->add(hash, examples[0]); + _dedup_cache->add(action_id, examples[0]); } void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores) diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index 24ef33080..89c69d935 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -29,7 +29,7 @@ class safe_vw ~safe_vw(); void parse_context_with_pdf(string_view context, std::vector& actions, std::vector& scores); - void add_lru_dedup_cache(uint64_t hash, std::string action_str); + void load_action(uint64_t action_id, std::string action_str); void rank(string_view context, std::vector& actions, std::vector& scores); void choose_continuous_action(string_view context, float& action, float& pdf_value); // Used for CCB diff --git a/rlclientlib/vw_model/vw_model.cc b/rlclientlib/vw_model/vw_model.cc index 82fada014..68fe79f8e 100644 --- a/rlclientlib/vw_model/vw_model.cc +++ b/rlclientlib/vw_model/vw_model.cc @@ -67,11 +67,11 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat return error_code::success; } -int vw_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status) +int vw_model::load_action(uint64_t action_id, std::string action_str, api_status* status) { std::lock_guard lock(_mutex); auto vw = _vw_pool.get_or_create(); - vw->add_lru_dedup_cache(hash, action_str); + vw->load_action(action_id, action_str); return error_code::success; } diff --git a/rlclientlib/vw_model/vw_model.h b/rlclientlib/vw_model/vw_model.h index edd4b28c7..a635a0fd8 100644 --- a/rlclientlib/vw_model/vw_model.h +++ b/rlclientlib/vw_model/vw_model.h @@ -29,7 +29,7 @@ class vw_model : public i_model vw_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; - int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, diff --git a/unit_test/live_model_test.cc b/unit_test/live_model_test.cc index 8e38cb516..fdcd05aa0 100644 --- a/unit_test/live_model_test.cc +++ b/unit_test/live_model_test.cc @@ -494,9 +494,9 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only_ R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"__aid":1},{"__aid":2},{"__aid":3}]})"; // add dedup - BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(1, "{\"TAction\":{\"a1\":\"f1\"}}", &status), err::success); - BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(2, "{\"TAction\":{\"a2\":\"f2\"}}", &status), err::success); - BOOST_CHECK_EQUAL(model.add_lru_dedup_cache(3, "{\"TAction\":{\"a3\":\"f3\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.load_action(1, "{\"TAction\":{\"a1\":\"f1\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.load_action(2, "{\"TAction\":{\"a2\":\"f2\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.load_action(3, "{\"TAction\":{\"a3\":\"f3\"}}", &status), err::success); // request ranking BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS_DEDUP, response), err::success); From 98f34fdaf439605d172136f1b9f9fdc95c290bdd Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Fri, 10 Mar 2023 10:25:40 -0500 Subject: [PATCH 07/10] comments --- external_parser/CMakeLists.txt | 5 +- external_parser/joiners/example_joiner.h | 2 +- external_parser/joiners/i_joiner.h | 2 +- .../unit_tests/test_lru_dedup_cache.cc | 2 +- rlclientlib/CMakeLists.txt | 5 +- rlclientlib/lru_dedup_cache.cc | 45 ----------------- rlclientlib/lru_dedup_cache.h | 46 ----------------- rlclientlib/vw_model/pdf_model.cc | 2 +- rlclientlib/vw_model/safe_vw.cc | 50 +++++++++---------- rlclientlib/vw_model/safe_vw.h | 24 ++++----- rlclientlib/vw_model/vw_model.cc | 11 ++-- rlclientlib/vw_model/vw_model.h | 2 +- unit_test/safe_vw_test.cc | 18 +++---- 13 files changed, 62 insertions(+), 152 deletions(-) delete mode 100644 rlclientlib/lru_dedup_cache.cc delete mode 100644 rlclientlib/lru_dedup_cache.h diff --git a/external_parser/CMakeLists.txt b/external_parser/CMakeLists.txt index 1a187291d..9b316806a 100644 --- a/external_parser/CMakeLists.txt +++ b/external_parser/CMakeLists.txt @@ -135,7 +135,7 @@ set(binary_parser_headers ${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h ${CMAKE_CURRENT_LIST_DIR}/log_converter.h - ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.h + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h @@ -146,7 +146,7 @@ set(binary_parser_sources ${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/log_converter.cc - ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.cc + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc @@ -161,6 +161,7 @@ target_include_directories(rl_binary_parser ${CMAKE_CURRENT_LIST_DIR}/../ext_libs/zstd/lib/ ${CMAKE_CURRENT_LIST_DIR}/../ext_libs/date/ ) +target_include_directories(rl_binary_parser PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/) # If flatbuffers found via CONFIG, add its target as a library dependency # Otherwise, the flatbuffers MODULE defines FLATBUFFERS_INCLUDE_DIR to add to the include path diff --git a/external_parser/joiners/example_joiner.h b/external_parser/joiners/example_joiner.h index 4b3062aa7..c44346aa4 100644 --- a/external_parser/joiners/example_joiner.h +++ b/external_parser/joiners/example_joiner.h @@ -1,6 +1,6 @@ #pragma once -#include "../rlclientlib/lru_dedup_cache.h" +#include "../rlclientlib/example_cache/lru_dedup_cache.h" #include "event_processors/joined_event.h" #include "event_processors/loop.h" #include "joiners/i_joiner.h" diff --git a/external_parser/joiners/i_joiner.h b/external_parser/joiners/i_joiner.h index dc1b3bd21..fad083684 100644 --- a/external_parser/joiners/i_joiner.h +++ b/external_parser/joiners/i_joiner.h @@ -1,6 +1,6 @@ #pragma once -#include "../rlclientlib/lru_dedup_cache.h" +#include "../rlclientlib/example_cache/lru_dedup_cache.h" #include "event_processors/reward.h" #include "generated/v2/CbEvent_generated.h" #include "generated/v2/FileFormat_generated.h" diff --git a/external_parser/unit_tests/test_lru_dedup_cache.cc b/external_parser/unit_tests/test_lru_dedup_cache.cc index 020231ff1..11973cf4a 100644 --- a/external_parser/unit_tests/test_lru_dedup_cache.cc +++ b/external_parser/unit_tests/test_lru_dedup_cache.cc @@ -1,6 +1,6 @@ #include -#include "../rlclientlib/lru_dedup_cache.h" +#include "../../rlclientlib/example_cache/lru_dedup_cache.h" #include "parse_example_external.h" #include "test_common.h" #include "vw/config/options_cli.h" diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index fa0367e4f..3c53068cc 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -55,6 +55,7 @@ set(PROJECT_SOURCES decision_response.cc dedup.cc error_callback_fn.cc + example_cache/lru_dedup_cache.cc factory_resolver.cc generic_event.cc learning_mode.cc @@ -68,7 +69,6 @@ set(PROJECT_SOURCES logger/logger_facade.cc logger/preamble.cc logger/preamble_sender.cc - lru_dedup_cache.cc model_mgmt/data_callback_fn.cc model_mgmt/empty_data_transport.cc model_mgmt/file_model_loader.cc @@ -143,6 +143,7 @@ set(PROJECT_PUBLIC_HEADERS set(PROJECT_PRIVATE_HEADERS console_tracer.h dedup.h + example_cache/lru_dedup_cache.h federation/federated_client.h federation/joined_log_provider.h generic_event.h @@ -150,7 +151,6 @@ set(PROJECT_PRIVATE_HEADERS logger/async_batcher.h logger/event_logger.h logger/logger_facade.h - lru_dedup_cache.h model_mgmt/data_callback_fn.h model_mgmt/empty_data_transport.h model_mgmt/file_model_loader.h @@ -213,6 +213,7 @@ target_include_directories(rlclientlib ${CMAKE_CURRENT_SOURCE_DIR}/../include PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/example_cache ${CMAKE_CURRENT_SOURCE_DIR}/../ext_libs/date ) diff --git a/rlclientlib/lru_dedup_cache.cc b/rlclientlib/lru_dedup_cache.cc deleted file mode 100644 index 74a46337d..000000000 --- a/rlclientlib/lru_dedup_cache.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "lru_dedup_cache.h" - -void lru_dedup_cache::add(uint64_t dedup_id, VW::example* ex) -{ - dedup_examples.emplace(dedup_id, ex); - lru.push_front(dedup_id); - lru_pos.emplace(dedup_id, lru.begin()); -} - -void lru_dedup_cache::update(uint64_t dedup_id) -{ - // existing move to front - auto position = lru_pos[dedup_id]; - lru.erase(position); - lru.push_front(dedup_id); - lru_pos[dedup_id] = lru.begin(); -} - -void lru_dedup_cache::clear_after(uint64_t first_id, release_example_f release_example, void* context) -{ - // erase the rest - auto iter = lru_pos[first_id]; - // point to the element right after - iter++; - auto first_pos = iter; - while (iter != lru.end()) - { - auto dedup_id = *iter; - lru_pos.erase(dedup_id); - release_example(context, dedup_examples[dedup_id]); - dedup_examples.erase(dedup_id); - iter++; - } - lru.erase(first_pos, lru.end()); -} - -void lru_dedup_cache::clear(release_example_f release_example, void* context) -{ - for (auto& dedup_item : dedup_examples) { release_example(context, dedup_item.second); } - dedup_examples.clear(); - lru_pos.clear(); - lru.clear(); -} - -bool lru_dedup_cache::exists(uint64_t dedup_id) { return dedup_examples.find(dedup_id) != dedup_examples.end(); } \ No newline at end of file diff --git a/rlclientlib/lru_dedup_cache.h b/rlclientlib/lru_dedup_cache.h deleted file mode 100644 index b51d29479..000000000 --- a/rlclientlib/lru_dedup_cache.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include "vw/core/example.h" - -#include -#include - -/* -LRU dedup cache -When a new dedup payload is deserialized, the examples that were not found in -the new dedup payload will have moved to the end of the lru list. -If we call clear_after with the first example of the new dedup payload we can -assume that anything after that can be evicted as it was not in the new dedup -payload. If two dedup payloads are identical then nothing will be evicted. - -Assumption: dedup payloads are dictionaries and so they have unique items -*/ -struct lru_dedup_cache -{ - // from dictionary id to example object - // right now holding one dedup dictionary at a time, could be exented to a - // map of maps holding more than one dedup dictionaries at a time - std::unordered_map dedup_examples; - std::list lru; - using list_iterator = std::list::iterator; - std::unordered_map lru_pos; - - using release_example_f = void (*)(void*, VW::example*); - static void noop_release_example_f(void*, VW::example*) {} - -public: - void add(uint64_t dedup_id, VW::example* ex); - void update(uint64_t dedup_id); - void clear_after(uint64_t dedup_id, release_example_f release_example = lru_dedup_cache::noop_release_example_f, - void* context = nullptr); - bool exists(uint64_t dedup_id); - void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr); - std::unordered_map* get_dict() { return &dedup_examples; } - - lru_dedup_cache() = default; - ~lru_dedup_cache() = default; - lru_dedup_cache(const lru_dedup_cache&) = delete; - lru_dedup_cache(lru_dedup_cache&&) = delete; - lru_dedup_cache& operator=(const lru_dedup_cache&) = delete; - lru_dedup_cache& operator=(lru_dedup_cache&&) = delete; -}; \ No newline at end of file diff --git a/rlclientlib/vw_model/pdf_model.cc b/rlclientlib/vw_model/pdf_model.cc index 60e6fb28d..f1ed8c44a 100644 --- a/rlclientlib/vw_model/pdf_model.cc +++ b/rlclientlib/vw_model/pdf_model.cc @@ -13,7 +13,7 @@ namespace model_management // We construct a VW object here to use the example parser to parse joined dsjson-style examples // to extract the PDF. pdf_model::pdf_model(i_trace* trace_logger, const utility::configuration& /*unused*/) - : _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf", nullptr)) + : _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf")) { } diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index 5b9037e69..3fb54dabb 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -19,14 +19,14 @@ namespace reinforcement_learning { static const std::string SEED_TAG = "seed="; -safe_vw::safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache) - : _master(std::move(master)), _dedup_cache(dedup_cache) +safe_vw::safe_vw(std::shared_ptr master) + : _master(std::move(master)) { _vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr); init(); } -safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) +safe_vw::safe_vw(const char* model_data, size_t len) { io_buf buf; buf.add_file(VW::io::create_buffer_view(model_data, len)); @@ -35,8 +35,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cach init(); } -safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache) - : _dedup_cache(dedup_cache) +safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline) { io_buf buf; buf.add_file(VW::io::create_buffer_view(model_data, len)); @@ -45,7 +44,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_comma init(); } -safe_vw::safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache) +safe_vw::safe_vw(const std::string& vw_commandline) { _vw = VW::initialize(vw_commandline); init(); @@ -122,9 +121,8 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector& acti for (auto&& ex : examples) { _example_pool.emplace_back(ex); } } -void safe_vw::load_action(uint64_t action_id, std::string action_str) +void safe_vw::load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache) { - if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); } VW::multi_ex examples; examples.push_back(get_or_create_example()); @@ -137,10 +135,10 @@ void safe_vw::load_action(uint64_t action_id, std::string action_str) { VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } - _dedup_cache->add(action_id, examples[0]); + action_cache->add(action_id, examples[0]); } -void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores) +void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache) { VW::multi_ex examples; examples.push_back(get_or_create_example()); @@ -148,16 +146,18 @@ void safe_vw::rank(string_view context, std::vector& actions, std::vectordedup_examples; + // check for null if (_vw->audit) { _vw->audit_buffer->clear(); VW::read_line_json_s( - *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict); } else { VW::read_line_json_s( - *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict()); + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict); } // finalize example @@ -397,30 +397,30 @@ void safe_vw::init() } } -safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache) - : _command_line(std::move(command_line)), _dedup_cache(dedup_cache) +safe_vw_factory::safe_vw_factory(std::string command_line) + : _command_line(std::move(command_line)) { } -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache) - : _master_data(master_data), _dedup_cache(dedup_cache) +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) + : _master_data(master_data) { } -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache) - : _master_data(master_data), _dedup_cache(dedup_cache) +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) + : _master_data(master_data) { } safe_vw_factory::safe_vw_factory( - const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache) - : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) + const model_management::model_data& master_data, std::string command_line) + : _master_data(master_data), _command_line(std::move(command_line)) { } safe_vw_factory::safe_vw_factory( - const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache) - : _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache) + const model_management::model_data&& master_data, std::string command_line) + : _master_data(master_data), _command_line(std::move(command_line)) { } @@ -429,13 +429,13 @@ safe_vw* safe_vw_factory::operator()() if ((_master_data.data() != nullptr) && !_command_line.empty()) { // Construct new vw object from raw model data and command line argument - return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line, _dedup_cache); + return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line); } if (_master_data.data() != nullptr) { // Construct new vw object from raw model data. - return new safe_vw(_master_data.data(), _master_data.data_sz(), _dedup_cache); + return new safe_vw(_master_data.data(), _master_data.data_sz()); } - return new safe_vw(_command_line, _dedup_cache); + return new safe_vw(_command_line); } } // namespace reinforcement_learning diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index 89c69d935..318f27bae 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -15,22 +15,21 @@ class safe_vw std::shared_ptr _master; VW::workspace* _vw; std::vector _example_pool; - lru_dedup_cache* _dedup_cache; VW::example* get_or_create_example(); static VW::example& get_or_create_example_f(void* vw); public: - safe_vw(std::shared_ptr master, lru_dedup_cache* dedup_cache); - safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache); - safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache); - safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache); + safe_vw(std::shared_ptr master); + safe_vw(const char* model_data, size_t len, const std::string& vw_commandline); + safe_vw(const char* model_data, size_t len); + safe_vw(const std::string& vw_commandline); ~safe_vw(); void parse_context_with_pdf(string_view context, std::vector& actions, std::vector& scores); - void load_action(uint64_t action_id, std::string action_str); - void rank(string_view context, std::vector& actions, std::vector& scores); + void load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache); + void rank(string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache = nullptr); void choose_continuous_action(string_view context, float& action, float& pdf_value); // Used for CCB void rank_decisions(const std::vector& event_ids, string_view context, @@ -60,17 +59,16 @@ class safe_vw_factory { model_management::model_data _master_data; std::string _command_line; - lru_dedup_cache* _dedup_cache; public: // model_data is copied and stored in the factory object. - safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache); - safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache); - safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache); + safe_vw_factory(std::string command_line); + safe_vw_factory(const model_management::model_data& master_data); + safe_vw_factory(const model_management::model_data&& master_data); safe_vw_factory( - const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache); + const model_management::model_data& master_data, std::string command_line); safe_vw_factory( - const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache); + const model_management::model_data&& master_data, std::string command_line); safe_vw* operator()(); }; diff --git a/rlclientlib/vw_model/vw_model.cc b/rlclientlib/vw_model/vw_model.cc index 68fe79f8e..3a2e818c4 100644 --- a/rlclientlib/vw_model/vw_model.cc +++ b/rlclientlib/vw_model/vw_model.cc @@ -18,7 +18,7 @@ vw_model::vw_model(i_trace* trace_logger, const utility::configuration& config) , _initial_command_line(std::string(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id N/A")) + (_audit ? " --audit" : "")) - , _vw_pool(safe_vw_factory(_initial_command_line, _dedup_cache), + , _vw_pool(safe_vw_factory(_initial_command_line), config.get_int(name::VW_POOL_INIT_SIZE, value::DEFAULT_VW_POOL_INIT_SIZE), trace_logger) , _trace_logger(trace_logger) { @@ -34,13 +34,13 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat { std::string cmd_line = add_optional_audit_flag(_quiet_commandline_options); - std::unique_ptr init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line, _dedup_cache)); + std::unique_ptr init_vw(new safe_vw(data.data(), data.data_sz(), cmd_line)); if (init_vw->is_CB_to_CCB_model_upgrade(_initial_command_line)) { cmd_line = add_optional_audit_flag(_upgrade_to_CCB_vw_commandline_options); } - safe_vw_factory factory(data, cmd_line, _dedup_cache); + safe_vw_factory factory(data, cmd_line); std::unique_ptr test_vw(factory()); if (test_vw->is_compatible(_initial_command_line)) { @@ -71,7 +71,7 @@ int vw_model::load_action(uint64_t action_id, std::string action_str, api_status { std::lock_guard lock(_mutex); auto vw = _vw_pool.get_or_create(); - vw->load_action(action_id, action_str); + vw->load_action(action_id, action_str, &_action_cache); return error_code::success; } @@ -84,7 +84,7 @@ int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view f // Get a ranked list of action_ids and corresponding pdf std::lock_guard lock(_mutex); - vw->rank(features, action_ids, action_pdf); + vw->rank(features, action_ids, action_pdf, &_action_cache); if (_audit) { write_audit_log(event_id, vw->get_audit_data()); } @@ -142,6 +142,7 @@ int vw_model::request_decision(const std::vector& event_ids, string auto vw = _vw_pool.get_or_create(); // Get a ranked list of action_ids and corresponding pdf + std::lock_guard lock(_mutex); vw->rank_decisions(event_ids, features, actions_ids, action_pdfs); model_version = vw->id(); diff --git a/rlclientlib/vw_model/vw_model.h b/rlclientlib/vw_model/vw_model.h index a635a0fd8..68f8772f1 100644 --- a/rlclientlib/vw_model/vw_model.h +++ b/rlclientlib/vw_model/vw_model.h @@ -52,7 +52,7 @@ class vw_model : public i_model const std::string _quiet_commandline_options{"--json --quiet"}; const std::string _upgrade_to_CCB_vw_commandline_options{"--ccb_explore_adf --json --quiet"}; utility::versioned_object_pool _vw_pool; - lru_dedup_cache* _dedup_cache = nullptr; + lru_dedup_cache _action_cache; std::mutex _mutex; i_trace* _trace_logger; }; diff --git a/unit_test/safe_vw_test.cc b/unit_test/safe_vw_test.cc index 7784674a6..705a3534e 100644 --- a/unit_test/safe_vw_test.cc +++ b/unit_test/safe_vw_test.cc @@ -20,7 +20,7 @@ void get_model_data_from_raw(const char* data, unsigned int len, model_managemen BOOST_AUTO_TEST_CASE(safe_vw_1) { - safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, nullptr); + safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len); const auto json = R"({"a":{"0":1,"5":2},"_multi":[{"b":{"0":1}},{"b":{"0":2}},{"b":{"0":3}}]})"; std::vector actions; @@ -34,7 +34,7 @@ BOOST_AUTO_TEST_CASE(safe_vw_1) BOOST_AUTO_TEST_CASE(safe_vw_audit_logs) { - safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, "--json --quiet", nullptr); + safe_vw vw((const char*)cb_data_5_model, cb_data_5_model_len, "--json --quiet"); const auto json = R"({"a":{"0":1,"5":2},"_multi":[{"b":{"0":1}},{"b":{"0":2}},{"b":{"0":3}}]})"; std::vector actions; @@ -43,7 +43,7 @@ BOOST_AUTO_TEST_CASE(safe_vw_audit_logs) BOOST_CHECK_EQUAL(0, vw.get_audit_data().size()); - safe_vw vw_w_audit((const char*)cb_data_5_model, cb_data_5_model_len, "--json --audit", nullptr); + safe_vw vw_w_audit((const char*)cb_data_5_model, cb_data_5_model_len, "--json --audit"); vw_w_audit.rank(json, actions, ranking); BOOST_CHECK_LT(0, vw_w_audit.get_audit_data().size()); @@ -59,7 +59,7 @@ BOOST_AUTO_TEST_CASE(factory_with_cb_model_and_ccb_arguments) model_management::model_data model_data; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &model_data); - const safe_vw_factory factory(model_data, vw_commandLine, nullptr); + const safe_vw_factory factory(model_data, vw_commandLine); versioned_object_pool pool(factory); { @@ -90,7 +90,7 @@ BOOST_AUTO_TEST_CASE(factory_with_initial_model) model_management::model_data model_data; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &model_data); - const safe_vw_factory factory(model_data, nullptr); + const safe_vw_factory factory(model_data); versioned_object_pool pool(factory); { @@ -99,7 +99,7 @@ BOOST_AUTO_TEST_CASE(factory_with_initial_model) // Update factory while an object is floating around model_management::model_data updated_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &updated_model); - pool.update_factory(safe_vw_factory(updated_model, nullptr)); + pool.update_factory(safe_vw_factory(updated_model)); std::vector actions; std::vector ranking; @@ -127,14 +127,14 @@ BOOST_AUTO_TEST_CASE(factory_with_empty_model) // Start with empty model data model_management::model_data empty_data; - const safe_vw_factory factory(empty_data, nullptr); + const safe_vw_factory factory(empty_data); versioned_object_pool pool(factory); // Initial model & rank call { model_management::model_data new_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &new_model); - pool.update_factory(safe_vw_factory(new_model, nullptr)); + pool.update_factory(safe_vw_factory(new_model)); auto vw = pool.get_or_create(); std::vector actions; @@ -148,7 +148,7 @@ BOOST_AUTO_TEST_CASE(factory_with_empty_model) { model_management::model_data new_model; get_model_data_from_raw((const char*)cb_data_5_model, cb_data_5_model_len, &new_model); - pool.update_factory(safe_vw_factory(new_model, nullptr)); + pool.update_factory(safe_vw_factory(new_model)); auto vw = pool.get_or_create(); std::vector actions; From 7c238ef4610c3875b1e712b17e2ef26aa23f730f Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Fri, 10 Mar 2023 10:26:00 -0500 Subject: [PATCH 08/10] clang --- rlclientlib/vw_model/safe_vw.cc | 27 ++++++++------------------- rlclientlib/vw_model/safe_vw.h | 9 ++++----- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index 3fb54dabb..91e9bfbda 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -19,8 +19,7 @@ namespace reinforcement_learning { static const std::string SEED_TAG = "seed="; -safe_vw::safe_vw(std::shared_ptr master) - : _master(std::move(master)) +safe_vw::safe_vw(std::shared_ptr master) : _master(std::move(master)) { _vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr); init(); @@ -138,7 +137,8 @@ void safe_vw::load_action(uint64_t action_id, std::string action_str, lru_dedup_ action_cache->add(action_id, examples[0]); } -void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache) +void safe_vw::rank( + string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache) { VW::multi_ex examples; examples.push_back(get_or_create_example()); @@ -397,29 +397,18 @@ void safe_vw::init() } } -safe_vw_factory::safe_vw_factory(std::string command_line) - : _command_line(std::move(command_line)) -{ -} +safe_vw_factory::safe_vw_factory(std::string command_line) : _command_line(std::move(command_line)) {} -safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) - : _master_data(master_data) -{ -} +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) : _master_data(master_data) {} -safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) - : _master_data(master_data) -{ -} +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) : _master_data(master_data) {} -safe_vw_factory::safe_vw_factory( - const model_management::model_data& master_data, std::string command_line) +safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line) : _master_data(master_data), _command_line(std::move(command_line)) { } -safe_vw_factory::safe_vw_factory( - const model_management::model_data&& master_data, std::string command_line) +safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line) : _master_data(master_data), _command_line(std::move(command_line)) { } diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index 318f27bae..0405a8994 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -29,7 +29,8 @@ class safe_vw void parse_context_with_pdf(string_view context, std::vector& actions, std::vector& scores); void load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache); - void rank(string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache = nullptr); + void rank(string_view context, std::vector& actions, std::vector& scores, + lru_dedup_cache* action_cache = nullptr); void choose_continuous_action(string_view context, float& action, float& pdf_value); // Used for CCB void rank_decisions(const std::vector& event_ids, string_view context, @@ -65,10 +66,8 @@ class safe_vw_factory safe_vw_factory(std::string command_line); safe_vw_factory(const model_management::model_data& master_data); safe_vw_factory(const model_management::model_data&& master_data); - safe_vw_factory( - const model_management::model_data& master_data, std::string command_line); - safe_vw_factory( - const model_management::model_data&& master_data, std::string command_line); + safe_vw_factory(const model_management::model_data& master_data, std::string command_line); + safe_vw_factory(const model_management::model_data&& master_data, std::string command_line); safe_vw* operator()(); }; From de01d5f9bdb1116eaa590e29690f69baedee9c9c Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Fri, 10 Mar 2023 10:30:10 -0500 Subject: [PATCH 09/10] add files --- rlclientlib/example_cache/lru_dedup_cache.cc | 45 ++++++++++++++++++++ rlclientlib/example_cache/lru_dedup_cache.h | 45 ++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 rlclientlib/example_cache/lru_dedup_cache.cc create mode 100644 rlclientlib/example_cache/lru_dedup_cache.h diff --git a/rlclientlib/example_cache/lru_dedup_cache.cc b/rlclientlib/example_cache/lru_dedup_cache.cc new file mode 100644 index 000000000..74a46337d --- /dev/null +++ b/rlclientlib/example_cache/lru_dedup_cache.cc @@ -0,0 +1,45 @@ +#include "lru_dedup_cache.h" + +void lru_dedup_cache::add(uint64_t dedup_id, VW::example* ex) +{ + dedup_examples.emplace(dedup_id, ex); + lru.push_front(dedup_id); + lru_pos.emplace(dedup_id, lru.begin()); +} + +void lru_dedup_cache::update(uint64_t dedup_id) +{ + // existing move to front + auto position = lru_pos[dedup_id]; + lru.erase(position); + lru.push_front(dedup_id); + lru_pos[dedup_id] = lru.begin(); +} + +void lru_dedup_cache::clear_after(uint64_t first_id, release_example_f release_example, void* context) +{ + // erase the rest + auto iter = lru_pos[first_id]; + // point to the element right after + iter++; + auto first_pos = iter; + while (iter != lru.end()) + { + auto dedup_id = *iter; + lru_pos.erase(dedup_id); + release_example(context, dedup_examples[dedup_id]); + dedup_examples.erase(dedup_id); + iter++; + } + lru.erase(first_pos, lru.end()); +} + +void lru_dedup_cache::clear(release_example_f release_example, void* context) +{ + for (auto& dedup_item : dedup_examples) { release_example(context, dedup_item.second); } + dedup_examples.clear(); + lru_pos.clear(); + lru.clear(); +} + +bool lru_dedup_cache::exists(uint64_t dedup_id) { return dedup_examples.find(dedup_id) != dedup_examples.end(); } \ No newline at end of file diff --git a/rlclientlib/example_cache/lru_dedup_cache.h b/rlclientlib/example_cache/lru_dedup_cache.h new file mode 100644 index 000000000..408dedda6 --- /dev/null +++ b/rlclientlib/example_cache/lru_dedup_cache.h @@ -0,0 +1,45 @@ +#pragma once + +#include "vw/core/example.h" + +#include +#include + +/* +LRU dedup cache +When a new dedup payload is deserialized, the examples that were not found in +the new dedup payload will have moved to the end of the lru list. +If we call clear_after with the first example of the new dedup payload we can +assume that anything after that can be evicted as it was not in the new dedup +payload. If two dedup payloads are identical then nothing will be evicted. + +Assumption: dedup payloads are dictionaries and so they have unique items +*/ +struct lru_dedup_cache +{ + // from dictionary id to example object + // right now holding one dedup dictionary at a time, could be exented to a + // map of maps holding more than one dedup dictionaries at a time + std::unordered_map dedup_examples; + std::list lru; + using list_iterator = std::list::iterator; + std::unordered_map lru_pos; + + using release_example_f = void (*)(void*, VW::example*); + static void noop_release_example_f(void*, VW::example*) {} + +public: + void add(uint64_t dedup_id, VW::example* ex); + void update(uint64_t dedup_id); + void clear_after(uint64_t dedup_id, release_example_f release_example = lru_dedup_cache::noop_release_example_f, + void* context = nullptr); + bool exists(uint64_t dedup_id); + void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr); + + lru_dedup_cache() = default; + ~lru_dedup_cache() = default; + lru_dedup_cache(const lru_dedup_cache&) = delete; + lru_dedup_cache(lru_dedup_cache&&) = delete; + lru_dedup_cache& operator=(const lru_dedup_cache&) = delete; + lru_dedup_cache& operator=(lru_dedup_cache&&) = delete; +}; \ No newline at end of file From ed4cb6cef038906b7ccf20cd5a9835b2e2e32b7d Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Fri, 10 Mar 2023 12:20:31 -0500 Subject: [PATCH 10/10] fix valgrind --- rlclientlib/vw_model/safe_vw.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index 91e9bfbda..92e9262c7 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -135,6 +135,9 @@ void safe_vw::load_action(uint64_t action_id, std::string action_str, lru_dedup_ VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); } action_cache->add(action_id, examples[0]); + + // clean up examples and push examples back into pool for re-use + for (auto&& ex : examples) { _example_pool.emplace_back(ex); } } void safe_vw::rank(