Skip to content

Commit

Permalink
Implement resolution for local Vcs
Browse files Browse the repository at this point in the history
  • Loading branch information
bgw committed Jul 29, 2024
1 parent ba7d88f commit 4c2a735
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 41 deletions.
55 changes: 53 additions & 2 deletions crates/turbo-tasks-memory/tests/local_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Wrapper(u32);
struct TransparentWrapper(u32);

#[tokio::test]
async fn store_and_read() {
async fn test_store_and_read() {
run(&REGISTRATION, async {
let a: Vc<u32> = Vc::local_cell(42);
assert_eq!(*a.await.unwrap(), 42);
Expand All @@ -27,7 +27,7 @@ async fn store_and_read() {
}

#[tokio::test]
async fn store_and_read_generic() {
async fn test_store_and_read_generic() {
run(&REGISTRATION, async {
// `Vc<Vec<Vc<T>>>` is stored as `Vc<Vec<Vc<()>>>` and requires special
// transmute handling
Expand All @@ -43,3 +43,54 @@ async fn store_and_read_generic() {
})
.await
}

#[turbo_tasks::function]
async fn returns_resolved_local_vc() -> Vc<u32> {
Vc::<u32>::local_cell(42).resolve().await.unwrap()
}

#[tokio::test]
async fn test_return_resolved() {
run(&REGISTRATION, async {
assert_eq!(*returns_resolved_local_vc().await.unwrap(), 42);
})
.await
}

#[turbo_tasks::value(eq = "manual")]
#[derive(Default)]
struct Untracked {
#[turbo_tasks(debug_ignore, trace_ignore)]
#[serde(skip)]
cell: Vc<u32>,
}

impl PartialEq for Untracked {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self as *const _, other as *const _)
}
}

impl Eq for Untracked {}

#[turbo_tasks::function]
async fn get_untracked_local_cell() -> Vc<Untracked> {
Untracked {
cell: Vc::local_cell(42),
}
.cell()
}

#[tokio::test]
#[should_panic(expected = "Local Vcs must only be accessed within their own task")]
async fn test_panics_on_local_cell_escape() {
run(&REGISTRATION, async {
get_untracked_local_cell()
.await
.unwrap()
.cell
.await
.unwrap();
})
.await
}
45 changes: 44 additions & 1 deletion crates/turbo-tasks/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
manager::TurboTasksBackendApi,
raw_vc::CellId,
registry,
task::shared_reference::TypedSharedReference,
trait_helpers::{get_trait_method, has_trait, traits},
triomphe_utils::unchecked_sidecast_triomphe_arc,
FunctionId, RawVc, ReadRef, SharedReference, TaskId, TaskIdProvider, TaskIdSet, TraitRef,
Expand Down Expand Up @@ -370,7 +371,7 @@ impl TypedCellContent {
.1
.0
.ok_or_else(|| anyhow!("Cell is empty"))?
.typed(self.0);
.into_typed(self.0);
Ok(
// Safety: It is a TypedSharedReference
TraitRef::new(shared_reference),
Expand All @@ -382,12 +383,54 @@ impl TypedCellContent {
}
}

impl From<TypedSharedReference> for TypedCellContent {
fn from(value: TypedSharedReference) -> Self {
TypedCellContent(value.0, CellContent(Some(value.1)))
}
}

impl TryFrom<TypedCellContent> for TypedSharedReference {
type Error = TypedCellContent;

fn try_from(content: TypedCellContent) -> Result<Self, TypedCellContent> {
if let TypedCellContent(type_id, CellContent(Some(shared_reference))) = content {
Ok(TypedSharedReference(type_id, shared_reference))
} else {
Err(content)
}
}
}

impl CellContent {
pub fn into_typed(self, type_id: ValueTypeId) -> TypedCellContent {
TypedCellContent(type_id, self)
}
}

impl From<SharedReference> for CellContent {
fn from(value: SharedReference) -> Self {
CellContent(Some(value))
}
}

impl From<Option<SharedReference>> for CellContent {
fn from(value: Option<SharedReference>) -> Self {
CellContent(value)
}
}

impl TryFrom<CellContent> for SharedReference {
type Error = CellContent;

fn try_from(content: CellContent) -> Result<Self, CellContent> {
if let CellContent(Some(shared_reference)) = content {
Ok(shared_reference)
} else {
Err(content)
}
}
}

pub type TaskCollectiblesMap = AutoMap<RawVc, i32, BuildHasherDefault<FxHasher>, 1>;

pub trait Backend: Sync + Send {
Expand Down
21 changes: 17 additions & 4 deletions crates/turbo-tasks/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::{
magic_any::MagicAny,
raw_vc::{CellId, RawVc},
registry,
task::shared_reference::TypedSharedReference,
trace::TraceRawVcs,
trait_helpers::get_trait_method,
util::StaticOrArc,
Expand Down Expand Up @@ -274,7 +275,7 @@ struct CurrentTaskState {

/// Cells for locally allocated Vcs (`RawVc::LocalCell`). This is freed
/// (along with `CurrentTaskState`) when the task finishes executing.
local_cells: Vec<TypedCellContent>,
local_cells: Vec<TypedSharedReference>,
}

impl CurrentTaskState {
Expand Down Expand Up @@ -1542,7 +1543,12 @@ pub(crate) async fn read_task_cell(
}
}

#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
/// A reference to a task's cell with methods that allow updating the contents
/// of the cell.
///
/// Mutations should not outside of the task that that owns this cell. Doing so
/// is a logic error, and may lead to incorrect caching behavior.
#[derive(Clone, Copy, Serialize, Deserialize)]
pub struct CurrentCellRef {
current_task: TaskId,
index: CellId,
Expand Down Expand Up @@ -1715,7 +1721,7 @@ pub fn find_cell_by_type(ty: ValueTypeId) -> CurrentCellRef {
})
}

pub(crate) fn create_local_cell(value: TypedCellContent) -> (ExecutionId, LocalCellId) {
pub(crate) fn create_local_cell(value: TypedSharedReference) -> (ExecutionId, LocalCellId) {
CURRENT_TASK_STATE.with(|cell| {
let CurrentTaskState {
execution_id,
Expand All @@ -1738,11 +1744,18 @@ pub(crate) fn create_local_cell(value: TypedCellContent) -> (ExecutionId, LocalC
})
}

/// Returns the contents of the given local cell. Panics if a local cell is
/// attempted to be accessed outside of its task.
///
/// Returns [`TypedSharedReference`] instead of [`TypedCellContent`] because
/// local cells are always filled. The returned value can be cheaply converted
/// with `.into()`.
///
/// Panics if the ExecutionId does not match the expected value.
pub(crate) fn read_local_cell(
execution_id: ExecutionId,
local_cell_id: LocalCellId,
) -> TypedCellContent {
) -> TypedSharedReference {
CURRENT_TASK_STATE.with(|cell| {
let CurrentTaskState {
execution_id: expected_execution_id,
Expand Down
4 changes: 2 additions & 2 deletions crates/turbo-tasks/src/persisted_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct SerializableTaskCell(Option<Option<TypedSharedReference>>);
impl From<SerializableTaskCell> for TaskCell {
fn from(val: SerializableTaskCell) -> Self {
match val.0 {
Some(d) => TaskCell::Content(CellContent(d.map(|d| d.untyped().1))),
Some(d) => TaskCell::Content(d.map(TypedSharedReference::into_untyped).into()),
None => TaskCell::NeedComputation,
}
}
Expand All @@ -56,7 +56,7 @@ impl Serialize for TaskCells {
for (cell_id, cell) in &self.0 {
let task_cell = SerializableTaskCell(match cell {
TaskCell::Content(CellContent(opt)) => {
Some(opt.as_ref().map(|d| d.typed(cell_id.type_id)))
Some(opt.clone().map(|d| d.into_typed(cell_id.type_id)))
}
TaskCell::NeedComputation => None,
});
Expand Down
31 changes: 12 additions & 19 deletions crates/turbo-tasks/src/raw_vc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,15 @@ impl RawVc {

/// See [`crate::Vc::resolve`].
pub(crate) async fn resolve(self) -> Result<RawVc> {
let tt = turbo_tasks();
let mut current = self;
let mut notified = false;
loop {
match current {
RawVc::TaskOutput(task) => {
if !notified {
tt.notify_scheduled_tasks();
notified = true;
}
current = read_task_output(&*tt, task, false).await?;
}
RawVc::TaskCell(_, _) => return Ok(current),
RawVc::LocalCell(_, _) => todo!(),
}
}
self.resolve_inner(/* strongly_consistent */ false).await
}

/// See [`crate::Vc::resolve_strongly_consistent`].
pub(crate) async fn resolve_strongly_consistent(self) -> Result<RawVc> {
self.resolve_inner(/* strongly_consistent */ true).await
}

pub(crate) async fn resolve_inner(self, strongly_consistent: bool) -> Result<RawVc> {
let tt = turbo_tasks();
let mut current = self;
let mut notified = false;
Expand All @@ -194,10 +183,14 @@ impl RawVc {
tt.notify_scheduled_tasks();
notified = true;
}
current = read_task_output(&*tt, task, true).await?;
current = read_task_output(&*tt, task, strongly_consistent).await?;
}
RawVc::TaskCell(_, _) => return Ok(current),
RawVc::LocalCell(_, _) => todo!(),
RawVc::LocalCell(execution_id, local_cell_id) => {
let shared_reference = read_local_cell(execution_id, local_cell_id);
let value_type = get_value_type(shared_reference.0);
return Ok((value_type.raw_cell)(shared_reference));
}
}
}
}
Expand Down Expand Up @@ -355,7 +348,7 @@ impl Future for ReadRawVcFuture {
}
}
RawVc::LocalCell(execution_id, local_cell_id) => {
return Poll::Ready(Ok(read_local_cell(execution_id, local_cell_id)));
return Poll::Ready(Ok(read_local_cell(execution_id, local_cell_id).into()));
}
};
// SAFETY: listener is from previous pinned this
Expand Down
2 changes: 1 addition & 1 deletion crates/turbo-tasks/src/read_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ where
};
Vc {
node: <T::CellMode as VcCellMode<T>>::raw_cell(
SharedReference::new(value).typed(type_id),
SharedReference::new(value).into_typed(type_id),
),
_t: PhantomData,
}
Expand Down
21 changes: 17 additions & 4 deletions crates/turbo-tasks/src/task/shared_reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
any::Any,
fmt::{Debug, Display},
hash::Hash,
ops::Deref,
};

use anyhow::Result;
Expand Down Expand Up @@ -36,14 +37,26 @@ impl SharedReference {
}
}

pub(crate) fn typed(&self, type_id: ValueTypeId) -> TypedSharedReference {
TypedSharedReference(type_id, self.clone())
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
self.0.downcast_ref()
}

pub fn into_typed(self, type_id: ValueTypeId) -> TypedSharedReference {
TypedSharedReference(type_id, self)
}
}

impl TypedSharedReference {
pub(crate) fn untyped(&self) -> (ValueTypeId, SharedReference) {
(self.0, self.1.clone())
pub fn into_untyped(self) -> SharedReference {
self.1
}
}

impl Deref for TypedSharedReference {
type Target = SharedReference;

fn deref(&self) -> &Self::Target {
&self.1
}
}

Expand Down
Loading

0 comments on commit 4c2a735

Please sign in to comment.