Skip to content

Commit

Permalink
Merge pull request #7337 from roc-lang/specialize-exprs-bools
Browse files Browse the repository at this point in the history
Handle `If`, `Lookup`, and `Bool` in `specialize_types`
  • Loading branch information
rtfeldman authored Dec 12, 2024
2 parents bfdf967 + aaf82dd commit 7495495
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 113 deletions.
2 changes: 1 addition & 1 deletion crates/compiler/specialize_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ pub use mono_ir::{MonoExpr, MonoExprId, MonoExprs};
pub use mono_module::{InternedStrId, Interns};
pub use mono_num::Number;
pub use mono_struct::MonoFieldId;
pub use mono_type::{MonoType, MonoTypeId, MonoTypes};
pub use mono_type::{MonoType, MonoTypeId, MonoTypes, Primitive};
pub use specialize_type::{MonoTypeCache, Problem, RecordFieldIds, TupleElemIds};
53 changes: 42 additions & 11 deletions crates/compiler/specialize_types/src/mono_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,38 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {

MonoExpr::Struct(slice)
}
Expr::If {
cond_var: _,
branch_var,
branches,
final_else,
} => {
let branch_type = mono_from_var(*branch_var);

let mono_final_else = self.to_mono_expr(&final_else.value);
let final_else = self.mono_exprs.add(mono_final_else, final_else.region);

let mut branch_pairs: Vec<((MonoExpr, Region), (MonoExpr, Region))> =
Vec::with_capacity_in(branches.len(), self.arena);

for (cond, body) in branches {
let mono_cond = self.to_mono_expr(&cond.value);
let mono_body = self.to_mono_expr(&body.value);

branch_pairs.push(((mono_cond, cond.region), (mono_body, body.region)));
}

let branches = self.mono_exprs.extend_pairs(branch_pairs.into_iter());

MonoExpr::If {
branch_type,
branches,
final_else,
}
}
Expr::Var(symbol, var) | Expr::ParamsVar { symbol, var, .. } => {
MonoExpr::Lookup(*symbol, mono_from_var(*var))
}
// Expr::Call((fn_var, fn_expr, capture_var, ret_var), args, called_via) => {
// let opt_ret_type = mono_from_var(*var);

