Skip to content

Commit

Permalink
fix matching with char (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
Samir-Rashid authored Nov 20, 2024
1 parent 08b3c90 commit a5a6540
Show file tree
Hide file tree
Showing 18 changed files with 165 additions and 16 deletions.
8 changes: 8 additions & 0 deletions crates/flux-desugar/src/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,10 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
Ok((self.genv().alloc_slice(&fhir_args), self.genv().alloc_slice(&constraints)))
}

/// This is the mega desugaring function [`surface::Ty`] -> [`fhir::Ty`].
/// These are both similar representations. The most important difference is that
/// [`fhir::Ty`] has explicit refinement parameters and [`surface::Ty`] does not.
/// Refinements are implicitly scoped in surface.
fn desugar_ty(&mut self, ty: &surface::Ty) -> Result<fhir::Ty<'genv>> {
let node_id = ty.node_id;
let span = ty.span;
Expand Down Expand Up @@ -1358,6 +1362,7 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
}
}

/// Desugar surface literal
fn desugar_lit(&self, span: Span, lit: surface::Lit) -> Result<fhir::Lit> {
match lit.kind {
surface::LitKind::Integer => {
Expand All @@ -1373,6 +1378,9 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
}
surface::LitKind::Bool => Ok(fhir::Lit::Bool(lit.symbol == kw::True)),
surface::LitKind::Str => Ok(fhir::Lit::Str(lit.symbol)),
surface::LitKind::Char => {
Ok(fhir::Lit::Char(lit.symbol.as_str().parse::<char>().unwrap()))
}
_ => Err(self.emit_err(errors::UnexpectedLiteral { span })),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/flux-fhir-analysis/src/conv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,7 @@ fn conv_lit(lit: fhir::Lit) -> rty::Constant {
fhir::Lit::Real(r) => rty::Constant::Real(rty::Real(r)),
fhir::Lit::Bool(b) => rty::Constant::from(b),
fhir::Lit::Str(s) => rty::Constant::from(s),
fhir::Lit::Char(c) => rty::Constant::from(c),
}
}
fn conv_un_op(op: fhir::UnOp) -> rty::UnOp {
Expand Down
1 change: 1 addition & 0 deletions crates/flux-fhir-analysis/src/wf/sortck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ fn synth_lit(lit: fhir::Lit) -> rty::Sort {
fhir::Lit::Bool(_) => rty::Sort::Bool,
fhir::Lit::Real(_) => rty::Sort::Real,
fhir::Lit::Str(_) => rty::Sort::Str,
fhir::Lit::Char(_) => rty::Sort::Char,
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/flux-infer/src/fixpoint_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl SortEncodingCtxt {
rty::Sort::Real => fixpoint::Sort::Real,
rty::Sort::Bool => fixpoint::Sort::Bool,
rty::Sort::Str => fixpoint::Sort::Str,
rty::Sort::Char => fixpoint::Sort::Int,
rty::Sort::BitVec(size) => fixpoint::Sort::BitVec(Box::new(bv_size_to_fixpoint(*size))),
// There's no way to declare opaque sorts in the fixpoint horn syntax so we encode user
// declared opaque sorts, type parameter sorts, and (unormalizable) type alias sorts as
Expand Down Expand Up @@ -597,6 +598,7 @@ fn const_to_fixpoint(cst: rty::Constant) -> fixpoint::Constant {
rty::Constant::Int(i) => fixpoint::Constant::Numeral(i),
rty::Constant::Real(r) => fixpoint::Constant::Decimal(r),
rty::Constant::Bool(b) => fixpoint::Constant::Boolean(b),
rty::Constant::Char(c) => fixpoint::Constant::Numeral(BigInt::from(u32::from(c))),
rty::Constant::Str(s) => fixpoint::Constant::String(fixpoint::SymStr(s)),
}
}
Expand Down Expand Up @@ -1120,7 +1122,7 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
scx: &mut SortEncodingCtxt,
) -> QueryResult<fixpoint::Expr> {
let e = match sort {
rty::Sort::Int | rty::Sort::Real => {
rty::Sort::Int | rty::Sort::Real | rty::Sort::Char => {
fixpoint::Expr::Atom(
rel,
Box::new([self.expr_to_fixpoint(e1, scx)?, self.expr_to_fixpoint(e2, scx)?]),
Expand Down
6 changes: 6 additions & 0 deletions crates/flux-middle/src/big_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ impl From<i32> for BigInt {
}
}

impl From<u32> for BigInt {
fn from(val: u32) -> Self {
BigInt { sign: Sign::NonNegative, val: val as u128 }
}
}

impl liquid_fixpoint::FixpointFmt for BigInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.sign {
Expand Down
4 changes: 3 additions & 1 deletion crates/flux-middle/src/fhir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,8 @@ pub enum Lit {
Int(i128),
Real(i128),
Bool(bool),
Str(Symbol),
Str(Symbol), // `rustc_span::Symbol` interns a value with the type
Char(char), // all Rust chars are u32s
}

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -1427,6 +1428,7 @@ impl fmt::Debug for Lit {
Lit::Real(r) => write!(f, "{r}real"),
Lit::Bool(b) => write!(f, "{b}"),
Lit::Str(s) => write!(f, "\"{s:?}\""),
Lit::Char(c) => write!(f, "\'{c}\'"),
}
}
}
Expand Down
20 changes: 18 additions & 2 deletions crates/flux-middle/src/rty/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ impl Expr {
}
BaseTy::Uint(_) => ExprKind::Constant(Constant::from(bits)).intern(),
BaseTy::Bool => ExprKind::Constant(Constant::Bool(bits != 0)).intern(),
BaseTy::Char => {
let c = char::from_u32(bits.try_into().unwrap()).unwrap();
ExprKind::Constant(Constant::Char(c)).intern()
}
_ => bug!(),
}
}
Expand Down Expand Up @@ -905,6 +909,7 @@ pub enum Constant {
Real(Real),
Bool(bool),
Str(Symbol),
Char(char),
}

impl Constant {
Expand Down Expand Up @@ -983,7 +988,11 @@ impl Constant {
let b = scalar_to_bits(tcx, scalar, ty)?;
Some(Constant::Bool(b != 0))
}
_ => None,
TyKind::Char => {
let b = scalar_to_bits(tcx, scalar, ty)?;
Some(Constant::Char(char::from_u32(b as u32)?))
}
_ => bug!(),
}
}

Expand Down Expand Up @@ -1039,6 +1048,12 @@ impl From<Symbol> for Constant {
}
}

impl From<char> for Constant {
fn from(c: char) -> Self {
Constant::Char(c)
}
}

impl_internable!(ExprKind);
impl_slice_internable!(Expr, KVar);

Expand Down Expand Up @@ -1195,7 +1210,8 @@ mod pretty {
Constant::Int(i) => w!("{i}"),
Constant::Real(r) => w!("{}.0", ^r.0),
Constant::Bool(b) => w!("{b}"),
Constant::Str(sym) => w!("{sym}"),
Constant::Str(sym) => w!("\"{sym}\""),
Constant::Char(c) => write!(f, "\'{c}\'"),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/flux-middle/src/rty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ impl TypeSuperVisitable for Sort {
| Sort::Bool
| Sort::Real
| Sort::Str
| Sort::Char
| Sort::BitVec(_)
| Sort::Loc
| Sort::Param(_)
Expand Down Expand Up @@ -476,6 +477,7 @@ impl TypeSuperFoldable for Sort {
| Sort::Real
| Sort::Loc
| Sort::Str
| Sort::Char
| Sort::BitVec(_)
| Sort::Param(_)
| Sort::Var(_)
Expand Down
16 changes: 16 additions & 0 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ pub enum Sort {
Real,
BitVec(BvSize),
Str,
Char,
Loc,
Param(ParamTy),
Tuple(List<Sort>),
Expand Down Expand Up @@ -1271,6 +1272,13 @@ impl Ty {
.unwrap_or_default()
}

/// Whether the type is a `char`
pub fn is_char(&self) -> bool {
self.as_bty_skipping_existentials()
.map(BaseTy::is_char)
.unwrap_or_default()
}

pub fn is_uninit(&self) -> bool {
matches!(self.kind(), TyKind::Uninit)
}
Expand Down Expand Up @@ -1544,6 +1552,14 @@ impl BaseTy {
matches!(self, BaseTy::Adt(adt_def, _) if adt_def.is_box())
}

pub fn is_char(&self) -> bool {
matches!(self, BaseTy::Char)
}

pub fn is_str(&self) -> bool {
matches!(self, BaseTy::Str)
}

pub fn unpack_box(&self) -> Option<(&Ty, &Ty)> {
if let BaseTy::Adt(adt_def, args) = self
&& adt_def.is_box()
Expand Down
1 change: 1 addition & 0 deletions crates/flux-middle/src/rty/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl Pretty for Sort {
Sort::Int => w!("int"),
Sort::Real => w!("real"),
Sort::Str => w!("str"),
Sort::Char => w!("char"),
Sort::BitVec(size) => w!("bitvec({:?})", size),
Sort::Loc => w!("loc"),
Sort::Var(n) => w!("@{}", ^n.index()),
Expand Down
4 changes: 2 additions & 2 deletions crates/flux-middle/src/sort_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> {
let sort = match ty.kind() {
ty::TyKind::Bool => Some(rty::Sort::Bool),
ty::TyKind::Slice(_) | ty::TyKind::Int(_) | ty::TyKind::Uint(_) => Some(rty::Sort::Int),
ty::TyKind::Char => Some(rty::Sort::Char),
ty::TyKind::Str => Some(rty::Sort::Str),
ty::TyKind::Adt(adt_def, args) => {
let mut sort_args = vec![];
Expand All @@ -58,7 +59,6 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> {
self.sort_of_generic_param(generic_param_def.def_id)?
}
ty::TyKind::Float(_)
| ty::TyKind::Char
| ty::TyKind::RawPtr(..)
| ty::TyKind::Ref(..)
| ty::TyKind::Tuple(_)
Expand Down Expand Up @@ -95,6 +95,7 @@ impl rty::BaseTy {
match self {
rty::BaseTy::Int(_) | rty::BaseTy::Uint(_) | rty::BaseTy::Slice(_) => rty::Sort::Int,
rty::BaseTy::Bool => rty::Sort::Bool,
rty::BaseTy::Char => rty::Sort::Char,
rty::BaseTy::Adt(adt_def, args) => adt_def.sort(args),
rty::BaseTy::Param(param_ty) => rty::Sort::Param(*param_ty),
rty::BaseTy::Str => rty::Sort::Str,
Expand All @@ -108,7 +109,6 @@ impl rty::BaseTy {
rty::Sort::Alias(*kind, alias_ty)
}
rty::BaseTy::Float(_)
| rty::BaseTy::Char
| rty::BaseTy::RawPtr(..)
| rty::BaseTy::Ref(..)
| rty::BaseTy::FnPtr(..)
Expand Down
11 changes: 8 additions & 3 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
}
TerminatorKind::SwitchInt { discr, targets } => {
let discr_ty = self.check_operand(infcx, env, terminator_span, discr)?;
if discr_ty.is_integral() || discr_ty.is_bool() {
if discr_ty.is_integral() || discr_ty.is_bool() || discr_ty.is_char() {
Ok(Self::check_if(&discr_ty, targets))
} else {
Ok(Self::check_match(&discr_ty, targets))
Expand Down Expand Up @@ -943,6 +943,8 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
Ok(Guard::Pred(pred))
}

/// Checks conditional branching as in a `match` statement. [`SwitchTargets`] (https://doc.rust-lang.org/nightly/nightly-rustc/stable_mir/mir/struct.SwitchTargets.html) contains a list of branches - the exact bit value which is being compared and the block to jump to. Using the conditionals, each branch can be checked using the new control flow information.
/// See https://github.com/flux-rs/flux/pull/840#discussion_r1786543174
fn check_if(discr_ty: &Ty, targets: &SwitchTargets) -> Vec<(BasicBlock, Guard)> {
let mk = |bits| {
match discr_ty.kind() {
Expand All @@ -953,7 +955,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
idx.clone()
}
}
TyKind::Indexed(bty @ (BaseTy::Int(_) | BaseTy::Uint(_)), idx) => {
TyKind::Indexed(bty @ (BaseTy::Int(_) | BaseTy::Uint(_) | BaseTy::Char), idx) => {
Expr::eq(idx.clone(), Expr::from_bits(bty, bits))
}
_ => tracked_span_bug!("unexpected discr_ty {:?}", discr_ty),
Expand Down Expand Up @@ -1486,7 +1488,10 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
let idx = Expr::constant(rty::Constant::from(*s));
Ok(Ty::mk_ref(ReStatic, Ty::indexed(BaseTy::Str, idx), Mutability::Not))
}
Constant::Char => Ok(Ty::char()),
Constant::Char(c) => {
let idx = Expr::constant(rty::Constant::from(*c));
Ok(Ty::indexed(BaseTy::Char, idx))
}
Constant::Param(param_const, ty) => {
let idx = Expr::const_generic(*param_const);
let ctor = self
Expand Down
27 changes: 26 additions & 1 deletion crates/flux-refineck/src/primops.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
/// This file defines the refinement rules for primitive operations.
/// Flux needs to define how to reason about primitive operations on different
/// [`BaseTy`]s. This is done by defining a set of rules for each operation.
///
/// For example, equality checks depend on whether the `BaseTy` is treated as
/// refineable or opaque.
///
/// ```
/// // Make the rules for `a == b`.
/// fn mk_eq_rules() -> RuleMatcher<2> {
/// primop_rules! {
/// // if the `BaseTy` is refineable, then we can reason about equality.
/// // The specified types in the `if` are refineable and Flux will use
/// // the refined postcondition (`bool[E::eq(a, b)]`) to reason about
/// // the invariants of `==`.
/// fn(a: T, b: T) -> bool[E::eq(a, b)]
/// if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()
///
/// // Otherwise, if the `BaseTy` is opaque, then we can't reason
/// // about equality. Flux only knows that the return type is a boolean,
/// // but the return value is unrefined.
/// fn(a: T, b: T) -> bool
/// }
/// }
/// ```
use std::{hash::Hash, sync::LazyLock};

use flux_common::tracked_span_bug;
Expand Down Expand Up @@ -280,7 +305,7 @@ fn mk_bit_xor_rules() -> RuleMatcher<2> {
fn mk_eq_rules() -> RuleMatcher<2> {
primop_rules! {
fn(a: T, b: T) -> bool[E::eq(a, b)]
if T.is_integral() || T.is_bool()
if T.is_integral() || T.is_bool() || T.is_char() || T.is_str()

fn(a: T, b: T) -> bool
}
Expand Down
4 changes: 3 additions & 1 deletion crates/flux-rustc-bridge/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ impl<'sess, 'tcx> MirLoweringCtxt<'_, 'sess, 'tcx> {
.ok_or_else(|| UnsupportedReason::new(format!("unsupported constant `{constant:?}`")))
}

/// A `ScalarInt` is just a set of bits that can represent any scalar value.
/// This can represent all the primitive types with a fixed size.
fn scalar_int_to_constant(
&self,
scalar: rustc_ty::ScalarInt,
Expand All @@ -663,7 +665,7 @@ impl<'sess, 'tcx> MirLoweringCtxt<'_, 'sess, 'tcx> {
TyKind::Float(float_ty) => {
Some(Constant::Float(scalar_to_bits(self.tcx, scalar, ty).unwrap(), *float_ty))
}
TyKind::Char => Some(Constant::Char),
TyKind::Char => Some(Constant::Char(scalar.try_into().unwrap())),
TyKind::Bool => Some(Constant::Bool(scalar.try_to_bool().unwrap())),
TyKind::Tuple(tys) if tys.is_empty() => Some(Constant::Unit),
_ => {
Expand Down
6 changes: 3 additions & 3 deletions crates/flux-rustc-bridge/src/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ pub enum StatementKind {
Nop,
}

/// Corresponds to <https://doc.rust-lang.org/beta/nightly-rustc/rustc_middle/mir/enum.Rvalue.html>
pub enum Rvalue {
Use(Operand),
Repeat(Operand, Const),
Expand Down Expand Up @@ -341,8 +342,7 @@ pub enum Constant {
Float(u128, FloatTy),
Bool(bool),
Str(Symbol),
/// We only support opaque chars, so no data stored here for now
Char,
Char(char),
Unit,
Param(ParamConst, Ty),
/// General catch-all for constants of a given Ty
Expand Down Expand Up @@ -750,7 +750,7 @@ impl fmt::Debug for Constant {
Constant::Bool(b) => write!(f, "{b}"),
Constant::Unit => write!(f, "()"),
Constant::Str(s) => write!(f, "\"{s:?}\""),
Constant::Char => write!(f, "\"<opaque char>\""),
Constant::Char(c) => write!(f, "\'{c}\'"),
Constant::Opaque(ty) => write!(f, "<opaque {:?}>", ty),
Constant::Param(p, _) => write!(f, "{:?}", p),
}
Expand Down
14 changes: 14 additions & 0 deletions tests/tests/pos/structs/enum-match-04.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// match with different types
fn test00(flag: char) {
match flag {
'l' => {},
_ => {},
}
}

fn test01(flag: f32) {
match flag {
0.0 => {},
_ => {},
}
}
9 changes: 9 additions & 0 deletions tests/tests/pos/surface/char02.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#[flux::sig(fn() -> char['a'])]
pub fn char00() -> char {
'a'
}

#[flux::sig(fn(c: char{v: 'a' <= v && v <= 'z'}) -> bool[true])]
pub fn lowercase(c: char) -> bool {
'c' == 'c'
}
Loading

0 comments on commit a5a6540

Please sign in to comment.