Skip to content

Commit

Permalink
feat(python): add bandits
Browse files Browse the repository at this point in the history
  • Loading branch information
rasendubi committed Aug 22, 2024
1 parent 676d644 commit b135f45
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 90 deletions.
64 changes: 64 additions & 0 deletions eppo_core/src/context_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{AttributeValue, Attributes};
/// `ContextAttributes` are subject or action attributes split by their semantics.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "pyo3", pyo3::pyclass(module = "eppo_client"))]
pub struct ContextAttributes {
/// Numeric attributes are quantitative (e.g., real numbers) and define a scale.
///
Expand Down Expand Up @@ -73,3 +74,66 @@ impl ContextAttributes {
result
}
}

#[cfg(feature = "pyo3")]
mod pyo3_impl {
use std::collections::HashMap;

use pyo3::prelude::*;

use crate::Attributes;

use super::ContextAttributes;

#[pymethods]
impl ContextAttributes {
#[new]
fn new(
numeric_attributes: HashMap<String, f64>,
categorical_attributes: HashMap<String, String>,
) -> ContextAttributes {
ContextAttributes {
numeric: numeric_attributes,
categorical: categorical_attributes,
}
}

/// Create an empty Attributes instance with no numeric or categorical attributes.
///
/// Returns:
/// ContextAttributes: An instance of the ContextAttributes class with empty dictionaries
/// for numeric and categorical attributes.
#[staticmethod]
fn empty() -> ContextAttributes {
ContextAttributes::default()
}

/// Create an ContextAttributes instance from a dictionary of attributes.
/// Args:
/// attributes (Dict[str, Union[float, int, bool, str]]): A dictionary where keys are attribute names
/// and values are attribute values which can be of type float, int, bool, or str.
/// Returns:
/// ContextAttributes: An instance of the ContextAttributes class
/// with numeric and categorical attributes separated.
#[staticmethod]
fn from_dict(attributes: Attributes) -> ContextAttributes {
attributes.into()
}

/// Note that this copies internal attributes, so changes to returned value won't have
/// effect. This may be mitigated by setting numeric attributes again.
#[getter]
fn get_numeric_attributes(&self, py: Python) -> PyObject {
self.numeric.to_object(py)
}

/// Note that this copies internal attributes, so changes to returned value won't have
/// effect. This may be mitigated by setting categorical attributes again.
#[getter]
fn get_categorical_attributes(&self, py: Python) -> PyObject {
self.categorical.to_object(py)
}
}
}
2 changes: 1 addition & 1 deletion eppo_core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ mod eval_visitor;
pub mod eval_details;

pub use eval_assignment::{get_assignment, get_assignment_details};
pub use eval_bandits::{get_bandit_action, get_bandit_action_details};
pub use eval_bandits::{get_bandit_action, get_bandit_action_details, BanditResult};
3 changes: 3 additions & 0 deletions python-sdk/python/eppo_client/bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from eppo_client import ContextAttributes, EvaluationResult

BanditResult = EvaluationResult
74 changes: 74 additions & 0 deletions python-sdk/python/tests/test_bandits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import os
import json
from time import sleep

import eppo_client
from eppo_client.assignment_logger import AssignmentLogger
from eppo_client.bandit import ContextAttributes, BanditResult

TEST_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../../sdk-test-data/ufc/bandit-tests",
)
test_data = []
for file_name in os.listdir(TEST_DIR):
with open("{}/{}".format(TEST_DIR, file_name)) as test_case_json:
test_case_dict = json.load(test_case_json)
test_case_dict["file_name"] = file_name
test_data.append(test_case_dict)

MOCK_BASE_URL = "http://localhost:8378/"


@pytest.fixture(scope="session", autouse=True)
def init_fixture():
eppo_client.init(
eppo_client.config.Config(
base_url=MOCK_BASE_URL + "bandit/api",
api_key="dummy",
assignment_logger=AssignmentLogger(),
)
)
sleep(0.1) # wait for initialization
yield


@pytest.mark.parametrize("test_case", test_data, ids=lambda x: x["file_name"])
def test_bandit_generic_test_cases(test_case):
client = eppo_client.get_instance()

flag = test_case["flag"]
default_value = test_case["defaultValue"]

for subject in test_case["subjects"]:
result = client.get_bandit_action(
flag,
subject["subjectKey"],
ContextAttributes(
numeric_attributes=subject["subjectAttributes"]["numericAttributes"],
categorical_attributes=subject["subjectAttributes"][
"categoricalAttributes"
],
),
{
action["actionKey"]: ContextAttributes(
action["numericAttributes"], action["categoricalAttributes"]
)
for action in subject["actions"]
},
default_value,
)

expected_result = BanditResult(
subject["assignment"]["variation"], subject["assignment"]["action"]
)

assert result.variation == subject["assignment"]["variation"], (
f"Flag {flag} failed for subject {subject['subjectKey']}:"
f"expected assignment {subject['assignment']['variation']}, got {result.variation}"
)
assert result.action == subject["assignment"]["action"], (
f"Flag {flag} failed for subject {subject['subjectKey']}:"
f"expected action {subject['assignment']['action']}, got {result.action}"
)
75 changes: 75 additions & 0 deletions python-sdk/python/tests/test_context_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest

from eppo_client.bandit import ContextAttributes


def test_init():
ContextAttributes(numeric_attributes={"a": 12}, categorical_attributes={"b": "s"})


def test_init_unnamed():
ContextAttributes({"a": 12}, {"b": "s"})


@pytest.mark.rust_only
def test_type_check():
with pytest.raises(TypeError):
ContextAttributes(
numeric_attributes={"a": "s"}, categorical_attributes={"b": "s"}
)


def test_bool_as_numeric():
attrs = ContextAttributes(
numeric_attributes={"true": True, "false": False}, categorical_attributes={}
)
assert attrs.numeric_attributes == {"true": 1.0, "false": 0.0}


def test_empty():
attrs = ContextAttributes.empty()


def test_from_dict():
attrs = ContextAttributes.from_dict(
{
"numeric1": 1,
"numeric2": 42.3,
"categorical1": "string",
}
)
assert attrs.numeric_attributes == {"numeric1": 1.0, "numeric2": 42.3}
assert attrs.categorical_attributes == {
"categorical1": "string",
}


# `bool` is a subclass of `int` in Python, so it was incorrectly
# captured as numeric attribute:
# https://linear.app/eppo/issue/FF-3106/
@pytest.mark.rust_only
def test_from_dict_bool():
attrs = ContextAttributes.from_dict(
{
"categorical": True,
}
)
assert attrs.numeric_attributes == {}
assert attrs.categorical_attributes == {
"categorical": "true",
}


@pytest.mark.rust_only
def test_does_not_allow_bad_attributes():
with pytest.raises(TypeError):
attrs = ContextAttributes.from_dict({"custom": {"tested": True}})


# In Rust, context attributes live in Rust land and getter returns a
# copy of attributes.
@pytest.mark.rust_only
def test_attributes_are_frozen():
attrs = ContextAttributes.from_dict({"cat": "string"})
attrs.categorical_attributes["cat"] = "dog"
assert attrs.categorical_attributes == {"cat": "string"}
Loading

0 comments on commit b135f45

Please sign in to comment.