Skip to content

Commit

Permalink
feat: bandits
Browse files Browse the repository at this point in the history
  • Loading branch information
rasendubi committed Jul 3, 2024
1 parent c388a2d commit 4169b29
Show file tree
Hide file tree
Showing 18 changed files with 884 additions and 150 deletions.
421 changes: 421 additions & 0 deletions eppo_core/src/bandits/eval.rs

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions eppo_core/src/bandits/event.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

/// Bandit evaluation event that needs to be logged to analytics storage.
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditEvent {
pub flag_key: String,
pub bandit_key: String,
pub subject: String,
pub action: String,
pub action_probability: f64,
pub optimality_gap: f64,
pub model_version: String,
pub timestamp: String,
pub subject_numeric_attributes: HashMap<String, f64>,
pub subject_categorical_attributes: HashMap<String, String>,
pub action_numeric_attributes: HashMap<String, f64>,
pub action_categorical_attributes: HashMap<String, String>,
pub meta_data: HashMap<String, String>,
}
7 changes: 7 additions & 0 deletions eppo_core/src/bandits/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod eval;
mod event;
mod models;

pub use eval::BanditResult;
pub use event::BanditEvent;
pub use models::*;
60 changes: 60 additions & 0 deletions eppo_core/src/bandits/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#![allow(missing_docs)]

use std::collections::HashMap;

use serde::{Deserialize, Serialize};

