diff --git a/turbopack/crates/turbo-tasks-memory/tests/local_cell.rs b/turbopack/crates/turbo-tasks-memory/tests/local_cell.rs index 9274c789c6404a..7ca3fa7103d94e 100644 --- a/turbopack/crates/turbo-tasks-memory/tests/local_cell.rs +++ b/turbopack/crates/turbo-tasks-memory/tests/local_cell.rs @@ -1,6 +1,6 @@ #![feature(arbitrary_self_types)] -use turbo_tasks::Vc; +use turbo_tasks::{debug::ValueDebug, test_helpers::current_task_for_testing, ValueDefault, Vc}; use turbo_tasks_testing::{register, run, Registration}; static REGISTRATION: Registration = register!(); @@ -57,6 +57,57 @@ async fn test_return_resolved() { .await } +#[turbo_tasks::value_trait] +trait UnimplementedTrait {} + +#[tokio::test] +async fn test_try_resolve_sidecast() { + run(®ISTRATION, async { + let trait_vc: Vc> = Vc::upcast(Vc::::local_cell(42)); + + // `u32` is both a `ValueDebug` and a `ValueDefault`, so this sidecast is valid + let sidecast_vc = Vc::try_resolve_sidecast::>(trait_vc) + .await + .unwrap(); + assert!(sidecast_vc.is_some()); + + // `u32` is not an `UnimplementedTrait` though, so this should return None + let wrongly_sidecast_vc = Vc::try_resolve_sidecast::>(trait_vc) + .await + .unwrap(); + assert!(wrongly_sidecast_vc.is_none()); + }) + .await +} + +#[tokio::test] +async fn test_try_resolve_downcast_type() { + run(®ISTRATION, async { + let trait_vc: Vc> = Vc::upcast(Vc::::local_cell(42)); + + let downcast_vc: Vc = Vc::try_resolve_downcast_type(trait_vc) + .await + .unwrap() + .unwrap(); + assert_eq!(*downcast_vc.await.unwrap(), 42); + + let wrongly_downcast_vc: Option> = + Vc::try_resolve_downcast_type(trait_vc).await.unwrap(); + assert!(wrongly_downcast_vc.is_none()); + }) + .await +} + +#[tokio::test] +async fn test_get_task_id() { + run(®ISTRATION, async { + // the task id as reported by the RawVc + let vc_task_id = Vc::into_raw(Vc::<()>::local_cell(())).get_task_id(); + assert_eq!(vc_task_id, current_task_for_testing()); + }) + .await +} + #[turbo_tasks::value(eq = "manual")] #[derive(Default)] struct Untracked { @@ -83,7 +134,7 @@ async fn get_untracked_local_cell() -> Vc { #[tokio::test] #[should_panic(expected = "Local Vcs must only be accessed within their own task")] -async fn test_panics_on_local_cell_escape() { +async fn test_panics_on_local_cell_escape_read() { run(®ISTRATION, async { get_untracked_local_cell() .await @@ -94,3 +145,12 @@ async fn test_panics_on_local_cell_escape() { }) .await } + +#[tokio::test] +#[should_panic(expected = "Local Vcs must only be accessed within their own task")] +async fn test_panics_on_local_cell_escape_get_task_id() { + run(®ISTRATION, async { + Vc::into_raw(get_untracked_local_cell().await.unwrap().cell).get_task_id(); + }) + .await +} diff --git a/turbopack/crates/turbo-tasks/src/manager.rs b/turbopack/crates/turbo-tasks/src/manager.rs index 3feea84080492f..80d2a18c3e30a3 100644 --- a/turbopack/crates/turbo-tasks/src/manager.rs +++ b/turbopack/crates/turbo-tasks/src/manager.rs @@ -1755,7 +1755,8 @@ pub(crate) fn create_local_cell(value: TypedSharedReference) -> (ExecutionId, Lo /// local cells are always filled. The returned value can be cheaply converted /// with `.into()`. /// -/// Panics if the ExecutionId does not match the expected value. +/// Panics if the [`ExecutionId`] does not match the current task's +/// `execution_id`. pub(crate) fn read_local_cell( execution_id: ExecutionId, local_cell_id: LocalCellId, @@ -1766,12 +1767,28 @@ pub(crate) fn read_local_cell( local_cells, .. } = &*cell.borrow(); - assert_eq!( - execution_id, *expected_execution_id, - "This Vc is local. Local Vcs must only be accessed within their own task. Resolve the \ - Vc to convert it into a non-local version." - ); + assert_eq_local_cell(execution_id, *expected_execution_id); // local cell ids are one-indexed (they use NonZeroU32) local_cells[(*local_cell_id as usize) - 1].clone() }) } + +/// Panics if the [`ExecutionId`] does not match the current task's +/// `execution_id`. +pub(crate) fn assert_execution_id(execution_id: ExecutionId) { + CURRENT_TASK_STATE.with(|cell| { + let CurrentTaskState { + execution_id: expected_execution_id, + .. + } = &*cell.borrow(); + assert_eq_local_cell(execution_id, *expected_execution_id); + }) +} + +fn assert_eq_local_cell(actual: ExecutionId, expected: ExecutionId) { + assert_eq!( + actual, expected, + "This Vc is local. Local Vcs must only be accessed within their own task. Resolve the Vc \ + to convert it into a non-local version." + ); +} diff --git a/turbopack/crates/turbo-tasks/src/raw_vc.rs b/turbopack/crates/turbo-tasks/src/raw_vc.rs index 9dc748bb282f42..a0f4c3470b8982 100644 --- a/turbopack/crates/turbo-tasks/src/raw_vc.rs +++ b/turbopack/crates/turbo-tasks/src/raw_vc.rs @@ -16,9 +16,12 @@ use crate::{ backend::{CellContent, TypedCellContent}, event::EventListener, id::{ExecutionId, LocalCellId}, - manager::{read_local_cell, read_task_cell, read_task_output, TurboTasksApi}, + manager::{ + assert_execution_id, current_task, read_local_cell, read_task_cell, read_task_output, + TurboTasksApi, + }, registry::{self, get_value_type}, - turbo_tasks, CollectiblesSource, TaskId, TraitTypeId, ValueTypeId, Vc, VcValueTrait, + turbo_tasks, CollectiblesSource, TaskId, TraitTypeId, ValueType, ValueTypeId, Vc, VcValueTrait, }; #[derive(Error, Debug)] @@ -100,38 +103,31 @@ impl RawVc { self, trait_type: TraitTypeId, ) -> Result, ResolveTypeError> { - let tt = turbo_tasks(); - tt.notify_scheduled_tasks(); - let mut current = self; - loop { - match current { - RawVc::TaskOutput(task) => { - current = read_task_output(&*tt, task, false) - .await - .map_err(|source| ResolveTypeError::TaskError { source })?; - } - RawVc::TaskCell(task, index) => { - let content = read_task_cell(&*tt, task, index) - .await - .map_err(|source| ResolveTypeError::ReadError { source })?; - if let TypedCellContent(value_type, CellContent(Some(_))) = content { - if get_value_type(value_type).has_trait(&trait_type) { - return Ok(Some(RawVc::TaskCell(task, index))); - } else { - return Ok(None); - } - } else { - return Err(ResolveTypeError::NoContent); - } - } - RawVc::LocalCell(_, _) => todo!(), - } - } + self.resolve_type_inner(|value_type_id| { + let value_type = get_value_type(value_type_id); + (value_type.has_trait(&trait_type), Some(value_type)) + }) + .await } pub(crate) async fn resolve_value( self, value_type: ValueTypeId, + ) -> Result, ResolveTypeError> { + self.resolve_type_inner(|cell_value_type| (cell_value_type == value_type, None)) + .await + } + + /// Helper for `resolve_trait` and `resolve_value`. + /// + /// After finding a cell, returns `Ok(Some(...))` when `conditional` returns + /// `true`, and `Ok(None)` when `conditional` returns `false`. + /// + /// As an optimization, `conditional` may return the `&'static ValueType` to + /// avoid a potential extra lookup later. + async fn resolve_type_inner( + self, + conditional: impl FnOnce(ValueTypeId) -> (bool, Option<&'static ValueType>), ) -> Result, ResolveTypeError> { let tt = turbo_tasks(); tt.notify_scheduled_tasks(); @@ -147,17 +143,29 @@ impl RawVc { let content = read_task_cell(&*tt, task, index) .await .map_err(|source| ResolveTypeError::ReadError { source })?; - if let TypedCellContent(cell_value_type, CellContent(Some(_))) = content { - if cell_value_type == value_type { - return Ok(Some(RawVc::TaskCell(task, index))); + if let TypedCellContent(value_type, CellContent(Some(_))) = content { + return Ok(if conditional(value_type).0 { + Some(RawVc::TaskCell(task, index)) } else { - return Ok(None); - } + None + }); } else { return Err(ResolveTypeError::NoContent); } } - RawVc::LocalCell(_, _) => todo!(), + RawVc::LocalCell(execution_id, local_cell_id) => { + let shared_reference = read_local_cell(execution_id, local_cell_id); + return Ok( + if let (true, value_type) = conditional(shared_reference.0) { + // re-use the `ValueType` lookup from `conditional`, if it exists + let value_type = + value_type.unwrap_or_else(|| get_value_type(shared_reference.0)); + Some((value_type.raw_cell)(shared_reference)) + } else { + None + }, + ); + } } } } @@ -172,7 +180,7 @@ impl RawVc { self.resolve_inner(/* strongly_consistent */ true).await } - pub(crate) async fn resolve_inner(self, strongly_consistent: bool) -> Result { + async fn resolve_inner(self, strongly_consistent: bool) -> Result { let tt = turbo_tasks(); let mut current = self; let mut notified = false; @@ -203,7 +211,10 @@ impl RawVc { pub fn get_task_id(&self) -> TaskId { match self { RawVc::TaskOutput(t) | RawVc::TaskCell(t, _) => *t, - RawVc::LocalCell(_, _) => todo!(), + RawVc::LocalCell(execution_id, _) => { + assert_execution_id(*execution_id); + current_task("RawVc::get_task_id") + } } } }