Expand Down Expand Up @@ -258,7 +290,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// })
// }
// }
// Expr::Var(symbol, var) => Some(MonoExpr::Lookup(*symbol, mono_from_var(*var)?)),
// Expr::LetNonRec(def, loc) => {
// let expr = self.to_mono_expr(def.loc_expr.value, stmts)?;
// let todo = (); // TODO if this is an underscore pattern and we're doing a fn call, convert it to Stmt::CallVoid
Expand All @@ -276,7 +307,7 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// todo!("split up the pattern into various Assign statements.");
// }
// Expr::LetRec(vec, loc, illegal_cycle_mark) => todo!(),
_ => todo!(),
_ => todo!("{:?}", can_expr),
// Expr::List {
// elem_var,
// loc_elems,
Expand All @@ -298,12 +329,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// branches_cond_var,
// exhaustive,
// } => todo!(),
// Expr::If {
// cond_var,
// branch_var,
// branches,
// final_else,
// } => todo!(),
// Expr::Call(_, vec, called_via) => todo!(),
// Expr::RunLowLevel { op, args, ret_var } => todo!(),
// Expr::ForeignCall {
Expand Down Expand Up @@ -405,7 +430,7 @@ fn to_num(primitive: Primitive, val: IntValue, problems: &mut impl Push<Problem>
})),
Primitive::U128 => MonoExpr::Number(Number::U128(val.as_u128())),
Primitive::I128 => MonoExpr::Number(Number::I128(val.as_i128())),
Primitive::Str | Primitive::Crash => {
Primitive::Str | Primitive::Crash | Primitive::Bool => {
let problem = Problem::NumSpecializedToWrongType(Some(MonoType::Primitive(primitive)));
problems.push(problem);
MonoExpr::CompilerBug(problem)
Expand All @@ -432,7 +457,8 @@ fn to_frac(primitive: Primitive, val: f64, problems: &mut impl Push<Problem>) ->
| Primitive::U128
| Primitive::I128
| Primitive::Str
| Primitive::Crash => {
| Primitive::Crash
| Primitive::Bool => {
let problem = Problem::NumSpecializedToWrongType(Some(MonoType::Primitive(primitive)));
problems.push(problem);
MonoExpr::CompilerBug(problem)
Expand All @@ -455,7 +481,12 @@ fn char_to_int(primitive: Primitive, ch: char, problems: &mut impl Push<Problem>
Primitive::I128 => MonoExpr::Number(Number::I128(ch as i128)),
Primitive::I16 => MonoExpr::Number(Number::I16(ch as i16)),
Primitive::I8 => MonoExpr::Number(Number::I8(ch as i8)),
Primitive::Str | Primitive::Dec | Primitive::F32 | Primitive::F64 | Primitive::Crash => {
Primitive::Str
| Primitive::Dec
| Primitive::F32
| Primitive::F64
| Primitive::Crash
| Primitive::Bool => {
let problem = Problem::CharSpecializedToWrongType(Some(MonoType::Primitive(primitive)));
problems.push(problem);
MonoExpr::CompilerBug(problem)
Expand Down
136 changes: 49 additions & 87 deletions crates/compiler/specialize_types/src/mono_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use roc_can::expr::Recursive;
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
use roc_region::all::Region;
use soa::{Index, NonEmptySlice, Slice, Slice2, Slice3};
use soa::{Index, NonEmptySlice, PairSlice, Slice, Slice2, Slice3};
use std::iter;

#[derive(Clone, Copy, Debug, PartialEq)]
Expand Down Expand Up @@ -146,6 +146,50 @@ impl MonoExprs {

Slice::new(start as u32, len as u16)
}

pub fn iter_pair_slice(
&self,
exprs: PairSlice<MonoExpr>,
) -> impl Iterator<Item = (&MonoExpr, &MonoExpr)> {
exprs.indices_iter().map(|(index_a, index_b)| {
debug_assert!(
self.exprs.len() > index_a && self.exprs.len() > index_b,
"A Slice index was not found in MonoExprs. This should never happen!"
);

// Safety: we should only ever hand out MonoExprId slices that are valid indices into here.
unsafe {
(
self.exprs.get_unchecked(index_a),
self.exprs.get_unchecked(index_b),
)
}
})
}

pub fn extend_pairs(
&mut self,
exprs: impl Iterator<Item = ((MonoExpr, Region), (MonoExpr, Region))>,
) -> PairSlice<MonoExpr> {
let start = self.exprs.len();

let additional = exprs.size_hint().0 * 2;
self.exprs.reserve(additional);
self.regions.reserve(additional);

let mut pairs = 0;

for ((expr_a, region_a), (expr_b, region_b)) in exprs {
self.exprs.push(expr_a);
self.exprs.push(expr_b);
self.regions.push(region_a);
self.regions.push(region_b);

pairs += 1;
}

PairSlice::new(start as u32, pairs)
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand All @@ -159,82 +203,6 @@ impl MonoExprId {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct MonoStmtId {
inner: Index<MonoStmt>,
}

impl MonoStmtId {
pub(crate) unsafe fn new_unchecked(inner: Index<MonoStmt>) -> Self {
Self { inner }
}
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MonoStmt {
/// Assign to a variable.
Assign(IdentId, MonoExprId),
AssignRec(IdentId, MonoExprId),

/// Introduce a variable, e.g. `var foo_` (we'll MonoStmt::Assign to it later.)
Declare(IdentId),

/// The `return` statement
Return(MonoExprId),

/// The "crash" keyword. Importantly, during code gen we must mark this as "nothing happens after this"
Crash {
msg: MonoExprId,
/// The type of the `crash` expression (which will have unified to whatever's around it)
expr_type: MonoTypeId,
},

Expect {
condition: MonoExprId,
/// If the expectation fails, we print the values of all the named variables
/// in the final expr. These are those values.
lookups_in_cond: Slice2<MonoTypeId, IdentId>,
},

Dbg {
source_location: InternedStrId,
source: InternedStrId,
expr: MonoExprId,
expr_type: MonoTypeId,
name: IdentId,
},

// Call a function that has no return value (or which we are discarding due to an underscore pattern).
CallVoid {
fn_type: MonoTypeId,
fn_expr: MonoExprId,
args: Slice2<MonoTypeId, MonoExprId>,
/// This is the type of the closure based only on canonical IR info,
/// not considering what other closures might later influence it.
/// Lambda set specialization may change this type later!
capture_type: MonoTypeId,
},

// Branching
When {
/// The actual condition of the when expression.
cond: MonoExprId,
cond_type: MonoTypeId,
/// Type of each branch (and therefore the type of the entire `when` expression)
branch_type: MonoTypeId,
/// Note: if the branches weren't exhaustive, we will have already generated a default
/// branch which crashes if it's reached. (The compiler will have reported an error already;
/// this is for if you want to run anyway.)
branches: NonEmptySlice<WhenBranch>,
},
If {
/// Type of each branch (and therefore the type of the entire `if` expression)
branch_type: MonoTypeId,
branches: Slice<(MonoStmtId, MonoStmtId)>,
final_else: Option<MonoTypeId>,
},
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MonoExpr {
Str(InternedStrId),
Expand Down Expand Up @@ -324,21 +292,15 @@ pub enum MonoExpr {
args: Slice2<MonoTypeId, MonoExprId>,
},

Block {
stmts: Slice<MonoStmtId>,
final_expr: MonoExprId,
If {
branch_type: MonoTypeId,
branches: PairSlice<MonoExpr>,
final_else: MonoExprId,
},

CompilerBug(Problem),
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct WhenBranch {
pub patterns: Slice<MonoPatternId>,
pub body: Slice<MonoStmtId>,
pub guard: Option<MonoExprId>,
}

#[derive(Clone, Copy, Debug)]
pub enum MonoPattern {
Identifier(IdentId),
Expand Down
6 changes: 6 additions & 0 deletions crates/compiler/specialize_types/src/mono_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ impl MonoTypeId {
inner: Index::new(14),
};

pub const BOOL: Self = Self {
inner: Index::new(15),
};

pub const DEFAULT_INT: Self = Self::I64; // TODO change this to I128
pub const DEFAULT_FRAC: Self = Self::DEC;

Expand Down Expand Up @@ -100,6 +104,7 @@ impl MonoTypes {
MonoType::Primitive(Primitive::F32),
MonoType::Primitive(Primitive::F64),
MonoType::Primitive(Primitive::Dec),
MonoType::Primitive(Primitive::Bool),
],
ids: Vec::new(),
slices: Vec::new(),
Expand Down Expand Up @@ -232,6 +237,7 @@ pub enum Primitive {
F32,
F64,
Dec,
Bool,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down
21 changes: 20 additions & 1 deletion crates/compiler/specialize_types/src/specialize_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ impl<'a, 'c, 'd, 'e, 'f, 'm, 'p, P: Push<Problem>> Env<'a, 'c, 'd, 'e, 'f, 'm, '
// .flat_map(|var_index| self.lower_var( subs, subs[var_index]));

// let arg = new_args.next();
} else if symbol == Symbol::BOOL_BOOL {
MonoTypeId::BOOL
} else {
todo!("implement lower_builtin for symbol {symbol:?} - or, if all the builtins are already in here, report a compiler bug instead of panicking like this.");
}
Expand Down Expand Up @@ -476,7 +478,24 @@ fn number_args_to_mono_id(
}
}
}
Content::RangedNumber(_numeric_range) => todo!(),
Content::RangedNumber(range) => {
use roc_types::num::NumericRange::*;

return match *range {
IntAtLeastSigned(int_lit_width) => {
int_lit_width_to_mono_type_id(int_lit_width)
}
IntAtLeastEitherSign(int_lit_width) => {
int_lit_width_to_mono_type_id(int_lit_width)
}
NumAtLeastSigned(int_lit_width) => {
int_lit_width_to_mono_type_id(int_lit_width)
}
NumAtLeastEitherSign(int_lit_width) => {
int_lit_width_to_mono_type_id(int_lit_width)
}
};
}
_ => {
// This is an invalid number type, so break out of
// the alias-unrolling loop in order to return an error.
Expand Down
Loading

0 comments on commit 7495495

Please sign in to comment.