Skip to content
This repository has been archived by the owner on Jun 21, 2024. It is now read-only.

feat(flags): Match flags on rollout percentage #45

Merged
merged 18 commits into from
Jun 10, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
serde-pickle = { version = "1.1.1"}
sha1 = "0.10.6"

[lints]
workspace = true
Expand Down
81 changes: 46 additions & 35 deletions feature-flags/src/flag_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use std::sync::Arc;
use tracing::instrument;

Expand All @@ -13,44 +13,30 @@ pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_";

// TODO: Hmm, revisit when dealing with groups, but seems like
// ideal to just treat it as a u8 and do our own validation on top
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize)]
pub enum GroupTypeIndex {}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OperatorType {
#[serde(rename = "exact")]
Exact,
#[serde(rename = "is_not")]
IsNot,
#[serde(rename = "icontains")]
Icontains,
#[serde(rename = "not_icontains")]
NotIcontains,
#[serde(rename = "regex")]
Regex,
#[serde(rename = "not_regex")]
NotRegex,
#[serde(rename = "gt")]
Gt,
#[serde(rename = "lt")]
Lt,
#[serde(rename = "gte")]
Gte,
#[serde(rename = "lte")]
Lte,
#[serde(rename = "is_set")]
IsSet,
#[serde(rename = "is_not_set")]
IsNotSet,
#[serde(rename = "is_date_exact")]
IsDateExact,
#[serde(rename = "is_date_after")]
IsDateAfter,
#[serde(rename = "is_date_before")]
IsDateBefore,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct PropertyFilter {
pub key: String,
pub value: serde_json::Value,
Expand All @@ -60,28 +46,28 @@ pub struct PropertyFilter {
pub group_type_index: Option<u8>,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct FlagGroupType {
pub properties: Option<Vec<PropertyFilter>>,
pub rollout_percentage: Option<f32>,
pub rollout_percentage: Option<f64>,
pub variant: Option<String>,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct MultivariateFlagVariant {
pub key: String,
pub name: Option<String>,
pub rollout_percentage: f32,
pub rollout_percentage: f64,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct MultivariateFlagOptions {
pub variants: Vec<MultivariateFlagVariant>,
}

// TODO: test name with https://www.fileformat.info/info/charset/UTF-16/list.htm values, like '𝖕𝖗𝖔𝖕𝖊𝖗𝖙𝖞': `𝓿𝓪𝓵𝓾𝓮`

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct FlagFilters {
pub groups: Vec<FlagGroupType>,
pub multivariate: Option<MultivariateFlagOptions>,
Expand All @@ -90,7 +76,7 @@ pub struct FlagFilters {
pub super_groups: Option<Vec<FlagGroupType>>,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct FeatureFlag {
pub id: i64,
pub team_id: i64,
Expand All @@ -105,15 +91,31 @@ pub struct FeatureFlag {
pub ensure_experience_continuity: bool,
}

#[derive(Debug, Deserialize, Serialize)]
impl FeatureFlag {
pub fn get_group_type_index(&self) -> Option<u8> {
self.filters.aggregation_group_type_index
}

pub fn get_conditions(&self) -> &Vec<FlagGroupType> {
&self.filters.groups
}

pub fn get_variants(&self) -> Vec<MultivariateFlagVariant> {
self.filters
.multivariate
.clone()
.map_or(vec![], |m| m.variants)
}
}

#[derive(Debug, Deserialize)]

pub struct FeatureFlagList {
pub flags: Vec<FeatureFlag>,
}

impl FeatureFlagList {
/// Returns feature flags given a team_id

/// Returns feature flags from redis given a team_id
#[instrument(skip_all)]
pub async fn from_redis(
client: Arc<dyn Client + Send + Sync>,
Expand All @@ -126,6 +128,8 @@ impl FeatureFlagList {
.map_err(|e| match e {
CustomRedisError::NotFound => FlagError::TokenValidationError,
CustomRedisError::PickleError(_) => {
// TODO: Implement From trait for FlagError so we don't need to map
// CustomRedisError ourselves
tracing::error!("failed to fetch data: {}", e);
println!("failed to fetch data: {}", e);
FlagError::DataParsingError
Expand All @@ -150,8 +154,6 @@ impl FeatureFlagList {

#[cfg(test)]
mod tests {
use rand::Rng;

use super::*;
use crate::test_utils::{
insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client,
Expand All @@ -161,21 +163,30 @@ mod tests {
async fn test_fetch_flags_from_redis() {
let client = setup_redis_client(None);

let team = insert_new_team_in_redis(client.clone()).await.unwrap();
let team = insert_new_team_in_redis(client.clone())
.await
.expect("Failed to insert team");

insert_flags_for_team_in_redis(client.clone(), team.id, None)
.await
.expect("Failed to insert flags");

let flags_from_redis = FeatureFlagList::from_redis(client.clone(), team.id)
.await
.unwrap();
.expect("Failed to fetch flags from redis");
assert_eq!(flags_from_redis.flags.len(), 1);
let flag = flags_from_redis.flags.get(0).unwrap();
let flag = flags_from_redis.flags.get(0).expect("Empty flags in redis");
assert_eq!(flag.key, "flag1");
assert_eq!(flag.team_id, team.id);
assert_eq!(flag.filters.groups.len(), 1);
assert_eq!(flag.filters.groups[0].properties.as_ref().unwrap().len(), 1);
assert_eq!(
flag.filters.groups[0]
.properties
.as_ref()
.expect("Properties don't exist on flag")
.len(),
1
);
}

#[tokio::test]
Expand Down
161 changes: 161 additions & 0 deletions feature-flags/src/flag_matching.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use crate::flag_definitions::{FeatureFlag, FlagGroupType};
use sha1::{Digest, Sha1};
use std::fmt::Write;

#[derive(Debug, PartialEq, Eq)]
pub struct FeatureFlagMatch {
pub matches: bool,
pub variant: Option<String>,
//reason
//condition_index
//payload
}

// TODO: Rework FeatureFlagMatcher - python has a pretty awkward interface, where we pass in all flags, and then again
// the flag to match. I don't think there's any reason anymore to store the flags in the matcher, since we can just
// pass the flag to match directly to the get_match method. This will also make the matcher more stateless.
// Potentially, we could also make the matcher a long-lived object, with caching for group keys and such.
// It just takes in the flag and distinct_id and returns the match...
// Or, make this fully stateless
// and have a separate cache struct for caching group keys, cohort definitions, etc. - and check size, if we can keep it in memory
// for all teams. If not, we can have a LRU cache, or a cache that stores only the most recent N keys.
// But, this can be a future refactor, for now just focusing on getting the basic matcher working, write lots and lots of tests
// and then we can easily refactor stuff around.
#[derive(Debug)]
pub struct FeatureFlagMatcher {
// pub flags: Vec<FeatureFlag>,
pub distinct_id: String,
}

const LONG_SCALE: u64 = 0xfffffffffffffff;

impl FeatureFlagMatcher {
pub fn new(distinct_id: String) -> Self {
FeatureFlagMatcher {
// flags,
distinct_id,
}
}

pub fn get_match(&self, feature_flag: &FeatureFlag) -> FeatureFlagMatch {
if self.hashed_identifier(feature_flag).is_none() {
return FeatureFlagMatch {
matches: false,
variant: None,
};
}

// TODO: super groups for early access
// TODO: Variant overrides condition sort

for (index, condition) in feature_flag.get_conditions().iter().enumerate() {
let (is_match, _evaluation_reason) =
self.is_condition_match(feature_flag, condition, index);

if is_match {
// TODO: This is a bit awkward, we should handle overrides only when variants exist.
let variant = match condition.variant.clone() {
Some(variant_override) => {
if feature_flag
.get_variants()
.iter()
.any(|v| v.key == variant_override)
{
Some(variant_override)
} else {
self.get_matching_variant(feature_flag)
}
}
None => self.get_matching_variant(feature_flag),
};

// let payload = self.get_matching_payload(is_match, variant, feature_flag);
return FeatureFlagMatch {
matches: true,
variant,
};
}
}
FeatureFlagMatch {
matches: false,
variant: None,
}
}

pub fn is_condition_match(
&self,
feature_flag: &FeatureFlag,
condition: &FlagGroupType,
_index: usize,
) -> (bool, String) {
let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0);
let mut condition_match = true;
if condition.properties.is_some() {
// TODO: Handle matching conditions
if !condition.properties.as_ref().unwrap().is_empty() {
condition_match = false;
}
}

if !condition_match {
return (false, "NO_CONDITION_MATCH".to_string());
} else if rollout_percentage == 100.0 {
// TODO: Check floating point schenanigans if any
return (true, "CONDITION_MATCH".to_string());
}

if self.get_hash(feature_flag, "") > (rollout_percentage / 100.0) {
return (false, "OUT_OF_ROLLOUT_BOUND".to_string());
}

(true, "CONDITION_MATCH".to_string())
}

pub fn hashed_identifier(&self, feature_flag: &FeatureFlag) -> Option<String> {
if feature_flag.get_group_type_index().is_none() {
// TODO: Use hash key overrides for experience continuity
Some(self.distinct_id.clone())
} else {
// TODO: Handle getting group key
Some("".to_string())
}
}

/// This function takes a identifier and a feature flag key and returns a float between 0 and 1.
/// Given the same identifier and key, it'll always return the same float. These floats are
/// uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic
/// we can do _hash(key, identifier) < 0.2
pub fn get_hash(&self, feature_flag: &FeatureFlag, salt: &str) -> f64 {
// check if hashed_identifier is None
let hashed_identifier = self
.hashed_identifier(feature_flag)
.expect("hashed_identifier is None when computing hash");
let hash_key = format!("{}.{}{}", feature_flag.key, hashed_identifier, salt);
let mut hasher = Sha1::new();
hasher.update(hash_key.as_bytes());
let result = hasher.finalize();
// :TRICKY: Convert the first 15 characters of the digest to a hexadecimal string
// not sure if this is correct, padding each byte as 2 characters
let hex_str: String = result.iter().fold(String::new(), |mut acc, byte| {
let _ = write!(acc, "{:02x}", byte);
acc
})[..15]
.to_string();
let hash_val = u64::from_str_radix(&hex_str, 16).unwrap();

hash_val as f64 / LONG_SCALE as f64
}

pub fn get_matching_variant(&self, feature_flag: &FeatureFlag) -> Option<String> {
let hash = self.get_hash(feature_flag, "variant");
let mut total_percentage = 0.0;

for variant in feature_flag.get_variants() {
total_percentage += variant.rollout_percentage / 100.0;
if hash < total_percentage {
return Some(variant.key.clone());
}
}
None
}
}
1 change: 1 addition & 0 deletions feature-flags/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod api;
pub mod config;
pub mod flag_definitions;
pub mod flag_matching;
pub mod redis;
pub mod router;
pub mod server;
Expand Down
1 change: 1 addition & 0 deletions feature-flags/src/team.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl Team {
}
})?;

// TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups
let team: Team = serde_json::from_str(&serialized_team).map_err(|e| {
tracing::error!("failed to parse data to team: {}", e);
FlagError::DataParsingError
Expand Down
Loading
Loading