Skip to content
This repository has been archived by the owner on Jun 21, 2024. It is now read-only.

feat(flags): Do token validation and extract distinct id #41

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ redis = { version = "0.23.3", features = [
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
serde-pickle = { version = "1.1.1"}

[lints]
workspace = true
Expand Down
4 changes: 4 additions & 0 deletions feature-flags/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub enum FlagError {
#[error("failed to parse request: {0}")]
RequestParsingError(#[from] serde_json::Error),

#[error("failed to parse redis data: {0}")]
DataParsingError(#[from] serde_pickle::Error),

#[error("Empty distinct_id in request")]
EmptyDistinctId,
#[error("No distinct_id in request")]
Expand All @@ -44,6 +47,7 @@ impl IntoResponse for FlagError {
match self {
FlagError::RequestDecodingError(_)
| FlagError::RequestParsingError(_)
| FlagError::DataParsingError(_)
| FlagError::EmptyDistinctId
| FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()),

Expand Down
2 changes: 1 addition & 1 deletion feature-flags/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use envconfig::Envconfig;

#[derive(Envconfig, Clone)]
pub struct Config {
#[envconfig(default = "127.0.0.1:0")]
#[envconfig(default = "127.0.0.1:3001")]
pub address: SocketAddr,

#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
Expand Down
1 change: 1 addition & 0 deletions feature-flags/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pub mod router;
pub mod server;
pub mod v0_endpoint;
pub mod v0_request;
pub mod team;
54 changes: 25 additions & 29 deletions feature-flags/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ use std::time::Duration;

use anyhow::Result;
use async_trait::async_trait;
use redis::AsyncCommands;
use redis::{AsyncCommands, RedisError};
use tokio::time::timeout;

// average for all commands is <10ms, check grafana
const REDIS_TIMEOUT_MILLISECS: u64 = 10;

/// A simple redis wrapper
/// Copied from capture/src/redis.rs.
/// TODO: Modify this to support hincrby, get, and set commands.
/// TODO: Modify this to support hincrby

#[async_trait]
pub trait Client {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>>;

async fn get(&self, k: String) -> Result<String>;
async fn set(&self, k: String, v: String) -> Result<()>;
}

pub struct RedisClient {
Expand All @@ -40,38 +43,31 @@ impl Client for RedisClient {

Ok(fut?)
}
}

// TODO: Find if there's a better way around this.
#[derive(Clone)]
pub struct MockRedisClient {
zrangebyscore_ret: Vec<String>,
}
async fn get(&self, k: String) -> Result<String> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: because rust exposes all string handling, double check if something very bad will happen if I have a utf-16 character in the team name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh will come back to this, will remove team name from consideration for now, will handle this on string matching properties.

let mut conn = self.client.get_async_connection().await?;

impl MockRedisClient {
pub fn new() -> MockRedisClient {
MockRedisClient {
zrangebyscore_ret: Vec::new(),
}
}
let results = conn.get(k.clone());
// TODO: Is this safe? Should we be doing something else for error handling here?
let fut: Result<Vec<u8>, RedisError> = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;

pub fn zrangebyscore_ret(&mut self, ret: Vec<String>) -> Self {
self.zrangebyscore_ret = ret;
// TRICKY: We serialise data to json, then django pickles it.
// Here we deserialize the bytes using serde_pickle, to get the json string.
let string_response: String = serde_pickle::from_slice(&fut?, Default::default())?;

self.clone()
Ok(string_response)
}
}

impl Default for MockRedisClient {
fn default() -> Self {
Self::new()
}
}
async fn set(&self, k: String, v: String) -> Result<()> {
// TRICKY: We serialise data to json, then django pickles it.
// Here we serialize the json string to bytes using serde_pickle.
let bytes = serde_pickle::to_vec(&v, Default::default())?;

#[async_trait]
impl Client for MockRedisClient {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result<Vec<String>> {
Ok(self.zrangebyscore_ret.clone())
let mut conn = self.client.get_async_connection().await?;

let results = conn.set(k, bytes);
let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;

Ok(fut?)
}
}
}
133 changes: 133 additions & 0 deletions feature-flags/src/team.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::sync::Arc;

use crate::{api::FlagError, redis::Client};

use serde::{Deserialize, Serialize};
use tracing::instrument;


// TRICKY: I'm still not sure where the :1: is coming from.
// The Django prefix is `posthog` only.
// It's from here: https://docs.djangoproject.com/en/4.2/topics/cache/#cache-versioning
// F&!£%% on the bright side we don't use this functionality yet.
// Will rely on integration tests to catch this.
const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:";

// TODO: Check what happens if json has extra stuff, does serde ignore it? Yes
// Make sure we don't serialize and store team data in redis. Let main decide endpoint control this...
// and track misses. Revisit if this becomes an issue.
// because otherwise very annoying to keep this in sync with main django which has a lot of extra fields we need here.
// will lead to inconsistent behaviour.
// This is turning out to be very annoying, because we have django key prefixes to be mindful of as well.
// Wonder if it would be better to make these caches independent? This generates that new problem of CRUD happening in Django,
// which needs to update this cache immediately, so they can't really ever be independent.
// True for both team cache and flags cache. Hmm. Just I guess need to add tests around the key prefixes...
#[derive(Debug, Deserialize, Serialize)]
pub struct Team {
pub id: i64,
pub name: String,
pub api_token: String,
}

impl Team {
/// Validates a token, and returns a team if it exists.
///

#[instrument(skip_all)]
pub async fn from_redis(
client: Arc<dyn Client + Send + Sync>,
token: String,
) -> Result<Team, FlagError> {

// TODO: Instead of failing here, i.e. if not in redis, fallback to pg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since that dataset is relatively small, an in-process LRU cache would be very useful. We do have it in ingestion's TeamManager for example. Can be added later for sure, but happy to help there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof tricky, how do you handle invalidations in this case? Team config options / api tokens would get updated out of sync via django. Or is the TTL here so low that it doesn't make a difference in practice 👀 .

Either way, good idea with the low ttl, will look into this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plugin-server cache has a 2 minute TTL, so tokens are still accepted 2 minutes after rotation, which is good enough in my book. Negative lookups (token is not valid) are cached for 5 minutes too, as it's less probable.

let serialized_team = client
.get(
format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)
)
.await
.map_err(|e| {
tracing::error!("failed to fetch data: {}", e);
// TODO: Can be other errors if serde_pickle destructuring fails?
FlagError::TokenValidationError
})?;

let team: Team = serde_json::from_str(&serialized_team).map_err(|e| {
tracing::error!("failed to parse data to team: {}", e);
// TODO: Internal error, shouldn't send back to client
FlagError::RequestParsingError(e)
})?;

Ok(team)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
use anyhow::Error;

use crate::redis::RedisClient;
use rand::{distributions::Alphanumeric, Rng};

use super::*;

fn random_string(prefix: &str, length: usize) -> String {
let suffix: String = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(length)
.map(char::from)
.collect();
format!("{}{}", prefix, suffix)
}

async fn insert_new_team_in_redis(client: Arc<RedisClient>) -> Result<Team, Error> {
let id = rand::thread_rng().gen_range(0..10_000_000);
let token = random_string("phc_", 12);
let team = Team {
id: id,
name: "team".to_string(),
api_token: token,
};

let serialized_team = serde_json::to_string(&team)?;
client
.set(
format!("{TEAM_TOKEN_CACHE_PREFIX}{}", team.api_token.clone()),
serialized_team,
)
.await?;

Ok(team)
}

#[tokio::test]
async fn test_fetch_team_from_redis() {
let client = RedisClient::new("redis://localhost:6379/".to_string())
.expect("Failed to create redis client");
let client = Arc::new(client);

let team = insert_new_team_in_redis(client.clone()).await.unwrap();

let target_token = team.api_token;

let team_from_redis = Team::from_redis(client.clone(), target_token.clone()).await.unwrap();
assert_eq!(
team_from_redis.api_token, target_token
);
assert_eq!(
team_from_redis.id, team.id
);
}

#[tokio::test]
async fn test_fetch_invalid_team_from_redis() {
let client = RedisClient::new("redis://localhost:6379/".to_string())
.expect("Failed to create redis client");
let client = Arc::new(client);

match Team::from_redis(client.clone(), "banana".to_string()).await {
Err(FlagError::TokenValidationError) => (),
_ => panic!("Expected TokenValidationError"),
};
}
}
18 changes: 9 additions & 9 deletions feature-flags/src/v0_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
)]
#[debug_handler]
pub async fn flags(
_state: State<router::State>,
state: State<router::State>,
InsecureClientIp(ip): InsecureClientIp,
meta: Query<FlagsQueryParams>,
headers: HeaderMap,
Expand All @@ -59,19 +59,19 @@ pub async fn flags(
.get("content-type")
.map_or("", |v| v.to_str().unwrap_or(""))
{
"application/x-www-form-urlencoded" => {
return Err(FlagError::RequestDecodingError(String::from(
"invalid form data",
)));
"application/json" => {
tracing::Span::current().record("content_type", "application/json");
FlagRequest::from_bytes(body)
}
ct => {
tracing::Span::current().record("content_type", ct);

FlagRequest::from_bytes(body)
return Err(FlagError::RequestDecodingError(format!(
"unsupported content type: {}",
ct
)));
}
}?;

let token = request.extract_and_verify_token()?;
let token = request.extract_and_verify_token(state.redis.clone()).await?;

tracing::Span::current().record("token", &token);

Expand Down
Loading
Loading