Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use lru_dedup_dict for rank call #569

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ set(binary_parser_headers
${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h
${CMAKE_CURRENT_LIST_DIR}/log_converter.h
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h
Expand All @@ -146,7 +146,7 @@ set(binary_parser_sources
${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/log_converter.cc
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc
Expand All @@ -161,6 +161,7 @@ target_include_directories(rl_binary_parser
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/zstd/lib/
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/date/
)
target_include_directories(rl_binary_parser PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/)

# If flatbuffers found via CONFIG, add its target as a library dependency
# Otherwise, the flatbuffers MODULE defines FLATBUFFERS_INCLUDE_DIR to add to the include path
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/example_joiner.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "../rlclientlib/example_cache/lru_dedup_cache.h"
#include "event_processors/joined_event.h"
#include "event_processors/loop.h"
#include "joiners/i_joiner.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/i_joiner.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include "../rlclientlib/example_cache/lru_dedup_cache.h"
#include "event_processors/reward.h"
#include "generated/v2/CbEvent_generated.h"
#include "generated/v2/FileFormat_generated.h"
#include "generated/v2/Metadata_generated.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/unit_tests/test_lru_dedup_cache.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <boost/test/unit_test.hpp>

#include "lru_dedup_cache.h"
#include "../../rlclientlib/example_cache/lru_dedup_cache.h"
#include "parse_example_external.h"
#include "test_common.h"
#include "vw/config/options_cli.h"
Expand Down
10 changes: 10 additions & 0 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class live_model
*/
int init(api_status* status = nullptr);

/**
* @brief Load dedup cache.
* Load the dedup cache from the specified file. This cache is used to
* prevent duplicate actions from being sent to the online trainer.
* @param action_id Hash of the dedup cache
* @param action_str Action string
* @return int Return error code. This will also be returned in the api_status object
*/
int load_action(uint64_t action_id, std::string action_str, api_status* status);

/**
* @brief Choose an action, given a list of actions, action features and context features. The
* inference library chooses an action by creating a probability distribution over the actions
Expand Down
1 change: 1 addition & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class i_model
{
public:
virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0;
virtual int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) = 0;
virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) = 0;
virtual int choose_continuous_action(string_view features, float& action, float& pdf_value,
Expand Down
3 changes: 3 additions & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ set(PROJECT_SOURCES
decision_response.cc
dedup.cc
error_callback_fn.cc
example_cache/lru_dedup_cache.cc
factory_resolver.cc
generic_event.cc
learning_mode.cc
Expand Down Expand Up @@ -142,6 +143,7 @@ set(PROJECT_PUBLIC_HEADERS
set(PROJECT_PRIVATE_HEADERS
console_tracer.h
dedup.h
example_cache/lru_dedup_cache.h
federation/federated_client.h
federation/joined_log_provider.h
generic_event.h
Expand Down Expand Up @@ -211,6 +213,7 @@ target_include_directories(rlclientlib
${CMAKE_CURRENT_SOURCE_DIR}/../include
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/example_cache
${CMAKE_CURRENT_SOURCE_DIR}/../ext_libs/date
)

Expand Down
3 changes: 3 additions & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea
return error_code::success;
}

// TODO: Implement LRU cache for ONNX models.
int onnx_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; }

int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class onnx_model : public model_management::i_model
public:
onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input);
int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override;
int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;

Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ std::vector<int> live_model::c_array_to_vector(const int* c_array, size_t array_
return std::vector<int>(c_array, c_array + array_size);
}

int live_model::load_action(uint64_t action_id, std::string action_str, api_status* status)
{
INIT_CHECK();
return _pimpl->load_action(action_id, std::move(action_str), status);
}

int live_model::choose_rank(
const char* event_id, string_view context_json, ranking_response& response, api_status* status)
{
Expand Down
5 changes: 5 additions & 0 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status)
return error_code::success;
}

int live_model_impl::load_action(uint64_t action_id, std::string action_str, api_status* status)
{
return _model->load_action(action_id, std::move(action_str), status);
}

int live_model_impl::choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class live_model_impl

int init(api_status* status);

int load_action(uint64_t action_id, std::string action_str, api_status* status);
int choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status);
// here the event_id is auto-generated
Expand Down
3 changes: 3 additions & 0 deletions rlclientlib/vw_model/pdf_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta
return error_code::success;
}

// TODO: Implement LRU cache for PDF models.
int pdf_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; }

int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/vw_model/pdf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class pdf_model : public i_model
public:
pdf_model(i_trace* trace_logger, const utility::configuration& config);
int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override;
int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;
int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version,
Expand Down
34 changes: 31 additions & 3 deletions rlclientlib/vw_model/safe_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,48 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector<int>& acti
for (auto&& ex : examples) { _example_pool.emplace_back(ex); }
}

void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<float>& scores)
void safe_vw::load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache)
{
VW::multi_ex examples;
examples.push_back(get_or_create_example());

if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
else
{
VW::read_line_json_s<false>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
action_cache->add(action_id, examples[0]);

// clean up examples and push examples back into pool for re-use
for (auto&& ex : examples) { _example_pool.emplace_back(ex); }
}

void safe_vw::rank(
string_view context, std::vector<int>& actions, std::vector<float>& scores, lru_dedup_cache* action_cache)
{
VW::multi_ex examples;
examples.push_back(get_or_create_example());

// copy due to destructive parsing by rapidjson
std::string line_vec(context);

auto* action_dict = (action_cache == nullptr) ? nullptr : &action_cache->dedup_examples;
// check for null
if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this);
VW::read_line_json_s<true>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict);
}
else
{
VW::read_line_json_s<false>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict);
}
else { VW::read_line_json_s<false>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); }

// finalize example
VW::setup_examples(*_vw, examples);
Expand Down
5 changes: 4 additions & 1 deletion rlclientlib/vw_model/safe_vw.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "lru_dedup_cache.h"
#include "model_mgmt.h"
#include "vw/core/vw.h"

Expand Down Expand Up @@ -27,7 +28,9 @@ class safe_vw
~safe_vw();

void parse_context_with_pdf(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void rank(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache);
void rank(string_view context, std::vector<int>& actions, std::vector<float>& scores,
lru_dedup_cache* action_cache = nullptr);
void choose_continuous_action(string_view context, float& action, float& pdf_value);
// Used for CCB
void rank_decisions(const std::vector<const char*>& event_ids, string_view context,
Expand Down
13 changes: 12 additions & 1 deletion rlclientlib/vw_model/vw_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat
return error_code::success;
}

int vw_model::load_action(uint64_t action_id, std::string action_str, api_status* status)
{
std::lock_guard<std::mutex> lock(_mutex);
auto vw = _vw_pool.get_or_create();
vw->load_action(action_id, action_str, &_action_cache);
return error_code::success;
}

int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand All @@ -75,7 +83,8 @@ int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view f
auto vw = _vw_pool.get_or_create();

// Get a ranked list of action_ids and corresponding pdf
vw->rank(features, action_ids, action_pdf);
std::lock_guard<std::mutex> lock(_mutex);
vw->rank(features, action_ids, action_pdf, &_action_cache);

if (_audit) { write_audit_log(event_id, vw->get_audit_data()); }

Expand All @@ -97,6 +106,7 @@ int vw_model::choose_rank_multistep(const char* event_id, uint64_t rnd_seed, str
const episode_history& history, std::vector<int>& action_ids, std::vector<float>& action_pdf,
std::string& model_version, api_status* status)
{
std::lock_guard<std::mutex> lock(_mutex);
return choose_rank(event_id, rnd_seed, features, action_ids, action_pdf, model_version, status);
}

Expand Down Expand Up @@ -132,6 +142,7 @@ int vw_model::request_decision(const std::vector<const char*>& event_ids, string
auto vw = _vw_pool.get_or_create();

// Get a ranked list of action_ids and corresponding pdf
std::lock_guard<std::mutex> lock(_mutex);
vw->rank_decisions(event_ids, features, actions_ids, action_pdfs);

model_version = vw->id();
Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/vw_model/vw_model.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#pragma once
#include "../utility/versioned_object_pool.h"
#include "lru_dedup_cache.h"
#include "model_mgmt.h"
#include "multistep.h"
#include "safe_vw.h"
#include "trace_logger.h"

#include <mutex>

namespace reinforcement_learning
{
namespace utility
Expand All @@ -26,6 +29,7 @@ class vw_model : public i_model
vw_model(i_trace* trace_logger, const utility::configuration& config);

int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override;
int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;
int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version,
Expand All @@ -48,6 +52,8 @@ class vw_model : public i_model
const std::string _quiet_commandline_options{"--json --quiet"};
const std::string _upgrade_to_CCB_vw_commandline_options{"--ccb_explore_adf --json --quiet"};
utility::versioned_object_pool<safe_vw> _vw_pool;
lru_dedup_cache _action_cache;
std::mutex _mutex;
i_trace* _trace_logger;
};
} // namespace model_management
Expand Down
50 changes: 49 additions & 1 deletion unit_test/live_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only)
r::ranking_response response;

const auto JSON_CB_CONTEXT_3ACTIONS =
R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"} },{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})";
R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"}},{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})";

// request ranking
BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS, response), err::success);
Expand All @@ -468,6 +468,54 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only)
}
}

BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only_dedup)
{
// create a simple ds configuration
u::configuration config;
cfg::create_from_json(JSON_CFG, config);
config.set(r::name::EH_TEST, "true");
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER);
config.set(
r::name::MODEL_VW_INITIAL_COMMAND_LINE, "--cb_explore_adf --json --quiet --epsilon 0.3 --first_only --id N/A");

r::api_status status;

// create the ds live_model, and initialize it with the config
r::live_model model(config);

BOOST_CHECK_EQUAL(model.init(&status), err::success);
const auto event_id = "event_id";

r::ranking_response response;

const auto JSON_CB_CONTEXT_3ACTIONS_DEDUP =
R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"__aid":1},{"__aid":2},{"__aid":3}]})";

// add dedup
BOOST_CHECK_EQUAL(model.load_action(1, "{\"TAction\":{\"a1\":\"f1\"}}", &status), err::success);
BOOST_CHECK_EQUAL(model.load_action(2, "{\"TAction\":{\"a2\":\"f2\"}}", &status), err::success);
BOOST_CHECK_EQUAL(model.load_action(3, "{\"TAction\":{\"a3\":\"f3\"}}", &status), err::success);

// request ranking
BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS_DEDUP, response), err::success);

size_t num_actions = response.size();
BOOST_CHECK_EQUAL(num_actions, 3);

const float EXPECTED_PROBS[3] = {0.8f, 0.1f, 0.1f};

// check that our PDF is what we expected
r::ranking_response::iterator it = response.begin();

for (uint32_t i = 0; i < num_actions; i++)
{
auto action_probability = *(it + i);
BOOST_CHECK_CLOSE(action_probability.probability, EXPECTED_PROBS[action_probability.action_id], FLOAT_TOL);
}
}

BOOST_AUTO_TEST_CASE(live_model_ranking_w_las_request_check_response_pdf_explore_only)
{
// create a simple ds configuration
Expand Down