Skip to content

Commit

Permalink
fix: first pass at fixing the behaviour of update_category (now updat…
Browse files Browse the repository at this point in the history
…e_categories)
  • Loading branch information
sminez committed Oct 11, 2024
1 parent 8518eac commit 9622a7e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 27 deletions.
93 changes: 72 additions & 21 deletions src/features/user/infrastructure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use snapd::{
api::{convenience::SnapNameFromId, find::FindSnapByName},
SnapdClient,
};
use sqlx::{Acquire, Executor, Row};
use sqlx::{Postgres, QueryBuilder, Row};
use std::sync::Arc;
use tokio::sync::Notify;
use tracing::error;

use crate::{
Expand Down Expand Up @@ -200,8 +202,15 @@ async fn snapd_categories_by_snap_id(
.collect())
}

/// Update the category (we do this every time we get a vote for the time being)
pub(crate) async fn update_category(app_ctx: &AppContext, snap_id: &str) -> Result<(), UserError> {
/// Update the categories for a given snap.
///
/// In the case where we do not have categories, we need to fetch them and store them in the DB.
/// This is racey without coordination so we check to see if any other tasks are currently attempting
/// this and block on them completing if they are, if not then we set up the Notify and they block on us.
pub(crate) async fn update_categories(
app_ctx: &AppContext,
snap_id: &str,
) -> Result<(), UserError> {
let mut pool = app_ctx
.infrastructure()
.repository()
Expand All @@ -211,30 +220,72 @@ pub(crate) async fn update_category(app_ctx: &AppContext, snap_id: &str) -> Resu
UserError::Unknown
})?;

let snapd_client = &app_ctx.infrastructure().snapd_client;
// Take the mutex first so we don't race between checking the current table state and updating
let mut guard = app_ctx.infrastructure().category_updates.lock().await;

let (n_rows,): (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM snap_categories WHERE snap_id = $1;")
.bind(snap_id)
.fetch_one(&mut *pool)
.await
.map_err(|error| {
error!("{error:?}");
UserError::FailedToCastVote
})?;

// If we have categories for the requested snap in place already then skip updating.
// Eventually we will need to update and refresh categories over time but the assumption for now is
// that snap categories do not change frequently so we do not need to eagerly update them.
if n_rows > 0 {
return Ok(());
}

let (notifier, should_wait) = match guard.get(&snap_id.to_string()) {
Some(notifier) => (notifier.clone(), true),
None => (Arc::new(Notify::new()), false),
};

if should_wait {
// Another task is updating the categories for this snap so wait for it to complete and then
// return: https://docs.rs/tokio/latest/tokio/sync/struct.Notify.html#method.notified
drop(guard);
notifier.notified().await;
return Ok(());
}

// At this point we can release the mutex for other calls to update_categories to proceed while
// we update the DB state for the snap_id we are interested in. Any calls between now and when
// we complete the update will block on the notifier we insert here.
guard.insert(snap_id.to_string(), notifier.clone());
drop(guard);

let snapd_client = &app_ctx.infrastructure().snapd_client;
let categories = snapd_categories_by_snap_id(snapd_client, snap_id).await?;

// Do a transaction because bulk querying doesn't seem to work cleanly
let mut tx = pool.begin().await?;
// The trailing space after the query here is important as the builder will append directly to
// the string provided.
let mut query_builder: QueryBuilder<Postgres> =
QueryBuilder::new("INSERT INTO snap_categories(snap_id, category) ");

// Reset the categories since we're refreshing all of them
tx.execute(
sqlx::query("DELETE FROM snap_categories WHERE snap_categories.snap_id = $1;")
.bind(snap_id),
)
.await?;
query_builder.push_values(categories, |mut b, category| {
b.push_bind(snap_id).push_bind(category);
});

for category in categories.iter() {
tx.execute(
sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1, $2); ")
.bind(snap_id)
.bind(category),
)
.await?;
}
query_builder
.build()
.execute(&mut *pool)
.await
.map_err(|error| {
error!("{error:?}");
UserError::FailedToCastVote
})?;

// Grab the mutex around the category_updates so any incoming tasks block behind us and then
// notify all blocked tasks before removing the Notify from the map.
let mut guard = app_ctx.infrastructure().category_updates.lock().await;
notifier.notify_waiters();
guard.remove(&snap_id.to_string());

tx.commit().await?;
Ok(())
}

Expand Down
12 changes: 7 additions & 5 deletions src/features/user/use_cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
},
};

use super::infrastructure::update_category;
use super::infrastructure::update_categories;

/// Create a [`User`] entry, or note that the user has recently been seen, within the current
/// [`AppContext`].
Expand All @@ -36,11 +36,12 @@ pub async fn delete_user(app_ctx: &AppContext, client_hash: &str) -> Result<(),
#[allow(unused_must_use)]
pub async fn vote(app_ctx: &AppContext, vote: Vote) -> Result<(), UserError> {
// Ignore but log warning, it's not fatal
update_category(app_ctx, &vote.snap_id)
update_categories(app_ctx, &vote.snap_id)
.await
.inspect_err(|e| warn!("{}", e));
let result = save_vote_to_db(app_ctx, vote).await;
result?;

save_vote_to_db(app_ctx, vote).await?;

Ok(())
}

Expand All @@ -54,9 +55,10 @@ pub async fn get_snap_votes(
client_hash: String,
) -> Result<Vec<Vote>, UserError> {
// Ignore but log warning, it's not fatal
update_category(app_ctx, &snap_id)
update_categories(app_ctx, &snap_id)
.await
.inspect_err(|e| warn!("{}", e));

get_snap_votes_by_client_hash(app_ctx, snap_id, client_hash).await
}

Expand Down
7 changes: 6 additions & 1 deletion src/utils/infrastructure.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//! Utilities and structs for creating server infrastructure (database, etc).
use std::{
collections::HashMap,
error::Error,
fmt::{Debug, Formatter},
sync::Arc,
};

use snapd::SnapdClient;
use sqlx::{pool::PoolConnection, postgres::PgPoolOptions, PgPool, Postgres};
use tokio::sync::OnceCell;
use tokio::sync::{Mutex, Notify, OnceCell};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{reload::Handle, Registry};

Expand All @@ -29,6 +30,9 @@ pub struct Infrastructure {
pub log_reload_handle: &'static Handle<LevelFilter, Registry>,
/// The utility which lets us encode user tokens with our JWT credentials
pub jwt_encoder: Arc<JwtEncoder>,
/// In progress category updates that we need to block on
/// FIXME: The logic for this should really live here but it's all DB related.
pub category_updates: Arc<Mutex<HashMap<String, Arc<Notify>>>>,
}

impl Infrastructure {
Expand All @@ -52,6 +56,7 @@ impl Infrastructure {
jwt_encoder,
snapd_client: Default::default(),
log_reload_handle: reload_handle,
category_updates: Default::default(),
})
}

Expand Down

0 comments on commit 9622a7e

Please sign in to comment.