type Timestamp = chrono::DateTime<chrono::Utc>;

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditResponse {
pub bandits: HashMap<String, BanditConfiguration>,
pub updated_at: Timestamp,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditConfiguration {
pub bandit_key: String,
pub model_name: String,
pub model_version: String,
pub model_data: BanditModelData,
pub updated_at: Timestamp,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditModelData {
pub gamma: f64,
pub default_action_score: f64,
pub action_probability_floor: f64,
pub coefficients: HashMap<String, BanditCoefficients>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditCoefficients {
pub action_key: String,
pub intercept: f64,
pub subject_numeric_coefficients: Vec<BanditNumericAttributeCoefficient>,
pub subject_categorical_coefficients: Vec<BanditCategoricalAttributeCoefficient>,
pub action_numeric_coefficients: Vec<BanditNumericAttributeCoefficient>,
pub action_categorical_coefficients: Vec<BanditCategoricalAttributeCoefficient>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditNumericAttributeCoefficient {
pub attribute_key: String,
pub coefficient: f64,
pub missing_value_coefficient: f64,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct BanditCategoricalAttributeCoefficient {
pub attribute_key: String,
pub value_coefficients: HashMap<String, f64>,
pub missing_value_coefficient: f64,
}
65 changes: 61 additions & 4 deletions eppo_core/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,66 @@
use std::sync::Arc;
use std::collections::HashMap;

use crate::ufc::UniversalFlagConfig;
use crate::{
bandits::{BanditConfiguration, BanditResponse},
ufc::{BanditVariation, UniversalFlagConfig},
};

/// Remote configuration for the eppo client. It's a central piece that defines client behavior.
#[derive(Default, Clone)]
pub struct Configuration {
/// UFC configuration.
pub ufc: Option<Arc<UniversalFlagConfig>>,
/// Flags configuration.
pub flags: Option<UniversalFlagConfig>,
/// Bandits configuration.
pub bandits: Option<BanditResponse>,
/// Mapping from flag key to flag variation value to bandit variation.
pub flag_to_bandit_associations:
HashMap</* flag_key: */ String, HashMap</* variation_key: */ String, BanditVariation>>,
}

impl Configuration {
/// Create a new configuration from server responses.
pub fn new(
config: Option<UniversalFlagConfig>,
bandits: Option<BanditResponse>,
) -> Configuration {
let flag_to_bandit_associations = config
.as_ref()
.map(get_flag_to_bandit_associations)
.unwrap_or_default();
Configuration {
flags: config,
bandits,
flag_to_bandit_associations,
}
}

/// Return a bandit variant for the specified flag key and string flag variation.
pub(crate) fn get_bandit_key<'a>(&'a self, flag_key: &str, variation: &str) -> Option<&'a str> {
self.flag_to_bandit_associations
.get(flag_key)
.and_then(|x| x.get(variation))
.map(|variation| variation.key.as_str())
}

/// Return bandit configuration for the given key.
///
/// Returns `None` if bandits are missing for bandit does not exist.
pub(crate) fn get_bandit<'a>(&'a self, bandit_key: &str) -> Option<&'a BanditConfiguration> {
self.bandits.as_ref()?.bandits.get(bandit_key)
}
}

fn get_flag_to_bandit_associations(
config: &UniversalFlagConfig,
) -> HashMap<String, HashMap<String, BanditVariation>> {
config
.bandits
.iter()
.flat_map(|(_, bandits)| bandits.iter())
.fold(HashMap::new(), |mut acc, variation| {
acc.entry(variation.flag_key.clone())
.or_default()
.insert(variation.variation_value.clone(), variation.clone());
acc
})
}
51 changes: 46 additions & 5 deletions eppo_core/src/configuration_fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use std::sync::Arc;

use reqwest::{StatusCode, Url};

use crate::{ufc::UniversalFlagConfig, Configuration, Error, Result};
use crate::{bandits::BanditResponse, ufc::UniversalFlagConfig, Configuration, Error, Result};

#[derive(Debug, PartialEq, Eq)]
pub struct ConfigurationFetcherConfig {
pub base_url: String,
pub api_key: String,
/// SDK name. Usually, language name.
/// SDK name. (Usually, language name.)
pub sdk_name: String,
/// Version of SDK.
pub sdk_version: String,
Expand All @@ -17,6 +18,7 @@ pub struct ConfigurationFetcherConfig {
pub const DEFAULT_BASE_URL: &'static str = "https://fscdn.eppo.cloud/api";

const UFC_ENDPOINT: &'static str = "/flag-config/v1/config";
const BANDIT_ENDPOINT: &'static str = "/flag-config/v1/bandits";

/// A client that fetches Eppo configuration from the server.
pub struct ConfigurationFetcher {
Expand Down Expand Up @@ -46,9 +48,14 @@ impl ConfigurationFetcher {

let ufc = self.fetch_ufc_configuration()?;

Ok(Configuration {
ufc: Some(Arc::new(ufc)),
})
let bandits = if ufc.bandits.is_empty() {
// We don't need bandits configuration if there are no bandits.
None
} else {
Some(self.fetch_bandits_configuration()?)
};

Ok(Configuration::new(Some(ufc), bandits))
}

fn fetch_ufc_configuration(&mut self) -> Result<UniversalFlagConfig> {
Expand Down Expand Up @@ -84,4 +91,38 @@ impl ConfigurationFetcher {

Ok(configuration)
}

fn fetch_bandits_configuration(&mut self) -> Result<BanditResponse> {
let url = Url::parse_with_params(
&format!("{}{}", self.config.base_url, BANDIT_ENDPOINT),
&[
("apiKey", &*self.config.api_key),
("sdkName", &*self.config.sdk_name),
("sdkVersion", &*self.config.sdk_version),
("coreVersion", env!("CARGO_PKG_VERSION")),
],
)
.map_err(|err| Error::InvalidBaseUrl(err))?;

log::debug!(target: "eppo", "fetching UFC configuration");
let response = self.client.get(url).send()?;

let response = response.error_for_status().map_err(|err| {
if err.status() == Some(StatusCode::UNAUTHORIZED) {
log::warn!(target: "eppo", "client is not authorized. Check your API key");
self.unauthorized = true;
return Error::Unauthorized;
} else {
log::warn!(target: "eppo", "received non-200 response while fetching new configuration: {:?}", err);
return Error::from(err);

}
})?;

let configuration = response.json()?;

log::debug!(target: "eppo", "successfully fetched UFC configuration");

Ok(configuration)
}
}
19 changes: 11 additions & 8 deletions eppo_core/src/configuration_store.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! A thread-safe in-memory storage for currently active configuration. [`ConfigurationStore`]
//! provides a concurrent access for readers (e.g., flag evaluation) and writers (e.g., periodic
//! configuration fetcher).
use std::sync::RwLock;
use std::sync::{Arc, RwLock};

use crate::Configuration;

Expand All @@ -11,15 +11,15 @@ use crate::Configuration;
/// `Configuration` itself is always immutable and can only be replaced fully.
#[derive(Default)]
pub struct ConfigurationStore {
configuration: RwLock<Configuration>,
configuration: RwLock<Arc<Configuration>>,
}

impl ConfigurationStore {
pub fn new() -> Self {
ConfigurationStore::default()
}

pub fn get_configuration(&self) -> Configuration {
pub fn get_configuration(&self) -> Arc<Configuration> {
// self.configuration.read() should always return Ok(). Err() is possible only if the lock
// is poisoned (writer panicked while holding the lock), which should never happen.
let configuration = self
Expand All @@ -32,6 +32,7 @@ impl ConfigurationStore {

/// Set new configuration.
pub fn set_configuration(&self, config: Configuration) {
let config = Arc::new(config);
let mut configuration_slot = self
.configuration
.write()
Expand All @@ -55,15 +56,17 @@ mod tests {
{
let store = store.clone();
let _ = std::thread::spawn(move || {
store.set_configuration(Configuration {
ufc: Some(Arc::new(UniversalFlagConfig {
store.set_configuration(Configuration::new(
Some(UniversalFlagConfig {
flags: HashMap::new(),
})),
});
bandits: HashMap::new(),
}),
None,
))
})
.join();
}

assert!(store.get_configuration().ufc.is_some());
assert!(store.get_configuration().flags.is_some());
}
}
74 changes: 74 additions & 0 deletions eppo_core/src/context_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::{AttributeValue, Attributes};

/// `ContextAttributes` are subject or action attributes split by their semantics.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContextAttributes {
/// Numeric attributes are quantitative (e.g., real numbers) and define a scale.
///
/// Not all numbers are numeric attributes. If a number is used to represent an enumeration or
/// on/off values, it is a categorical attribute.
pub numeric: HashMap<String, f64>,
/// Categorical attributes are attributes that have a finite set of values that are not directly
/// comparable (i.e., enumeration).
pub categorical: HashMap<String, String>,
}

impl From<Attributes> for ContextAttributes {
fn from(value: Attributes) -> Self {
ContextAttributes::from_iter(value)
}
}

impl<K, V> FromIterator<(K, V)> for ContextAttributes
where
K: ToOwned<Owned = String>,
V: ToOwned<Owned = AttributeValue>,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
iter.into_iter()
.fold(ContextAttributes::default(), |mut acc, (key, value)| {
match value.to_owned() {
AttributeValue::String(value) => {
acc.categorical.insert(key.to_owned(), value);
}
AttributeValue::Number(value) => {
acc.numeric.insert(key.to_owned(), value);
}
AttributeValue::Boolean(value) => {
// TBD: shall we ignore boolean attributes instead?
//
// One argument for including it here is that this basically guarantees that
// assignment evaluation inside bandit evaluation works the same way as if
// `get_assignment()` was called with generic `Attributes`.
//
// We can go a step further and remove `AttributeValue::Boolean` altogether,
// forcing it to be converted to a string before any evaluation.
acc.categorical.insert(key.to_owned(), value.to_string());
}
AttributeValue::Null => {
// Nulls are missing values and are ignored.
}
}
acc
})
}
}

impl ContextAttributes {
/// Convert contextual attributes to generic `Attributes`.
pub fn to_generic_attributes(&self) -> Attributes {
let mut result = HashMap::with_capacity(self.numeric.len() + self.categorical.capacity());
for (key, value) in self.numeric.iter() {
result.insert(key.clone(), AttributeValue::Number(*value));
}
for (key, value) in self.categorical.iter() {
result.insert(key.clone(), AttributeValue::String(value.clone()));
}
result
}
}
3 changes: 3 additions & 0 deletions eppo_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![warn(rustdoc::missing_crate_level_docs)]
#![warn(missing_docs)]

pub mod bandits;
pub mod configuration_fetcher;
pub mod configuration_store;
pub mod poller_thread;
Expand All @@ -23,8 +24,10 @@ pub mod ufc;

mod attributes;
mod configuration;
mod context_attributes;
mod error;

pub use attributes::{AttributeValue, Attributes};
pub use configuration::Configuration;
pub use context_attributes::ContextAttributes;
pub use error::{Error, Result};
Loading

0 comments on commit 4169b29

Please sign in to comment.