diff --git a/.gitignore b/.gitignore index 59a52478..85876ee4 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ mock_server/target *.rock *.nix +/crates/ratings_new/Cargo.lock +/crates/ratings/Cargo.lock + /proto/*.rs venv/ build/ diff --git a/Cargo.lock b/Cargo.lock index 93dde920..721bf36f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2417,15 +2417,20 @@ dependencies = [ name = "ratings_new" version = "0.1.0" dependencies = [ + "anyhow", "dotenvy", "envy", + "futures", + "http 1.1.0", "jsonwebtoken", "prost 0.13.3", "prost-types 0.13.3", + "rand", "reqwest", "secrecy", "serde", "serde_json", + "sha2", "simple_test_case", "sqlx", "strum", @@ -2435,6 +2440,7 @@ dependencies = [ "tonic 0.12.3", "tonic-build", "tonic-reflection 0.12.3", + "tower 0.5.1", "tracing", "tracing-subscriber", ] @@ -3706,6 +3712,16 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.18" @@ -3716,12 +3732,15 @@ dependencies = [ "nu-ansi-term", "once_cell", "regex", + "serde", + "serde_json", "sharded-slab", "smallvec", "thread_local", "tracing", "tracing-core", "tracing-log", + "tracing-serde", ] [[package]] diff --git a/crates/ratings_new/Cargo.lock b/crates/ratings_new/Cargo.lock deleted file mode 100644 index 66194ef8..00000000 --- a/crates/ratings_new/Cargo.lock +++ /dev/null @@ -1,7 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "ratings" -version = "0.1.0" diff --git a/crates/ratings_new/Cargo.toml b/crates/ratings_new/Cargo.toml index c352518b..4ba01871 100644 --- a/crates/ratings_new/Cargo.toml +++ b/crates/ratings_new/Cargo.toml @@ -14,6 +14,7 @@ db_tests = [] [dependencies] dotenvy = "0.15" envy = "0.4" +http = "1.1.0" jsonwebtoken = "9.2" prost = "0.13.3" prost-types = "0.13.3" @@ -28,10 +29,15 @@ time = "0.3" tokio = { version = "1.40.0", features = ["full"] } tonic = "0.12.2" tonic-reflection = "0.12.2" +tower = "0.5.1" tracing = "0.1.40" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] } [dev-dependencies] +anyhow = "1.0.93" +futures = "0.3" +rand = "0.8" +sha2 = "0.10" simple_test_case = "1.2.0" [build-dependencies] diff --git a/crates/ratings_new/build.rs b/crates/ratings_new/build.rs index de0e8d87..58fa5da6 100644 --- a/crates/ratings_new/build.rs +++ b/crates/ratings_new/build.rs @@ -11,10 +11,10 @@ fn init_proto() -> Result<(), Box> { ); let files = &[ - "../../proto/ratings_features_app.proto", - "../../proto/ratings_features_chart.proto", - "../../proto/ratings_features_user.proto", - "../../proto/ratings_features_common.proto", + "proto/ratings_features_app.proto", + "proto/ratings_features_chart.proto", + "proto/ratings_features_user.proto", + "proto/ratings_features_common.proto", ]; tonic_build::configure() @@ -27,7 +27,7 @@ fn init_proto() -> Result<(), Box> { "Category", r#"#[strum(serialize_all = "kebab_case", ascii_case_insensitive)]"#, ) - .compile(files, &["../../proto"])?; + .compile(files, &["proto"])?; Ok(()) } diff --git a/crates/ratings_new/proto/ratings_features_app.proto b/crates/ratings_new/proto/ratings_features_app.proto new file mode 100644 index 00000000..ac40cffb --- /dev/null +++ b/crates/ratings_new/proto/ratings_features_app.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package ratings.features.app; + +import "ratings_features_common.proto"; + +service App { + rpc GetRating (GetRatingRequest) returns (GetRatingResponse) {} +} + +message GetRatingRequest { + string snap_id = 1; +} + +message GetRatingResponse { + ratings.features.common.Rating rating = 1; +} diff --git a/crates/ratings_new/proto/ratings_features_chart.proto b/crates/ratings_new/proto/ratings_features_chart.proto new file mode 100644 index 00000000..0c03db6e --- /dev/null +++ b/crates/ratings_new/proto/ratings_features_chart.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package ratings.features.chart; + +import "ratings_features_common.proto"; + +service Chart { + rpc GetChart (GetChartRequest) returns (GetChartResponse) {} +} + +message GetChartRequest { + Timeframe timeframe = 1; + optional Category category = 2; +} + +message GetChartResponse { + Timeframe timeframe = 1; + repeated ChartData ordered_chart_data = 2; + optional Category category = 3; +} + +message ChartData { + float raw_rating = 1; + ratings.features.common.Rating rating = 2; +} + +enum Timeframe { + TIMEFRAME_UNSPECIFIED = 0; + TIMEFRAME_WEEK = 1; + TIMEFRAME_MONTH = 2; +} + +// The categories that can be selected, these +// are taken directly from `curl -sS -X GET --unix-socket /run/snapd.socket "http://localhost/v2/categories"` +// On 2024-02-03, it may need to be kept in sync. +enum Category { + ART_AND_DESIGN = 0; + BOOK_AND_REFERENCE = 1; + DEVELOPMENT = 2; + DEVICES_AND_IOT = 3; + EDUCATION = 4; + ENTERTAINMENT = 5; + FEATURED = 6; + FINANCE = 7; + GAMES = 8; + HEALTH_AND_FITNESS = 9; + MUSIC_AND_AUDIO = 10; + NEWS_AND_WEATHER = 11; + PERSONALISATION = 12; + PHOTO_AND_VIDEO = 13; + PRODUCTIVITY = 14; + SCIENCE = 15; + SECURITY = 16; + SERVER_AND_CLOUD = 17; + SOCIAL = 18; + UTILITIES = 19; +} diff --git a/crates/ratings_new/proto/ratings_features_common.proto b/crates/ratings_new/proto/ratings_features_common.proto new file mode 100644 index 00000000..84a53d5e --- /dev/null +++ b/crates/ratings_new/proto/ratings_features_common.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package ratings.features.common; + +message Rating { + string snap_id = 1; + uint64 total_votes = 2; + RatingsBand ratings_band = 3; +} + +enum RatingsBand { + VERY_GOOD = 0; + GOOD = 1; + NEUTRAL = 2; + POOR = 3; + VERY_POOR = 4; + INSUFFICIENT_VOTES = 5; +} diff --git a/crates/ratings_new/proto/ratings_features_user.proto b/crates/ratings_new/proto/ratings_features_user.proto new file mode 100644 index 00000000..893133cf --- /dev/null +++ b/crates/ratings_new/proto/ratings_features_user.proto @@ -0,0 +1,53 @@ +syntax = "proto3"; + +package ratings.features.user; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +service User { + rpc Authenticate (AuthenticateRequest) returns (AuthenticateResponse) {} + + rpc Delete (google.protobuf.Empty) returns (google.protobuf.Empty) {} + rpc Vote (VoteRequest) returns (google.protobuf.Empty) {} + rpc ListMyVotes (ListMyVotesRequest) returns (ListMyVotesResponse) {} + rpc GetSnapVotes(GetSnapVotesRequest) returns (GetSnapVotesResponse) {} +} + +message AuthenticateRequest { + // sha256([$user:$machineId]) + string id = 1; +} + +message AuthenticateResponse { + string token = 1; +} + +message ListMyVotesRequest { + string snap_id_filter = 1; +} + +message ListMyVotesResponse { + repeated Vote votes = 1; +} + +message GetSnapVotesRequest { + string snap_id = 1; +} + +message GetSnapVotesResponse { + repeated Vote votes = 1; +} + +message Vote { + string snap_id = 1; + int32 snap_revision = 2; + bool vote_up = 3; + google.protobuf.Timestamp timestamp = 4; +} + +message VoteRequest { + string snap_id = 1; + int32 snap_revision = 2; + bool vote_up = 3; +} diff --git a/crates/ratings_new/src/config.rs b/crates/ratings_new/src/config.rs index 18d9bb13..3b4dfb0a 100644 --- a/crates/ratings_new/src/config.rs +++ b/crates/ratings_new/src/config.rs @@ -6,20 +6,14 @@ use serde::Deserialize; /// Configuration for the general app center ratings backend service. #[derive(Deserialize, Debug, Clone)] pub struct Config { - /// Environment variables to use - pub env: String, /// The host configuration pub host: String, - /// The JWT secret value - pub jwt_secret: SecretString, - /// Log level to use - pub log_level: String, - /// The service name - pub name: String, /// The port to run on pub port: u16, /// The URI of the postgres database pub postgres_uri: String, + /// The JWT secret value + pub jwt_secret: SecretString, /// The base URI for snapcraft.io pub snapcraft_io_uri: String, } diff --git a/crates/ratings_new/src/context.rs b/crates/ratings_new/src/context.rs index 2da3842a..375f570d 100644 --- a/crates/ratings_new/src/context.rs +++ b/crates/ratings_new/src/context.rs @@ -1,28 +1,10 @@ //! Application level context & state -use crate::config::Config; -use jsonwebtoken::{EncodingKey, Header}; -use secrecy::{ExposeSecret, SecretString}; -use serde::{Deserialize, Serialize}; +use crate::{ + config::Config, + jwt::{Error, JwtEncoder}, +}; use std::{collections::HashMap, sync::Arc}; -use time::{Duration, OffsetDateTime}; use tokio::sync::{Mutex, Notify}; -use tracing::error; - -/// How many days until JWT info expires -static JWT_EXPIRY_DAYS: i64 = 1; - -/// Errors that can happen while encoding and signing tokens with JWT. -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("jwt: error decoding secret: {0}")] - DecodeSecretError(#[from] jsonwebtoken::errors::Error), - - #[error(transparent)] - Envy(#[from] envy::Error), - - #[error("jwt: an error occurred, but the reason was erased for security reasons")] - Erased, -} pub struct Context { pub config: Config, @@ -33,55 +15,14 @@ pub struct Context { } impl Context { - pub fn new(config: &Config) -> Result { + pub fn new(config: Config) -> Result { + let jwt_encoder = JwtEncoder::from_secret(&config.jwt_secret)?; + Ok(Self { - config: Config::load()?, - jwt_encoder: JwtEncoder::from_secret(&config.jwt_secret)?, + config, + jwt_encoder, http_client: reqwest::Client::new(), category_updates: Default::default(), }) } } - -/// Information representating a claim on a specific subject at a specific time -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Claims { - /// The subject - pub sub: String, - /// The expiration time - pub exp: usize, -} - -impl Claims { - /// Creates a new claim with the current datetime for the subject given by `sub`. - pub fn new(sub: String) -> Self { - let exp = OffsetDateTime::now_utc() + Duration::days(JWT_EXPIRY_DAYS); - let exp = exp.unix_timestamp() as usize; - - Self { sub, exp } - } -} - -pub struct JwtEncoder { - encoding_key: EncodingKey, -} - -impl JwtEncoder { - pub fn from_secret(secret: &SecretString) -> Result { - let encoding_key = EncodingKey::from_base64_secret(secret.expose_secret())?; - - Ok(Self { encoding_key }) - } - - pub fn encode(&self, sub: String) -> Result { - let claims = Claims::new(sub); - - match jsonwebtoken::encode(&Header::default(), &claims, &self.encoding_key) { - Ok(s) => Ok(s), - Err(e) => { - error!("unable to encode jwt: {e}"); - Err(Error::Erased) - } - } - } -} diff --git a/crates/ratings_new/src/db/mod.rs b/crates/ratings_new/src/db/mod.rs index 6df72e3f..edc79637 100644 --- a/crates/ratings_new/src/db/mod.rs +++ b/crates/ratings_new/src/db/mod.rs @@ -1,5 +1,5 @@ use crate::Config; -use sqlx::{postgres::PgPoolOptions, PgPool}; +use sqlx::{postgres::PgPoolOptions, Connection, PgPool}; use thiserror::Error; use tokio::sync::OnceCell; use tracing::info; @@ -75,6 +75,10 @@ pub async fn get_pool() -> Result<&'static PgPool> { Ok(pool) } +pub async fn check_db_conn() -> Result<()> { + conn!().ping().await.map_err(Into::into) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/ratings_new/src/db/vote.rs b/crates/ratings_new/src/db/vote.rs index 699648b1..2fb53a40 100644 --- a/crates/ratings_new/src/db/vote.rs +++ b/crates/ratings_new/src/db/vote.rs @@ -161,9 +161,10 @@ impl VoteSummary { r" WHERE votes.snap_id IN ( SELECT snap_categories.snap_id FROM snap_categories - WHERE snap_categories.category = $1)", + WHERE snap_categories.category = ", ) - .push_bind(category); + .push_bind(category) + .push(")"); } builder.push(" GROUP BY votes.snap_id"); diff --git a/crates/ratings_new/src/grpc/app.rs b/crates/ratings_new/src/grpc/app.rs index efe60705..88c7093c 100644 --- a/crates/ratings_new/src/grpc/app.rs +++ b/crates/ratings_new/src/grpc/app.rs @@ -18,12 +18,8 @@ use tracing::error; pub struct RatingService; impl RatingService { - /// The paths which are accessible without authentication, if any - pub const PUBLIC_PATHS: [&'static str; 0] = []; - - /// Converts this service into its corresponding server - pub fn to_server(self) -> AppServer { - AppServer::new(self) + pub fn new_server() -> AppServer { + AppServer::new(RatingService) } } diff --git a/crates/ratings_new/src/grpc/charts.rs b/crates/ratings_new/src/grpc/charts.rs index e7683795..0fd8d64d 100644 --- a/crates/ratings_new/src/grpc/charts.rs +++ b/crates/ratings_new/src/grpc/charts.rs @@ -17,12 +17,8 @@ use tracing::error; pub struct ChartService; impl ChartService { - /// The paths which are accessible without authentication, if any - pub const PUBLIC_PATHS: [&'static str; 0] = []; - - /// Converts this service into its corresponding server - pub fn to_server(self) -> ChartServer { - ChartServer::new(self) + pub fn new_server() -> ChartServer { + ChartServer::new(ChartService) } } @@ -92,6 +88,16 @@ impl From for PbRating { } } +impl From for Rating { + fn from(r: PbRating) -> Self { + Self { + snap_id: r.snap_id, + total_votes: r.total_votes, + ratings_band: RatingsBand::from_repr(r.ratings_band).unwrap(), + } + } +} + impl From for PbRatingsBand { fn from(rb: RatingsBand) -> Self { match rb { diff --git a/crates/ratings_new/src/grpc/mod.rs b/crates/ratings_new/src/grpc/mod.rs index b438099f..29b2ef66 100644 --- a/crates/ratings_new/src/grpc/mod.rs +++ b/crates/ratings_new/src/grpc/mod.rs @@ -1,12 +1,32 @@ -use crate::db; -use tonic::Status; +use crate::{db, jwt::JwtVerifier, middleware::AuthLayer, Context}; +use std::net::SocketAddr; +use tonic::{transport::Server, Status}; -pub mod app; -pub mod charts; -pub mod user; +mod app; +mod charts; +mod user; + +use app::RatingService; +use charts::ChartService; +use user::UserService; impl From for Status { fn from(value: db::Error) -> Self { Status::internal(value.to_string()) } } + +pub async fn run_server(ctx: Context) -> Result<(), Box> { + let verifier = JwtVerifier::from_secret(&ctx.config.jwt_secret)?; + let addr: SocketAddr = ctx.config.socket().parse()?; + + Server::builder() + .layer(AuthLayer::new(verifier)) + .add_service(RatingService::new_server()) + .add_service(ChartService::new_server()) + .add_service(UserService::new_server(ctx)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/crates/ratings_new/src/grpc/user.rs b/crates/ratings_new/src/grpc/user.rs index 381d3984..d00203cd 100644 --- a/crates/ratings_new/src/grpc/user.rs +++ b/crates/ratings_new/src/grpc/user.rs @@ -1,7 +1,7 @@ use crate::{ conn, - context::Claims, db::{User, Vote}, + jwt::Claims, proto::user::{ user_server::{self, UserServer}, AuthenticateRequest, AuthenticateResponse, GetSnapVotesRequest, GetSnapVotesResponse, @@ -25,15 +25,8 @@ pub struct UserService { } impl UserService { - /// The paths which are accessible without authentication, if any - pub const PUBLIC_PATHS: [&'static str; 2] = [ - "ratings.features.user.User/Register", - "ratings.features.user.User/Authenticate", - ]; - - /// Converts this service into its corresponding server - pub fn to_server(self) -> UserServer { - UserServer::new(self) + pub fn new_server(ctx: Context) -> UserServer { + UserServer::new(Self { ctx: Arc::new(ctx) }) } } diff --git a/crates/ratings_new/src/jwt.rs b/crates/ratings_new/src/jwt.rs new file mode 100644 index 00000000..8380ba50 --- /dev/null +++ b/crates/ratings_new/src/jwt.rs @@ -0,0 +1,108 @@ +use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; +use secrecy::{ExposeSecret, SecretString}; +use serde::{Deserialize, Serialize}; +use time::{Duration, OffsetDateTime}; +use tonic::Status; +use tracing::error; + +/// How many days until JWT info expires +static JWT_EXPIRY_DAYS: i64 = 1; + +/// Errors that can happen while encoding and signing tokens with JWT. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("jwt: error decoding secret: {0}")] + DecodeSecretError(#[from] jsonwebtoken::errors::Error), + + #[error(transparent)] + Envy(#[from] envy::Error), + + #[error("jwt: an error occurred, but the reason was erased for security reasons")] + Erased, + + #[error("jwt: invalid shape")] + InvalidShape, + + #[error("jwt: invalid authz token")] + InvalidHeader, + + #[error(transparent)] + TonicStatus(#[from] Status), +} + +impl From for Status { + fn from(err: Error) -> Self { + match err { + Error::DecodeSecretError(_) => Status::unauthenticated("invalid JWT token"), + Error::InvalidHeader => Status::unauthenticated("invalid authz header"), + Error::TonicStatus(status) => status, + _ => Status::internal("Internal Server Error"), + } + } +} + +/// Information representating a claim on a specific subject at a specific time +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Claims { + /// The subject + pub sub: String, + /// The expiration time + pub exp: usize, +} + +impl Claims { + /// Creates a new claim with the current datetime for the subject given by `sub`. + pub fn new(sub: String) -> Self { + let exp = OffsetDateTime::now_utc() + Duration::days(JWT_EXPIRY_DAYS); + let exp = exp.unix_timestamp() as usize; + + Self { sub, exp } + } +} + +pub struct JwtEncoder { + encoding_key: EncodingKey, +} + +impl JwtEncoder { + pub fn from_secret(secret: &SecretString) -> Result { + let encoding_key = EncodingKey::from_base64_secret(secret.expose_secret())?; + + Ok(Self { encoding_key }) + } + + pub fn encode(&self, sub: String) -> Result { + let claims = Claims::new(sub); + + match jsonwebtoken::encode(&Header::default(), &claims, &self.encoding_key) { + Ok(s) => Ok(s), + Err(e) => { + error!("unable to encode jwt: {e}"); + Err(Error::Erased) + } + } + } +} + +#[derive(Clone)] +pub struct JwtVerifier { + decoding_key: DecodingKey, +} + +impl JwtVerifier { + /// Creates a new verifier from the given secret. + pub fn from_secret(secret: &SecretString) -> Result { + let decoding_key = DecodingKey::from_base64_secret(secret.expose_secret())?; + + Ok(Self { decoding_key }) + } + + pub fn decode(&self, token: &str) -> Result { + jsonwebtoken::decode::(token, &self.decoding_key, &Validation::default()) + .map(|t| t.claims) + .map_err(|e| { + error!("{e:?}"); + Error::InvalidShape + }) + } +} diff --git a/crates/ratings_new/src/lib.rs b/crates/ratings_new/src/lib.rs index 8407d6a3..5bbe5031 100644 --- a/crates/ratings_new/src/lib.rs +++ b/crates/ratings_new/src/lib.rs @@ -2,6 +2,7 @@ pub mod config; pub mod context; pub mod db; pub mod grpc; +pub mod jwt; pub mod middleware; pub mod proto; pub mod ratings; diff --git a/crates/ratings_new/src/main.rs b/crates/ratings_new/src/main.rs index e7a11a96..0b501ce0 100644 --- a/crates/ratings_new/src/main.rs +++ b/crates/ratings_new/src/main.rs @@ -1,3 +1,32 @@ -fn main() { - println!("Hello, world!"); +use ratings_new::{db::check_db_conn, grpc::run_server, Config, Context}; +use std::io::stdout; +use tracing::{info, subscriber::set_global_default}; +use tracing_subscriber::{EnvFilter, FmtSubscriber}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(stdout) + .json() + .flatten_event(true) + .with_span_list(true) + .with_current_span(false) + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .finish(); + + set_global_default(subscriber).expect("unable to set a global tracing subscriber"); + + info!("loading application context"); + let ctx = Context::new(Config::load()?)?; + + info!("checking DB connectivity"); + check_db_conn().await?; // Ensure that the migrations run before server start + + info!("starting server"); + run_server(ctx).await?; + + Ok(()) } diff --git a/crates/ratings_new/src/middleware.rs b/crates/ratings_new/src/middleware.rs index 8b137891..c26f938d 100644 --- a/crates/ratings_new/src/middleware.rs +++ b/crates/ratings_new/src/middleware.rs @@ -1 +1,103 @@ +//! A custom Tower [Layer] for validating jwt tokens and attaching the decoded claim to incoming +//! requests. +use crate::jwt::JwtVerifier; +use http::{Request, Response}; +use std::{ + error::Error, + future::Future, + mem::replace, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tonic::Status; +use tower::{Layer, Service}; +type BoxFuture = Pin + Send + 'static>>; +type BoxError = Box; + +/// The paths which are accessible without authentication +pub const PUBLIC_PATHS: [&str; 1] = ["ratings.features.user.User/Authenticate"]; + +#[derive(Clone)] +pub struct AuthLayer { + verifier: Arc, +} + +impl AuthLayer { + pub fn new(verifier: JwtVerifier) -> Self { + Self { + verifier: Arc::new(verifier), + } + } +} + +impl Layer for AuthLayer { + type Service = AuthMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + AuthMiddleware { + inner, + verifier: self.verifier.clone(), + } + } +} + +#[derive(Clone)] +pub struct AuthMiddleware { + inner: S, + verifier: Arc, +} + +// Helper for constructing the boxed errors we need to return from the Layer implementation below +macro_rules! unauthenticated { + ($msg:expr) => { + Box::pin(async move { Err(Box::new(Status::unauthenticated($msg)) as BoxError) }) + }; +} + +// The implementation here is based on the example provided by Tonic but with some type aliases and +// simplifying of a few of the generics to tailor things to our use case. +// +// https://github.com/hyperium/tonic/blob/master/examples/src/tower/server.rs +impl Service> for AuthMiddleware +where + S: Service, Response = Response, Error = BoxError> + Clone + Send + 'static, + S::Future: Send + 'static, + T: Send + 'static, +{ + type Response = Response; + type Error = BoxError; + type Future = BoxFuture>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let clone = self.inner.clone(); + let mut inner = replace(&mut self.inner, clone); + + if !PUBLIC_PATHS.iter().any(|s| req.uri().path().ends_with(s)) { + let header = match req.headers().get("authorization") { + Some(h) => h.to_str().unwrap(), + None => return unauthenticated!("missing auth header"), + }; + + let parts: Vec<&str> = header.split_whitespace().collect(); + if parts.len() != 2 { + return unauthenticated!("malformed auth header"); + } + + match self.verifier.decode(parts[1]) { + Ok(claims) => { + req.extensions_mut().insert(claims); + } + Err(_) => return unauthenticated!("invalid auth header"), + } + } + + Box::pin(async move { inner.call(req).await }) + } +} diff --git a/crates/ratings_new/src/ratings/rating.rs b/crates/ratings_new/src/ratings/rating.rs index edc70f0f..7f62314f 100644 --- a/crates/ratings_new/src/ratings/rating.rs +++ b/crates/ratings_new/src/ratings/rating.rs @@ -6,8 +6,8 @@ const INSUFFICIENT_VOTES_QUANTITY: i64 = 25; /// A descriptive mapping of a number of ratings to a general indicator of "how good" /// an app can be said to be. -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] -#[allow(missing_docs)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, strum::FromRepr)] +#[repr(i32)] pub enum RatingsBand { VeryGood = 0, Good = 1, diff --git a/crates/ratings_new/tests/authentication.rs b/crates/ratings_new/tests/authentication.rs new file mode 100644 index 00000000..94f282aa --- /dev/null +++ b/crates/ratings_new/tests/authentication.rs @@ -0,0 +1,32 @@ +pub mod common; + +use common::TestHelper; +use simple_test_case::test_case; + +#[test_case("notarealhash"; "short")] +#[test_case("abcdefghijkabcdefghijkabcdefghijkabcdefghijkabcdefghijkabcdefgh"; "one char too short")] +#[test_case("abcdefghijkabcdefghijkabcdefghijkabcdefghijkabcdefghijkabcdefghijk"; "one char too long")] +#[tokio::test] +async fn invalid_client_hashes_are_rejected(bad_hash: &str) -> anyhow::Result<()> { + let t = TestHelper::new(); + + let res = t.authenticate(bad_hash.to_string()).await; + assert!(res.is_err(), "{res:?}"); + + Ok(()) +} + +#[tokio::test] +async fn valid_client_hashes_can_authenticate_multiple_times() -> anyhow::Result<()> { + let t = TestHelper::new(); + let client_hash = t.random_sha_256(); + + let token1 = t.authenticate(client_hash.clone()).await?; + let token2 = t.authenticate(client_hash.clone()).await?; + + assert_eq!(token1, token2); + t.assert_valid_jwt(&token1); + t.assert_valid_jwt(&token2); + + Ok(()) +} diff --git a/crates/ratings_new/tests/chart.rs b/crates/ratings_new/tests/chart.rs new file mode 100644 index 00000000..d8b0400c --- /dev/null +++ b/crates/ratings_new/tests/chart.rs @@ -0,0 +1,82 @@ +// NOTE: this is not at all ideal, but in order to get tests around charts to work we need +// to be able to control the set of snaps that are in a given category. If tests start +// failing when you are adding new test cases then double check that you are not +// making use of any of the Categories that the tests in this file rely on. +pub mod common; + +use common::{Category, TestHelper}; +use rand::{thread_rng, Rng}; +use simple_test_case::test_case; + +// !! This test expects to be the only one making use of the "Development" category +#[tokio::test] +async fn category_chart_returns_expected_top_snap() -> anyhow::Result<()> { + let t = TestHelper::new(); + + // Generate a random set of snaps within the given category + for _ in 0..25 { + let client = t.clone(); + let (upvotes, downvotes) = random_votes(25, 50, 15, 35); + client + .test_snap_with_initial_votes(1, upvotes, downvotes, &[Category::Development]) + .await?; + } + + // A snap that should be returned as the top snap for the category + let snap_id = t + .test_snap_with_initial_votes(1, 50, 0, &[Category::Development]) + .await?; + + let user_token = t.authenticate(t.random_sha_256()).await?; + let mut data = t + .get_chart(Some(Category::Development), &user_token) + .await?; + + let top_snap = data[0].rating.take().expect("to have rating for top snap"); + assert_eq!(top_snap.snap_id, snap_id, "{top_snap:?}"); + + Ok(()) +} + +#[test_case(&[(0, 25), (10, 15), (25, 0)], &[2,1,0], Category::DevicesAndIot; "Creation order is reverse rating order")] +#[test_case(&[(27, 0), (25, 0), (26, 0)], &[0,2,1], Category::NewsAndWeather; "More positive votes is weighted higher")] +#[tokio::test] +async fn category_chart_returns_expected_order( + snap_votes: &[(u64, u64)], + expected_order: &[usize], + category: Category, +) -> anyhow::Result<()> { + let t = TestHelper::new(); + let mut ids = Vec::with_capacity(snap_votes.len()); + + for &(upvotes, downvotes) in snap_votes.iter() { + let id = t + .test_snap_with_initial_votes(1, upvotes, downvotes, &[category]) + .await?; + ids.push(id); + } + + let user_token = t.authenticate(t.random_sha_256()).await?; + let data = t.get_chart(Some(category), &user_token).await?; + + let chart_indicies: Vec = data + .into_iter() + .map(|c| { + ids.iter() + .position(|id| id == &c.rating.as_ref().unwrap().snap_id) + .unwrap() + }) + .collect(); + assert_eq!(&chart_indicies, expected_order); + + Ok(()) +} + +fn random_votes(min_vote: usize, max_vote: usize, min_up: usize, max_up: usize) -> (u64, u64) { + let mut rng = thread_rng(); + let upvotes = rng.gen_range(min_up..max_up); + let min_vote = Ord::max(upvotes, min_vote); + let votes = rng.gen_range(min_vote..=max_vote); + + (upvotes as u64, (votes - upvotes) as u64) +} diff --git a/crates/ratings_new/tests/clear-db.sql b/crates/ratings_new/tests/clear-db.sql new file mode 100644 index 00000000..58264efa --- /dev/null +++ b/crates/ratings_new/tests/clear-db.sql @@ -0,0 +1,3 @@ +DELETE FROM snap_categories; +DELETE FROM users; +DELETE FROM votes; diff --git a/crates/ratings_new/tests/common/mod.rs b/crates/ratings_new/tests/common/mod.rs new file mode 100644 index 00000000..8ac7e292 --- /dev/null +++ b/crates/ratings_new/tests/common/mod.rs @@ -0,0 +1,259 @@ +use anyhow::anyhow; +use futures::future::join_all; +use rand::{distributions::Alphanumeric, Rng}; +use ratings_new::{ + jwt::JwtVerifier, + proto::{ + app::{app_client::AppClient, GetRatingRequest}, + chart::{chart_client::ChartClient, ChartData, GetChartRequest, Timeframe}, + user::{ + user_client::UserClient, AuthenticateRequest, GetSnapVotesRequest, Vote, VoteRequest, + }, + }, + ratings::Rating, +}; +use reqwest::Client; +use secrecy::SecretString; +use serde::Deserialize; +use sha2::{Digest, Sha256}; +use std::fmt::Write; +use tonic::{ + metadata::MetadataValue, + transport::{Channel, Endpoint}, + Request, +}; + +// re-export to simplify setting up test data in the test files +pub use ratings_new::db::Category; + +// NOTE: these are set by the 'tests' Makefile target +const MOCK_ADMIN_URL: Option<&str> = option_env!("MOCK_ADMIN_URL"); +const HOST: Option<&str> = option_env!("HOST"); +const PORT: Option<&str> = option_env!("PORT"); + +macro_rules! client { + ($client:ident, $channel:expr, $token:expr) => { + $client::with_interceptor($channel, move |mut req: Request<()>| { + let header: MetadataValue<_> = format!("Bearer {}", $token).parse().unwrap(); + req.metadata_mut().insert("authorization", header); + + Ok(req) + }) + }; +} + +fn rnd_string(len: usize) -> String { + let rng = rand::thread_rng(); + rng.sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} + +#[derive(Debug, Default, Clone)] +pub struct TestHelper { + server_url: String, + mock_admin_url: &'static str, + client: Client, +} + +impl TestHelper { + pub fn new() -> Self { + Self { + server_url: format!( + "http://{}:{}/", + HOST.expect("the integration tests need to be run using make integration-test"), + PORT.expect("the integration tests need to be run using make integration-test") + ), + mock_admin_url: MOCK_ADMIN_URL.unwrap(), + client: Client::new(), + } + } + + pub fn assert_valid_jwt(&self, value: &str) { + dotenvy::dotenv().ok(); + let JwtConfig { jwt_secret } = envy::prefixed("APP_").from_env::().unwrap(); + let verifier = JwtVerifier::from_secret(&jwt_secret).expect("unable to init JwtVerifier"); + + assert!( + verifier.decode(value).is_ok(), + "value should be a valid jwt" + ); + + // serde structs + #[derive(Deserialize)] + struct JwtConfig { + jwt_secret: SecretString, + } + } + + /// NOTE: total needs to be above 25 in order to generate a rating + pub async fn test_snap_with_initial_votes( + &self, + revision: i32, + upvotes: u64, + downvotes: u64, + categories: &[Category], + ) -> anyhow::Result { + let snap_id = self.random_id(); + let str_categories: Vec = categories.iter().map(|c| c.to_string()).collect(); + self.client + .post(format!("{}/{snap_id}", self.mock_admin_url)) + .body(str_categories.join(",")) + .send() + .await?; + + if upvotes > 0 { + self.generate_votes(&snap_id, revision, true, upvotes) + .await?; + } + if downvotes > 0 { + self.generate_votes(&snap_id, revision, false, downvotes) + .await?; + } + + Ok(snap_id) + } + + pub fn random_sha_256(&self) -> String { + let data = rnd_string(100); + let mut hasher = Sha256::new(); + hasher.update(data); + + hasher + .finalize() + .iter() + .fold(String::new(), |mut output, b| { + // This ignores the error without the overhead of unwrap/expect, + // This is okay because writing to a string can't fail (barring OOM which won't happen) + let _ = write!(output, "{b:02x}"); + output + }) + } + + pub fn random_id(&self) -> String { + rnd_string(32) + } + + async fn register_and_vote( + &self, + snap_id: &str, + snap_revision: i32, + vote_up: bool, + ) -> anyhow::Result<()> { + let id: String = self.random_sha_256(); + let token = self.authenticate(id.clone()).await?; + self.vote(snap_id, snap_revision, vote_up, &token).await?; + + Ok(()) + } + + pub async fn generate_votes( + &self, + snap_id: &str, + snap_revision: i32, + vote_up: bool, + count: u64, + ) -> anyhow::Result<()> { + let mut tasks = Vec::with_capacity(count as usize); + + for _ in 0..count { + let snap_id = snap_id.to_string(); + let client = self.clone(); + + tasks.push(tokio::spawn(async move { + client + .register_and_vote(&snap_id, snap_revision, vote_up) + .await + })); + } + + for res in join_all(tasks).await { + // Unwrapping twice as the join itself can error as well as the + // underlying call to register_and_vote. + // This is here so that tests panic in test generation if there + // are any issues rather than carrying on with malformed data + res.unwrap().unwrap(); + } + + Ok(()) + } + + async fn channel(&self) -> Channel { + Endpoint::from_shared(self.server_url.clone()) + .expect("failed to create Endpoint") + .connect() + .await + .expect("failed to connect") + } + + pub async fn get_rating(&self, id: &str, token: &str) -> anyhow::Result { + let resp = client!(AppClient, self.channel().await, token) + .get_rating(GetRatingRequest { + snap_id: id.to_string(), + }) + .await? + .into_inner(); + + resp.rating + .map(Into::into) + .ok_or(anyhow!("no rating for {id}")) + } + + pub async fn get_chart( + &self, + category: Option, + token: &str, + ) -> anyhow::Result> { + let resp = client!(ChartClient, self.channel().await, token) + .get_chart(GetChartRequest { + timeframe: Timeframe::Unspecified.into(), + category: category.map(|v| v as i32), + }) + .await? + .into_inner(); + + Ok(resp.ordered_chart_data) + } + + pub async fn vote( + &self, + snap_id: &str, + snap_revision: i32, + vote_up: bool, + token: &str, + ) -> anyhow::Result<()> { + client!(UserClient, self.channel().await, token) + .vote(VoteRequest { + snap_id: snap_id.to_string(), + snap_revision, + vote_up, + }) + .await?; + + Ok(()) + } + + pub async fn get_snap_votes( + &self, + token: &str, + request: GetSnapVotesRequest, + ) -> anyhow::Result> { + let resp = client!(UserClient, self.channel().await, token) + .get_snap_votes(request) + .await? + .into_inner(); + + Ok(resp.votes) + } + + pub async fn authenticate(&self, id: String) -> anyhow::Result { + let resp = UserClient::connect(self.server_url.clone()) + .await? + .authenticate(AuthenticateRequest { id }) + .await? + .into_inner(); + + Ok(resp.token) + } +} diff --git a/crates/ratings_new/tests/voting.rs b/crates/ratings_new/tests/voting.rs new file mode 100644 index 00000000..2827ef62 --- /dev/null +++ b/crates/ratings_new/tests/voting.rs @@ -0,0 +1,100 @@ +pub mod common; + +use common::{Category, TestHelper}; +use ratings_new::ratings::RatingsBand::{self, *}; +use simple_test_case::test_case; + +#[test_case(true; "up vote")] +#[test_case(false; "down vote")] +#[tokio::test] +async fn voting_increases_vote_count(vote_up: bool) -> anyhow::Result<()> { + let t = TestHelper::new(); + + let user_token = t.authenticate(t.random_sha_256()).await?; + let snap_revision = 1; + let snap_id = t + .test_snap_with_initial_votes(snap_revision, 3, 2, &[Category::Social]) + .await?; + + let initial_rating = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(initial_rating.total_votes, 5, "initial total votes"); + + // Vote with a user who has not previously voted for this snap + t.vote(&snap_id, snap_revision, vote_up, &user_token) + .await?; + + let rating = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(rating.total_votes, 6, "total votes: vote_up={vote_up}"); + + Ok(()) +} + +#[test_case(true; "up to down vote")] +#[test_case(false; "down to up vote")] +#[tokio::test] +async fn changing_your_vote_doesnt_alter_total(initial_up: bool) -> anyhow::Result<()> { + let t = TestHelper::new(); + + let user_token = t.authenticate(t.random_sha_256()).await?; + let snap_revision = 1; + let snap_id = t + .test_snap_with_initial_votes(snap_revision, 3, 2, &[Category::Social]) + .await?; + + let initial_rating = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(initial_rating.total_votes, 5, "initial total votes"); + + // Vote with a user who has not previously voted for this snap + t.vote(&snap_id, snap_revision, initial_up, &user_token) + .await?; + + let rating = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(rating.total_votes, 6, "total votes"); + + // That user changing their vote shouldn't alter the total + t.vote(&snap_id, snap_revision, !initial_up, &user_token) + .await?; + + let rating = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(rating.total_votes, 6, "total votes"); + + Ok(()) +} + +// The ratings bands details are found in ../src/features/common/entities.rs and the following +// test expects the break points for the value of the confidence interval: +// +// 0.80 < r - VeryGood +// 0.55 < r <= 0.80 - Good +// 0.45 < r <= 0.55 - Neutral +// 0.20 < r <= 0.45 - Poor +// r <= 0.20 - VeryPoor +// +// NOTE: In order to generate a rating we need to have at least 25 votes +#[test_case(true, Neutral, Good; "neutral to good")] +#[test_case(false, Neutral, Poor; "neutral to poor")] +#[tokio::test] +async fn voting_updates_ratings_band( + vote_up: bool, + initial_band: RatingsBand, + new_band: RatingsBand, +) -> anyhow::Result<()> { + let t = TestHelper::new(); + + let user_token = t.authenticate(t.random_sha_256()).await?; + let snap_revision = 1; + let snap_id = t + .test_snap_with_initial_votes(snap_revision, 60, 40, &[Category::Games]) + .await?; + + let r = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(r.ratings_band, initial_band, "initial band"); + + t.generate_votes(&snap_id, snap_revision, vote_up, 50) + .await?; + + let r = t.get_rating(&snap_id, &user_token).await?; + assert_eq!(r.ratings_band, new_band, "new band"); + + Ok(()) +} diff --git a/docker-compose.yml b/docker-compose.yml index 70d9d21e..35ee70c2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,7 +23,7 @@ services: ports: - 8080:8080 environment: - RUST_LOG: "info,hyper=error" + RUST_LOG: "ratings_new=debug,hyper=error" APP_LOG_LEVEL: "info" APP_ENV: "dev" APP_HOST: "0.0.0.0" @@ -36,7 +36,8 @@ services: APP_ADMIN_USER: "shadow" APP_ADMIN_PASSWORD: "maria" volumes: - - .:/app + # - .:/app + - ./crates/ratings_new:/app - cargo-cache:/usr/local/cargo/registry - target-cache:/app/target entrypoint: "cargo watch -i 'tests/**' -x run"