Skip to content

Commit

Permalink
encapsulate enum
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Feb 25, 2025
1 parent 56708d2 commit acf3642
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions crates/cubecl-core/src/frontend/barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ pub struct BarrierExpand<C: CubePrimitive> {
_c: PhantomData<C>,
}

pub struct BarrierLevel(InnerBarrierLevel);

#[derive(Clone)]
/// Defines how many units must reach the barrier to allow continuation
pub enum BarrierLevel {
enum InnerBarrierLevel {
/// Only waits for the unit who declared this barrier.
/// This may be useful for waiting upon async data loading
Unit,
Expand All @@ -63,31 +65,31 @@ pub enum BarrierLevel {
impl BarrierLevel {
/// Creates a Unit barrier level
pub fn unit() -> Self {
Self::Unit
BarrierLevel(InnerBarrierLevel::Unit)
}

/// Creates a Cube barrier level
///
/// The field elected_unit is the UNIT_POS of the unit that will
/// perform the underlying initialization. Typically, 0 should work
pub fn cube(elected_unit: u32) -> Self {
Self::Cube(elected_unit)
BarrierLevel(InnerBarrierLevel::Cube(elected_unit))
}

pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
Self::Unit
BarrierLevel(InnerBarrierLevel::Unit)
}

pub fn __expand_cube(_scope: &mut Scope, elected_unit: u32) -> Self {
Self::Cube(elected_unit)
BarrierLevel(InnerBarrierLevel::Cube(elected_unit))
}
}

impl From<BarrierLevel> for cubecl_ir::BarrierLevel {
fn from(val: BarrierLevel) -> Self {
impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
fn from(val: InnerBarrierLevel) -> Self {
match val {
BarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
BarrierLevel::Cube(elected_unit) => cubecl_ir::BarrierLevel::Cube(elected_unit),
InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
InnerBarrierLevel::Cube(elected_unit) => cubecl_ir::BarrierLevel::Cube(elected_unit),
}
}
}
Expand Down Expand Up @@ -116,7 +118,7 @@ impl<C: CubePrimitive> Barrier<C> {
pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand<C> {
let elem = C::as_elem(scope);

let variable = scope.create_barrier(Item::new(elem), level.into());
let variable = scope.create_barrier(Item::new(elem), level.0.into());
BarrierExpand {
elem: variable,
_c: PhantomData,
Expand Down

0 comments on commit acf3642

Please sign in to comment.