Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Per ingredient sync table #650

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::{any::Any, fmt, ptr::NonNull};
use crate::{
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
cycle::{CycleRecoveryAction, CycleRecoveryStrategy},
function::sync::{ClaimResult, SyncTable},
ingredient::fmt_index,
key::DatabaseKeyIndex,
plumbing::MemoIngredientMap,
salsa_struct::SalsaStructInDb,
table::sync::ClaimResult,
table::Table,
views::DatabaseDownCaster,
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
Expand All @@ -32,6 +32,7 @@ mod lru;
mod maybe_changed_after;
mod memo;
mod specify;
mod sync;

pub trait Configuration: Any {
const DEBUG_NAME: &'static str;
Expand Down Expand Up @@ -119,6 +120,8 @@ pub struct IngredientImpl<C: Configuration> {
/// instances that this downcaster was derived from.
view_caster: DatabaseDownCaster<C::DbView>,

sync_table: SyncTable,

/// When `fetch` and friends executes, they return a reference to the
/// value stored in the memo that is extended to live as long as the `&self`
/// reference we start with. This means that whenever we remove something
Expand Down Expand Up @@ -157,6 +160,7 @@ where
lru: lru::Lru::new(lru),
deleted_entries: Default::default(),
view_caster,
sync_table: SyncTable::new(index),
}
}

Expand Down Expand Up @@ -252,12 +256,7 @@ where
/// Attempts to claim `key_index`, returning `false` if a cycle occurs.
fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool {
let zalsa = db.zalsa();
match zalsa.sync_table_for(key_index).claim(
db,
zalsa,
self.database_key_index(key_index),
self.memo_ingredient_index(zalsa, key_index),
) {
match self.sync_table.try_claim(db, zalsa, key_index) {
ClaimResult::Retry | ClaimResult::Claimed(_) => true,
ClaimResult::Cycle => false,
}
Expand Down
16 changes: 5 additions & 11 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::{memo::Memo, Configuration, IngredientImpl, VerifyResult};
use crate::function::sync::ClaimResult;
use crate::zalsa::MemoIngredientIndex;
use crate::{
accumulator::accumulated_map::InputAccumulatedValues,
runtime::StampedValue,
table::sync::ClaimResult,
zalsa::{Zalsa, ZalsaDatabase},
zalsa_local::QueryRevisions,
AsDynDatabase as _, Id,
Expand Down Expand Up @@ -103,17 +103,11 @@ where
id: Id,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<C::Output<'db>>> {
let database_key_index = self.database_key_index(id);

// Try to claim this query: if someone else has claimed it already, go back and start again.
let _claim_guard = match zalsa.sync_table_for(id).claim(
db,
zalsa,
database_key_index,
memo_ingredient_index,
) {
let _claim_guard = match self.sync_table.try_claim(db, zalsa, id) {
ClaimResult::Retry => return None,
ClaimResult::Cycle => {
let database_key_index = self.database_key_index(id);
// check if there's a provisional value for this query
// Note we don't `validate_may_be_provisional` the memo here as we want to reuse an
// existing provisional memo if it exists
Expand All @@ -129,7 +123,7 @@ where
}
// no provisional value; create/insert/return initial provisional value
return self
.initial_value(db, database_key_index.key_index())
.initial_value(db, id)
.map(|initial_value| {
tracing::debug!(
"hit cycle at {database_key_index:#?}, \
Expand Down Expand Up @@ -160,7 +154,7 @@ where
};

// Push the query on the stack.
let active_query = db.zalsa_local().push_query(database_key_index);
let active_query = db.zalsa_local().push_query(self.database_key_index(id));

// Now that we've claimed the item, check again to see if there's a "hot" value.
let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
Expand Down
9 changes: 2 additions & 7 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
accumulator::accumulated_map::InputAccumulatedValues,
cycle::{CycleHeads, CycleRecoveryStrategy},
function::sync::ClaimResult,
key::DatabaseKeyIndex,
table::sync::ClaimResult,
zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase},
zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin},
AsDynDatabase as _, Id, Revision,
Expand Down Expand Up @@ -98,12 +98,7 @@ where
) -> Option<VerifyResult> {
let database_key_index = self.database_key_index(key_index);

let _claim_guard = match zalsa.sync_table_for(key_index).claim(
db,
zalsa,
database_key_index,
memo_ingredient_index,
) {
let _claim_guard = match self.sync_table.try_claim(db, zalsa, key_index) {
ClaimResult::Retry => return None,
ClaimResult::Cycle => match C::CYCLE_STRATEGY {
CycleRecoveryStrategy::Panic => panic!(
Expand Down
98 changes: 50 additions & 48 deletions src/table/sync.rs → src/function/sync.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
use std::thread::ThreadId;

use parking_lot::Mutex;
use rustc_hash::FxHashMap;

use crate::{
key::DatabaseKeyIndex,
runtime::{BlockResult, WaitResult},
zalsa::{MemoIngredientIndex, Zalsa},
Database,
zalsa::Zalsa,
Database, Id, IngredientIndex,
};

use super::util;

/// Tracks the keys that are currently being processed; used to coordinate between
/// worker threads.
#[derive(Default)]
pub(crate) struct SyncTable {
syncs: Mutex<Vec<Option<SyncState>>>,
syncs: Mutex<FxHashMap<Id, SyncState>>,
ingredient: IngredientIndex,
}

pub(crate) enum ClaimResult<'a> {
Retry,
Cycle,
Claimed(ClaimGuard<'a>),
}

struct SyncState {
Expand All @@ -26,59 +31,56 @@ struct SyncState {
anyone_waiting: bool,
}

pub(crate) enum ClaimResult<'a> {
Retry,
Cycle,
Claimed(ClaimGuard<'a>),
}

impl SyncTable {
#[inline]
pub(crate) fn claim<'me>(
pub(crate) fn new(ingredient: IngredientIndex) -> Self {
Self {
syncs: Default::default(),
ingredient,
}
}

pub(crate) fn try_claim<'me>(
&'me self,
db: &'me (impl ?Sized + Database),
zalsa: &'me Zalsa,
database_key_index: DatabaseKeyIndex,
memo_ingredient_index: MemoIngredientIndex,
key_index: Id,
) -> ClaimResult<'me> {
let mut syncs = self.syncs.lock();
let thread_id = std::thread::current().id();

util::ensure_vec_len(&mut syncs, memo_ingredient_index.as_usize() + 1);

match &mut syncs[memo_ingredient_index.as_usize()] {
None => {
syncs[memo_ingredient_index.as_usize()] = Some(SyncState {
id: thread_id,
anyone_waiting: false,
});
ClaimResult::Claimed(ClaimGuard {
database_key_index,
memo_ingredient_index,
zalsa,
sync_table: self,
_padding: false,
})
}
Some(SyncState {
id: other_id,
anyone_waiting,
}) => {
let mut write = self.syncs.lock();
match write.entry(key_index) {
std::collections::hash_map::Entry::Occupied(occupied_entry) => {
let &mut SyncState {
id,
ref mut anyone_waiting,
} = occupied_entry.into_mut();
// NB: `Ordering::Relaxed` is sufficient here,
// as there are no loads that are "gated" on this
// value. Everything that is written is also protected
// by a lock that must be acquired. The role of this
// boolean is to decide *whether* to acquire the lock,
// not to gate future atomic reads.
*anyone_waiting = true;
match zalsa
.runtime()
.block_on(db, database_key_index, *other_id, syncs)
{
match zalsa.runtime().block_on(
db,
DatabaseKeyIndex::new(self.ingredient, key_index),
id,
write,
) {
BlockResult::Completed => ClaimResult::Retry,
BlockResult::Cycle => ClaimResult::Cycle,
}
}
std::collections::hash_map::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(SyncState {
id: std::thread::current().id(),
anyone_waiting: false,
});
ClaimResult::Claimed(ClaimGuard {
key_index,
zalsa,
sync_table: self,
_padding: false,
})
}
}
}
}
Expand All @@ -87,8 +89,7 @@ impl SyncTable {
/// released when this value is dropped.
#[must_use]
pub(crate) struct ClaimGuard<'me> {
database_key_index: DatabaseKeyIndex,
memo_ingredient_index: MemoIngredientIndex,
key_index: Id,
zalsa: &'me Zalsa,
sync_table: &'me SyncTable,
// Reduce the size of ClaimResult by making more niches available in ClaimGuard; this fits into
Expand All @@ -100,12 +101,13 @@ impl ClaimGuard<'_> {
fn remove_from_map_and_unblock_queries(&self) {
let mut syncs = self.sync_table.syncs.lock();

let SyncState { anyone_waiting, .. } =
syncs[self.memo_ingredient_index.as_usize()].take().unwrap();
let SyncState { anyone_waiting, .. } = syncs.remove(&self.key_index).unwrap();

drop(syncs);

if anyone_waiting {
self.zalsa.runtime().unblock_queries_blocked_on(
self.database_key_index,
DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index),
if std::thread::panicking() {
WaitResult::Panicked
} else {
Expand Down
11 changes: 1 addition & 10 deletions src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
input::singleton::{Singleton, SingletonChoice},
key::DatabaseKeyIndex,
plumbing::{Jar, Stamp},
table::{memo::MemoTable, sync::SyncTable, Slot, Table},
table::{memo::MemoTable, Slot, Table},
zalsa::{IngredientIndex, Zalsa},
zalsa_local::QueryOrigin,
Database, Durability, Id, Revision, Runtime,
Expand Down Expand Up @@ -107,7 +107,6 @@ impl<C: Configuration> IngredientImpl<C> {
fields,
stamps,
memos: Default::default(),
syncs: Default::default(),
})
});

Expand Down Expand Up @@ -286,9 +285,6 @@ where

/// Memos
memos: MemoTable,

/// Syncs
syncs: SyncTable,
}

impl<C> Value<C>
Expand Down Expand Up @@ -322,9 +318,4 @@ where
fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable {
&mut self.memos
}

#[inline]
unsafe fn syncs(&self, _current_revision: Revision) -> &SyncTable {
&self.syncs
}
}
8 changes: 0 additions & 8 deletions src/interned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::ingredient::fmt_index;
use crate::plumbing::{IngredientIndices, Jar};
use crate::revision::AtomicRevision;
use crate::table::memo::MemoTable;
use crate::table::sync::SyncTable;
use crate::table::Slot;
use crate::zalsa::{IngredientIndex, Zalsa};
use crate::zalsa_local::QueryOrigin;
Expand Down Expand Up @@ -73,7 +72,6 @@ where
{
fields: C::Fields<'static>,
memos: MemoTable,
syncs: SyncTable,

/// The revision the value was first interned in.
first_interned_at: Revision,
Expand Down Expand Up @@ -293,7 +291,6 @@ where
let id = zalsa_local.allocate(table, self.ingredient_index, |id| Value::<C> {
fields: unsafe { self.to_internal_data(assemble(id, key)) },
memos: Default::default(),
syncs: Default::default(),
durability: AtomicU8::new(durability.as_u8()),
// Record the revision we are interning in.
first_interned_at: current_revision,
Expand Down Expand Up @@ -481,11 +478,6 @@ where
fn memos_mut(&mut self) -> &mut MemoTable {
&mut self.memos
}

#[inline]
unsafe fn syncs(&self, _current_revision: Revision) -> &crate::table::sync::SyncTable {
&self.syncs
}
}

/// A trait for types that hash and compare like `O`.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl Runtime {
other_id: ThreadId,
query_mutex_guard: QueryMutexGuard,
) -> BlockResult {
let mut dg = self.dependency_graph.lock();
let dg = self.dependency_graph.lock();
let thread_id = std::thread::current().id();

if dg.depends_on(other_id, thread_id) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/dependency_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl DependencyGraph {
/// True if `from_id` depends on `to_id`.
///
/// (i.e., there is a path from `from_id` to `to_id` in the graph.)
pub(super) fn depends_on(&mut self, from_id: ThreadId, to_id: ThreadId) -> bool {
pub(super) fn depends_on(&self, from_id: ThreadId, to_id: ThreadId) -> bool {
let mut p = from_id;
while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) {
if q == to_id {
Expand Down
Loading