diff --git a/crates/cubecl-core/src/frontend/barrier.rs b/crates/cubecl-core/src/frontend/barrier.rs index c12a4afc..bea2af7e 100644 --- a/crates/cubecl-core/src/frontend/barrier.rs +++ b/crates/cubecl-core/src/frontend/barrier.rs @@ -50,9 +50,11 @@ pub struct BarrierExpand { _c: PhantomData, } +pub struct BarrierLevel(InnerBarrierLevel); + #[derive(Clone)] /// Defines how many units must reach the barrier to allow continuation -pub enum BarrierLevel { +enum InnerBarrierLevel { /// Only waits for the unit who declared this barrier. /// This may be useful for waiting upon async data loading Unit, @@ -63,7 +65,7 @@ pub enum BarrierLevel { impl BarrierLevel { /// Creates a Unit barrier level pub fn unit() -> Self { - Self::Unit + BarrierLevel(InnerBarrierLevel::Unit) } /// Creates a Cube barrier level @@ -71,23 +73,23 @@ impl BarrierLevel { /// The field elected_unit is the UNIT_POS of the unit that will /// perform the underlying initialization. Typically, 0 should work pub fn cube(elected_unit: u32) -> Self { - Self::Cube(elected_unit) + BarrierLevel(InnerBarrierLevel::Cube(elected_unit)) } pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel { - Self::Unit + BarrierLevel(InnerBarrierLevel::Unit) } pub fn __expand_cube(_scope: &mut Scope, elected_unit: u32) -> Self { - Self::Cube(elected_unit) + BarrierLevel(InnerBarrierLevel::Cube(elected_unit)) } } -impl From for cubecl_ir::BarrierLevel { - fn from(val: BarrierLevel) -> Self { +impl From for cubecl_ir::BarrierLevel { + fn from(val: InnerBarrierLevel) -> Self { match val { - BarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit, - BarrierLevel::Cube(elected_unit) => cubecl_ir::BarrierLevel::Cube(elected_unit), + InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit, + InnerBarrierLevel::Cube(elected_unit) => cubecl_ir::BarrierLevel::Cube(elected_unit), } } } @@ -116,7 +118,7 @@ impl Barrier { pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand { let elem = C::as_elem(scope); - let variable = scope.create_barrier(Item::new(elem), level.into()); + let variable = scope.create_barrier(Item::new(elem), level.0.into()); BarrierExpand { elem: variable, _c: PhantomData,