diff --git a/Cargo.lock b/Cargo.lock index 70922a357a..8304a993a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -283,6 +283,17 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "derive-where" +version = "1.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146398d62142a0f35248a608f17edf0dde57338354966d6e41d0eb2d16980ccb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + [[package]] name = "diff" version = "0.1.13" @@ -542,6 +553,7 @@ dependencies = [ name = "flux-fixpoint" version = "0.1.0" dependencies = [ + "derive-where", "flux-common", "flux-config", "itertools", diff --git a/crates/flux-fixpoint/Cargo.toml b/crates/flux-fixpoint/Cargo.toml index 753eecfa9d..1cb416a11e 100644 --- a/crates/flux-fixpoint/Cargo.toml +++ b/crates/flux-fixpoint/Cargo.toml @@ -7,6 +7,7 @@ edition.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +derive-where = "1.2.5" flux-common.workspace = true flux-config.workspace = true itertools.workspace = true diff --git a/crates/flux-fixpoint/src/constraint.rs b/crates/flux-fixpoint/src/constraint.rs index 47c0bfc04c..181d61bbe2 100644 --- a/crates/flux-fixpoint/src/constraint.rs +++ b/crates/flux-fixpoint/src/constraint.rs @@ -1,22 +1,22 @@ use std::{ fmt::{self, Write}, - hash::{Hash, Hasher}, sync::LazyLock, }; +use derive_where::derive_where; use flux_common::format::PadAdapter; use itertools::Itertools; -use rustc_index::newtype_index; use rustc_macros::{Decodable, Encodable}; use rustc_span::Symbol; -use crate::big_int::BigInt; +use crate::{big_int::BigInt, StringTypes, Types}; -pub enum Constraint { - Pred(Pred, Option), +#[derive_where(Hash)] +pub enum Constraint { + Pred(Pred, #[derive_where(skip)] Option), Conj(Vec), - Guard(Pred, Box), - ForAll(Name, Sort, Pred, Box), + Guard(Pred, Box), + ForAll(T::Var, Sort, Pred, Box), } #[derive(Clone, Hash)] @@ -49,53 +49,50 @@ pub struct PolyFuncSort { fsort: FuncSort, } -#[derive(Hash, Debug)] -pub enum Pred { +#[derive_where(Hash)] +pub enum Pred { And(Vec), - KVar(KVid, Vec), - Expr(Expr), + KVar(T::KVar, Vec), + Expr(Expr), } -#[derive(Hash, Debug)] -pub enum Expr { - Var(Name), - ConstVar(ConstName), +#[derive_where(Hash)] +pub enum Expr { + Var(T::Var), Constant(Constant), - BinaryOp(BinOp, Box<[Expr; 2]>), - App(Func, Vec), + BinaryOp(BinOp, Box<[Self; 2]>), + App(Func, Vec), UnaryOp(UnOp, Box), - Pair(Box<[Expr; 2]>), - Proj(Box, Proj), - IfThenElse(Box<[Expr; 3]>), + Pair(Box<[Self; 2]>), + Proj(Box, Proj), + IfThenElse(Box<[Self; 3]>), Unit, } -#[derive(Hash, Debug, Clone)] -pub enum Func { - Var(Name), - /// uninterepreted function - Uif(ConstName), +#[derive_where(Hash)] +pub enum Func { + Var(T::Var), /// interpreted (theory) function Itf(Symbol), } -#[derive(Clone, Copy, Hash, Debug)] +#[derive(Clone, Copy, Hash)] pub enum Proj { Fst, Snd, } -#[derive(Hash)] -pub struct Qualifier { +#[derive_where(Hash)] +pub struct Qualifier { pub name: String, - pub args: Vec<(Name, Sort)>, - pub body: Expr, + pub args: Vec<(T::Var, Sort)>, + pub body: Expr, pub global: bool, } -#[derive(Clone, Copy, Debug)] -pub struct Const { - pub name: ConstName, +#[derive(Clone, Copy)] +pub struct Const { + pub name: T::Var, pub val: i128, } @@ -124,33 +121,14 @@ pub enum UnOp { Neg, } -#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, Encodable, Decodable)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encodable, Decodable)] pub enum Constant { Int(BigInt), Real(i128), Bool(bool), } -newtype_index! { - #[debug_format = "$k{}"] - pub struct KVid {} -} - -newtype_index! { - #[debug_format = "a{}"] - pub struct Name { - const NAME0 = 0; - const NAME1 = 1; - const NAME2 = 2; - } -} - -newtype_index! { - #[debug_format = "c{}"] - pub struct ConstName {} -} - -impl Constraint { +impl Constraint { pub const TRUE: Self = Self::Pred(Pred::TRUE, None); /// Returns true if the constraint has at least one concrete RHS ("head") predicates. @@ -164,28 +142,7 @@ impl Constraint { } } -impl Hash for Constraint { - fn hash(&self, state: &mut H) { - let tag = std::mem::discriminant(self); - tag.hash(state); - match self { - Constraint::Pred(p, _) => p.hash(state), - Constraint::Conj(cs) => cs.hash(state), - Constraint::Guard(p, c) => { - p.hash(state); - c.hash(state); - } - Constraint::ForAll(x, t, p, c) => { - x.hash(state); - t.hash(state); - p.hash(state); - c.hash(state); - } - } - } -} - -impl Pred { +impl Pred { pub const TRUE: Self = Pred::Expr(Expr::Constant(Constant::Bool(true))); pub fn is_trivially_true(&self) -> bool { @@ -217,10 +174,7 @@ impl PolyFuncSort { } } -impl fmt::Display for Constraint -where - Tag: fmt::Display, -{ +impl fmt::Display for Constraint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Constraint::Pred(pred, tag) => write!(f, "{}", PredTag(pred, tag)), @@ -241,7 +195,7 @@ where write!(f, "\n)") } Constraint::ForAll(x, sort, body, head) => { - write!(f, "(forall (({x:?} {sort}) {body})")?; + write!(f, "(forall (({x} {sort}) {body})")?; write!(PadAdapter::wrap_fmt(f, 2), "\n{head}")?; write!(f, "\n)") } @@ -249,12 +203,9 @@ where } } -struct PredTag<'a, Tag>(&'a Pred, &'a Option); +struct PredTag<'a, T: Types>(&'a Pred, &'a Option); -impl fmt::Display for PredTag<'_, Tag> -where - Tag: fmt::Display, -{ +impl fmt::Display for PredTag<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let PredTag(pred, tag) = self; match pred { @@ -301,35 +252,24 @@ impl fmt::Display for Sort { Sort::BitVec(size) => write!(f, "(BitVec Size{})", size), Sort::Pair(s1, s2) => write!(f, "(Pair {s1} {s2})"), Sort::Func(sort) => write!(f, "{sort}"), - Sort::App(ctor, ts) => write!(f, "({ctor} {:?})", ts.iter().format(" ")), + Sort::App(ctor, ts) => write!(f, "({ctor} {})", ts.iter().format(" ")), } } } impl fmt::Display for PolyFuncSort { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "(func({}, [{:?}]))", - self.params, - self.fsort.inputs_and_output.iter().format("; ") - ) + write!(f, "(func({}, [{}]))", self.params, self.fsort.inputs_and_output.iter().format("; ")) } } impl fmt::Display for FuncSort { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(func(0, [{:?}]))", self.inputs_and_output.iter().format("; ")) - } -} - -impl fmt::Debug for Sort { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(self, f) + write!(f, "(func(0, [{}]))", self.inputs_and_output.iter().format("; ")) } } -impl fmt::Display for Pred { +impl fmt::Display for Pred { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Pred::And(preds) => { @@ -340,24 +280,24 @@ impl fmt::Display for Pred { } } Pred::KVar(kvid, vars) => { - write!(f, "({kvid:?} {:?})", vars.iter().format(" ")) + write!(f, "(${kvid} {})", vars.iter().format(" ")) } Pred::Expr(expr) => write!(f, "({expr})"), } } } -impl Expr { - pub const ZERO: Expr = Expr::Constant(Constant::ZERO); - pub const ONE: Expr = Expr::Constant(Constant::ONE); - pub fn eq(self, other: Expr) -> Expr { +impl Expr { + pub const ZERO: Expr = Expr::Constant(Constant::ZERO); + pub const ONE: Expr = Expr::Constant(Constant::ONE); + pub fn eq(self, other: Self) -> Self { Expr::BinaryOp(BinOp::Eq, Box::new([self, other])) } } -struct FmtParens<'a>(&'a Expr); +struct FmtParens<'a, T: Types>(&'a Expr); -impl fmt::Display for FmtParens<'_> { +impl fmt::Display for FmtParens<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Fixpoint parser has `=` at two different precedence levels depending on whether it is // used in a sequence of boolean expressions or not. To avoid complexity we parenthesize @@ -371,11 +311,10 @@ impl fmt::Display for FmtParens<'_> { } } -impl fmt::Display for Expr { +impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Expr::Var(x) => write!(f, "{x:?}"), - Expr::ConstVar(x) => write!(f, "{x:?}"), + Expr::Var(x) => write!(f, "{x}"), Expr::Constant(c) => write!(f, "{c}"), Expr::BinaryOp(op, box [e1, e2]) => { write!(f, "{} {op} {}", FmtParens(e1), FmtParens(e2))?; @@ -402,57 +341,56 @@ impl fmt::Display for Expr { } } -impl fmt::Display for Func { +impl fmt::Display for Func { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Func::Var(name) => write!(f, "{name:?}"), - Func::Uif(uif) => write!(f, "{uif:?}"), + Func::Var(name) => write!(f, "{name}"), Func::Itf(itf) => write!(f, "{itf}"), } } } -pub(crate) static DEFAULT_QUALIFIERS: LazyLock> = LazyLock::new(|| { +pub(crate) static DEFAULT_QUALIFIERS: LazyLock>> = LazyLock::new(|| { // ----- // UNARY // ----- // (qualif EqZero ((v int)) (v == 0)) let eqzero = Qualifier { - args: vec![(NAME0, Sort::Int)], - body: Expr::BinaryOp(BinOp::Eq, Box::new([Expr::Var(NAME0), Expr::ZERO])), + args: vec![("v", Sort::Int)], + body: Expr::BinaryOp(BinOp::Eq, Box::new([Expr::Var("v"), Expr::ZERO])), name: String::from("EqZero"), global: true, }; // (qualif GtZero ((v int)) (v > 0)) let gtzero = Qualifier { - args: vec![(NAME0, Sort::Int)], - body: Expr::BinaryOp(BinOp::Gt, Box::new([Expr::Var(NAME0), Expr::ZERO])), + args: vec![("v", Sort::Int)], + body: Expr::BinaryOp(BinOp::Gt, Box::new([Expr::Var("v"), Expr::ZERO])), name: String::from("GtZero"), global: true, }; // (qualif GeZero ((v int)) (v >= 0)) let gezero = Qualifier { - args: vec![(NAME0, Sort::Int)], - body: Expr::BinaryOp(BinOp::Ge, Box::new([Expr::Var(NAME0), Expr::ZERO])), + args: vec![("v", Sort::Int)], + body: Expr::BinaryOp(BinOp::Ge, Box::new([Expr::Var("v"), Expr::ZERO])), name: String::from("GeZero"), global: true, }; // (qualif LtZero ((v int)) (v < 0)) let ltzero = Qualifier { - args: vec![(NAME0, Sort::Int)], - body: Expr::BinaryOp(BinOp::Lt, Box::new([Expr::Var(NAME0), Expr::ZERO])), + args: vec![("v", Sort::Int)], + body: Expr::BinaryOp(BinOp::Lt, Box::new([Expr::Var("v"), Expr::ZERO])), name: String::from("LtZero"), global: true, }; // (qualif LeZero ((v int)) (v <= 0)) let lezero = Qualifier { - args: vec![(NAME0, Sort::Int)], - body: Expr::BinaryOp(BinOp::Le, Box::new([Expr::Var(NAME0), Expr::ZERO])), + args: vec![("v", Sort::Int)], + body: Expr::BinaryOp(BinOp::Le, Box::new([Expr::Var("v"), Expr::ZERO])), name: String::from("LeZero"), global: true, }; @@ -463,90 +401,62 @@ pub(crate) static DEFAULT_QUALIFIERS: LazyLock> = LazyLock::new(| // (qualif Eq ((a int) (b int)) (a == b)) let eq = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], - body: Expr::BinaryOp(BinOp::Eq, Box::new([Expr::Var(NAME0), Expr::Var(NAME1)])), + args: vec![("a", Sort::Int), ("b", Sort::Int)], + body: Expr::BinaryOp(BinOp::Eq, Box::new([Expr::Var("a"), Expr::Var("b")])), name: String::from("Eq"), global: true, }; // (qualif Gt ((a int) (b int)) (a > b)) let gt = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], - body: Expr::BinaryOp(BinOp::Gt, Box::new([Expr::Var(NAME0), Expr::Var(NAME1)])), + args: vec![("a", Sort::Int), ("b", Sort::Int)], + body: Expr::BinaryOp(BinOp::Gt, Box::new([Expr::Var("a"), Expr::Var("b")])), name: String::from("Gt"), global: true, }; // (qualif Lt ((a int) (b int)) (a < b)) let ge = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], - body: Expr::BinaryOp(BinOp::Ge, Box::new([Expr::Var(NAME0), Expr::Var(NAME1)])), + args: vec![("a", Sort::Int), ("b", Sort::Int)], + body: Expr::BinaryOp(BinOp::Ge, Box::new([Expr::Var("a"), Expr::Var("b")])), name: String::from("Ge"), global: true, }; // (qualif Ge ((a int) (b int)) (a >= b)) let lt = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], - body: Expr::BinaryOp(BinOp::Lt, Box::new([Expr::Var(NAME0), Expr::Var(NAME1)])), + args: vec![("a", Sort::Int), ("b", Sort::Int)], + body: Expr::BinaryOp(BinOp::Lt, Box::new([Expr::Var("a"), Expr::Var("b")])), name: String::from("Lt"), global: true, }; // (qualif Le ((a int) (b int)) (a <= b)) let le = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], - body: Expr::BinaryOp(BinOp::Le, Box::new([Expr::Var(NAME0), Expr::Var(NAME1)])), + args: vec![("a", Sort::Int), ("b", Sort::Int)], + body: Expr::BinaryOp(BinOp::Le, Box::new([Expr::Var("a"), Expr::Var("b")])), name: String::from("Le"), global: true, }; // (qualif Le1 ((a int) (b int)) (a < b - 1)) let le1 = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int)], + args: vec![("a", Sort::Int), ("b", Sort::Int)], body: Expr::BinaryOp( BinOp::Le, Box::new([ - Expr::Var(NAME0), - Expr::BinaryOp(BinOp::Sub, Box::new([Expr::Var(NAME1), Expr::ONE])), + Expr::Var("a"), + Expr::BinaryOp(BinOp::Sub, Box::new([Expr::Var("b"), Expr::ONE])), ]), ), name: String::from("Le1"), global: true, }; - // (qualif Add2 ((a int) (b int) (c int)) (a == b + c)) - let _add2 = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int), (NAME2, Sort::Int)], - body: Expr::BinaryOp( - BinOp::Eq, - Box::new([ - Expr::Var(NAME0), - Expr::BinaryOp(BinOp::Add, Box::new([Expr::Var(NAME1), Expr::Var(NAME2)])), - ]), - ), - name: String::from("Add2"), - global: true, - }; - - // (qualif Sub2 ((a int) (b int) (c int)) (a == b - c)) - let _sub2 = Qualifier { - args: vec![(NAME0, Sort::Int), (NAME1, Sort::Int), (NAME2, Sort::Int)], - body: Expr::BinaryOp( - BinOp::Eq, - Box::new([ - Expr::Var(NAME0), - Expr::BinaryOp(BinOp::Sub, Box::new([Expr::Var(NAME1), Expr::Var(NAME2)])), - ]), - ), - name: String::from("Sub2"), - global: true, - }; - - vec![eqzero, gtzero, gezero, ltzero, lezero, eq, gt, ge, lt, le, le1] //, add2, sub2] + vec![eqzero, gtzero, gezero, ltzero, lezero, eq, gt, ge, lt, le, le1] }); -impl fmt::Display for Qualifier { +impl fmt::Display for Qualifier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, @@ -554,7 +464,7 @@ impl fmt::Display for Qualifier { self.name, self.args .iter() - .format_with(" ", |(name, sort), f| f(&format_args!("({name:?} {sort})"))), + .format_with(" ", |(name, sort), f| f(&format_args!("({name} {sort})"))), self.body ) } @@ -720,15 +630,3 @@ impl From for Constant { Constant::Bool(b) } } - -impl From for Expr { - fn from(n: Name) -> Self { - Expr::Var(n) - } -} - -impl From for Expr { - fn from(c_n: ConstName) -> Self { - Expr::ConstVar(c_n) - } -} diff --git a/crates/flux-fixpoint/src/lib.rs b/crates/flux-fixpoint/src/lib.rs index 3dceb5daaa..46985d2ad5 100644 --- a/crates/flux-fixpoint/src/lib.rs +++ b/crates/flux-fixpoint/src/lib.rs @@ -1,6 +1,5 @@ #![feature(rustc_private, min_specialization, lazy_cell, box_patterns, let_chains)] -extern crate rustc_index; extern crate rustc_macros; extern crate rustc_serialize; extern crate rustc_span; @@ -18,9 +17,10 @@ use std::{ }; pub use constraint::{ - BinOp, Const, ConstName, Constant, Constraint, Expr, Func, FuncSort, KVid, Name, PolyFuncSort, - Pred, Proj, Qualifier, Sort, SortCtor, UnOp, + BinOp, Const, Constant, Constraint, Expr, Func, FuncSort, PolyFuncSort, Pred, Proj, Qualifier, + Sort, SortCtor, UnOp, }; +use derive_where::derive_where; use flux_common::{cache::QueryCache, format::PadAdapter}; use flux_config as config; use itertools::Itertools; @@ -28,34 +28,67 @@ use serde::{de, Deserialize}; use crate::constraint::DEFAULT_QUALIFIERS; -#[derive(Clone, Debug, Hash)] -pub struct ConstInfo { - pub name: ConstName, +pub trait Symbol: fmt::Display + Hash {} + +impl Symbol for T {} + +pub trait Types { + type KVar: Symbol; + type Var: Symbol; + type Tag: fmt::Display + Hash + FromStr; +} + +#[macro_export] +macro_rules! declare_types { + (type KVar = $kvar:ty; type Var = $var:ty; type Tag = $tag:ty;) => { + pub mod fixpoint_generated { + pub struct FixpointTypes; + pub type Expr = $crate::Expr; + pub type Pred = $crate::Pred; + pub type Func = $crate::Func; + pub type Constraint = $crate::Constraint; + pub type KVar = $crate::KVar; + pub type ConstInfo = $crate::ConstInfo; + pub type Task = $crate::Task; + pub type Qualifier = $crate::Qualifier; + pub use $crate::{PolyFuncSort, Proj, Sort, SortCtor}; + } + + impl $crate::Types for fixpoint_generated::FixpointTypes { + type KVar = $kvar; + type Var = $var; + type Tag = $tag; + } + }; +} + +struct StringTypes; + +impl Types for StringTypes { + type KVar = &'static str; + type Var = &'static str; + type Tag = String; +} + +#[derive_where(Hash)] +pub struct ConstInfo { + pub name: T::Var, pub orig: rustc_span::Symbol, pub sort: Sort, } -pub struct Task { +#[derive_where(Hash)] +pub struct Task { + #[derive_where(skip)] pub comments: Vec, - pub constants: Vec, - pub kvars: Vec, - pub constraint: Constraint, - pub qualifiers: Vec, + pub constants: Vec>, + pub kvars: Vec>, + pub constraint: Constraint, + pub qualifiers: Vec>, pub sorts: Vec, pub scrape_quals: bool, } -impl Hash for Task { - fn hash(&self, state: &mut H) { - self.constants.hash(state); - self.kvars.hash(state); - self.constraint.hash(state); - self.qualifiers.hash(state); - self.sorts.hash(state); - self.scrape_quals.hash(state); - } -} - #[derive(Deserialize, Debug)] #[serde(tag = "tag", content = "contents", bound(deserialize = "Tag: FromStr"))] pub enum FixpointResult { @@ -82,20 +115,20 @@ pub struct Stats { #[derive(Deserialize, Debug)] pub struct CrashInfo(Vec); -#[derive(Debug, Hash)] -pub struct KVar { - kvid: KVid, +#[derive_where(Hash)] +pub struct KVar { + kvid: T::KVar, sorts: Vec, comment: String, } -impl Task { +impl Task { pub fn new( comments: Vec, - constants: Vec, - kvars: Vec, - constraint: Constraint, - qualifiers: Vec, + constants: Vec>, + kvars: Vec>, + constraint: Constraint, + qualifiers: Vec>, sorts: Vec, scrape_quals: bool, ) -> Self { @@ -112,7 +145,7 @@ impl Task { &self, key: String, cache: &mut QueryCache, - ) -> io::Result> { + ) -> io::Result> { let hash = self.hash_with_default(); if config::is_cache_enabled() && cache.is_safe(&key, hash) { @@ -129,7 +162,7 @@ impl Task { result } - fn check(&self) -> io::Result> { + fn check(&self) -> io::Result> { let mut child = Command::new("fixpoint") .arg("-q") .arg("--stdin") @@ -153,13 +186,13 @@ impl Task { } } -impl KVar { - pub fn new(kvid: KVid, sorts: Vec, comment: String) -> Self { +impl KVar { + pub fn new(kvid: T::KVar, sorts: Vec, comment: String) -> Self { Self { kvid, sorts, comment } } } -impl fmt::Display for Task { +impl fmt::Display for Task { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.scrape_quals { writeln!(f, "(fixpoint \"--scrape=both\")")?; @@ -195,11 +228,11 @@ impl fmt::Display for Task { } } -impl fmt::Display for KVar { +impl fmt::Display for KVar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(var {:?} ({})) // {}", + "(var ${} ({})) // {}", self.kvid, self.sorts .iter() @@ -209,13 +242,13 @@ impl fmt::Display for KVar { } } -impl fmt::Display for ConstInfo { +impl fmt::Display for ConstInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(constant {:?} {:?}) // orig: {}", self.name, self.sort, self.orig) + write!(f, "(constant {} {}) // orig: {}", self.name, self.sort, self.orig) } } -impl fmt::Debug for Task { +impl fmt::Debug for Task { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self, f) } diff --git a/crates/flux-refineck/src/fixpoint_encoding.rs b/crates/flux-refineck/src/fixpoint_encoding.rs index 1264fa8c86..8f8bb2ee69 100644 --- a/crates/flux-refineck/src/fixpoint_encoding.rs +++ b/crates/flux-refineck/src/fixpoint_encoding.rs @@ -2,7 +2,6 @@ use std::{hash::Hash, iter}; -use fixpoint::FixpointResult; use flux_common::{ bug, cache::QueryCache, @@ -11,7 +10,8 @@ use flux_common::{ span_bug, }; use flux_config as config; -use flux_fixpoint as fixpoint; +use flux_fixpoint::FixpointResult; +// use flux_fixpoint as fixpoint; use flux_middle::{ fhir::FuncKind, global_env::GlobalEnv, @@ -57,6 +57,64 @@ pub enum KVarEncoding { Conj, } +pub mod fixpoint { + use std::fmt; + + use rustc_index::newtype_index; + + newtype_index! { + pub struct KVid {} + } + + newtype_index! { + pub struct LocalVar {} + } + + newtype_index! { + pub struct GlobalVar {} + } + + #[derive(Hash, Debug, Copy, Clone)] + pub enum Var { + Global(GlobalVar), + Local(LocalVar), + } + + impl From for Var { + fn from(v: GlobalVar) -> Self { + Self::Global(v) + } + } + + impl From for Var { + fn from(v: LocalVar) -> Self { + Self::Local(v) + } + } + + impl fmt::Display for KVid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "k{}", self.as_u32()) + } + } + + impl fmt::Display for Var { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Var::Global(v) => write!(f, "c{}", v.as_u32()), + Var::Local(v) => write!(f, "a{}", v.as_u32()), + } + } + } + + flux_fixpoint::declare_types! { + type KVar = KVid; + type Var = Var; + type Tag = super::TagIdx; + } + pub use fixpoint_generated::*; +} + type KVidMap = UnordMap>; type ConstMap = FxIndexMap; @@ -89,22 +147,22 @@ struct FixpointKVar { /// Environment used to map [`rty::Var`] into [`fixpoint::Name`]. This only supports /// mapping of [`rty::Var::LateBound`] and [`rty::Var::Free`]. struct Env { - name_gen: IndexGen, - fvars: UnordMap, + local_var_gen: IndexGen, + fvars: UnordMap, /// Layers of late bound variables - layers: Vec>, + layers: Vec>, } impl Env { fn new() -> Self { - Self { name_gen: IndexGen::new(), fvars: Default::default(), layers: Vec::new() } + Self { local_var_gen: IndexGen::new(), fvars: Default::default(), layers: Vec::new() } } - fn fresh_name(&self) -> fixpoint::Name { - self.name_gen.fresh() + fn fresh_name(&self) -> fixpoint::LocalVar { + self.local_var_gen.fresh() } - fn insert_fvar_map(&mut self, name: rty::Name) -> fixpoint::Name { + fn insert_fvar_map(&mut self, name: rty::Name) -> fixpoint::LocalVar { let fresh = self.fresh_name(); self.fvars.insert(name, fresh); fresh @@ -114,11 +172,11 @@ impl Env { self.fvars.remove(&name); } - fn get_fvar(&self, name: rty::Name) -> Option { + fn get_fvar(&self, name: rty::Name) -> Option { self.fvars.get(&name).copied() } - fn get_late_bvar(&self, debruijn: DebruijnIndex, idx: u32) -> Option { + fn get_late_bvar(&self, debruijn: DebruijnIndex, idx: u32) -> Option { let depth = self.layers.len().checked_sub(debruijn.as_usize() + 1)?; self.layers[depth].get(idx as usize).copied() } @@ -128,7 +186,7 @@ impl Env { self.layers.push(layer); } - fn last_layer(&self) -> &[fixpoint::Name] { + fn last_layer(&self) -> &[fixpoint::LocalVar] { self.layers.last().unwrap() } } @@ -140,17 +198,27 @@ struct ExprCtxt<'a> { dbg_span: Span, } -#[derive(Debug)] struct ConstInfo { - name: fixpoint::ConstName, + name: fixpoint::GlobalVar, sym: rustc_span::Symbol, sort: fixpoint::Sort, val: Option, } /// An alias for additional bindings introduced when ANF-ing index expressions -/// in the course of conversion to fixpoint. -type Bindings = Vec<(fixpoint::Name, fixpoint::Sort, fixpoint::Expr)>; +/// in the course of encoding into fixpoint. +pub type Bindings = Vec<(fixpoint::LocalVar, fixpoint::Sort, fixpoint::Expr)>; + +pub fn stitch(bindings: Bindings, c: fixpoint::Constraint) -> fixpoint::Constraint { + bindings.into_iter().rev().fold(c, |c, (name, sort, e)| { + fixpoint::Constraint::ForAll( + fixpoint::Var::Local(name), + sort, + fixpoint::Pred::Expr(e), + Box::new(c), + ) + }) +} /// An alias for a list of predicate (conjuncts) and their spans, used to give /// localized errors when refine checking fails. @@ -179,7 +247,7 @@ where pub(crate) fn with_name_map( &mut self, name: rty::Name, - f: impl FnOnce(&mut Self, fixpoint::Name) -> R, + f: impl FnOnce(&mut Self, fixpoint::LocalVar) -> R, ) -> R { let fresh = self.env.insert_fvar_map(name); let r = f(self, fresh); @@ -188,11 +256,11 @@ where } fn assume_const_val( - cstr: fixpoint::Constraint, - const_name: fixpoint::ConstName, + cstr: fixpoint::Constraint, + var: fixpoint::GlobalVar, const_val: Constant, - ) -> fixpoint::Constraint { - let e1 = fixpoint::Expr::from(const_name); + ) -> fixpoint::Constraint { + let e1 = fixpoint::Expr::Var(fixpoint::Var::Global(var)); let e2 = fixpoint::Expr::Constant(const_val); let pred = fixpoint::Pred::Expr(e1.eq(e2)); fixpoint::Constraint::Guard(pred, Box::new(cstr)) @@ -201,7 +269,7 @@ where pub fn check( self, cache: &mut QueryCache, - constraint: fixpoint::Constraint, + constraint: fixpoint::Constraint, config: &CheckerConfig, ) -> QueryResult> { if !constraint.is_concrete() { @@ -236,7 +304,7 @@ where .into_values() .map(|const_info| { fixpoint::ConstInfo { - name: const_info.name, + name: fixpoint::Var::Global(const_info.name), orig: const_info.sym, sort: const_info.sort, } @@ -325,19 +393,20 @@ where let decl = self.kvars.get(kvar.kvid); let all_args = iter::zip(&kvar.args, &decl.sorts) - .map(|(arg, sort)| self.imm(arg, sort, bindings)) + .map(|(arg, sort)| fixpoint::Var::Local(self.imm(arg, sort, bindings))) .collect_vec(); let kvids = &self.kvid_map[&kvar.kvid]; if all_args.is_empty() { let fresh = self.env.fresh_name(); + let var = fixpoint::Var::Local(fresh); bindings.push(( fresh, fixpoint::Sort::Unit, - fixpoint::Expr::eq(fixpoint::Expr::Var(fresh), fixpoint::Expr::Unit), + fixpoint::Expr::eq(fixpoint::Expr::Var(var), fixpoint::Expr::Unit), )); - return fixpoint::Pred::KVar(kvids[0], vec![fresh]); + return fixpoint::Pred::KVar(kvids[0], vec![var]); } let kvars = kvids @@ -386,8 +455,8 @@ where &self, arg: &rty::Expr, sort: &rty::Sort, - bindings: &mut Vec<(fixpoint::Name, fixpoint::Sort, fixpoint::Expr)>, - ) -> fixpoint::Name { + bindings: &mut Vec<(fixpoint::LocalVar, fixpoint::Sort, fixpoint::Expr)>, + ) -> fixpoint::LocalVar { match arg.kind() { rty::ExprKind::Var(rty::Var::Free(name)) => { self.env.get_fvar(*name).unwrap_or_else(|| { @@ -400,7 +469,7 @@ where _ => { let fresh = self.env.fresh_name(); let pred = fixpoint::Expr::eq( - fixpoint::Expr::Var(fresh), + fixpoint::Expr::Var(fresh.into()), self.as_expr_cx().expr_to_fixpoint(arg), ); bindings.push((fresh, sort_to_fixpoint(sort), pred)); @@ -606,7 +675,7 @@ impl<'a> ExprCtxt<'a> { fn expr_to_fixpoint(&self, expr: &rty::Expr) -> fixpoint::Expr { match expr.kind() { - rty::ExprKind::Var(var) => fixpoint::Expr::Var(self.var_to_fixpoint(var)), + rty::ExprKind::Var(var) => fixpoint::Expr::Var(self.var_to_fixpoint(var).into()), rty::ExprKind::Constant(c) => fixpoint::Expr::Constant(*c), rty::ExprKind::BinaryOp(op, e1, e2) => { fixpoint::Expr::BinaryOp( @@ -629,7 +698,7 @@ impl<'a> ExprCtxt<'a> { let const_info = self.const_map.get(&Key::Const(*did)).unwrap_or_else(|| { span_bug!(self.dbg_span, "no entry found in const_map for def_id: `{did:?}`") }); - fixpoint::Expr::ConstVar(const_info.name) + fixpoint::Expr::Var(const_info.name.into()) } rty::ExprKind::App(func, args) => { let func = self.func_to_fixpoint(func); @@ -654,7 +723,7 @@ impl<'a> ExprCtxt<'a> { } } - fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::Name { + fn var_to_fixpoint(&self, var: &rty::Var) -> fixpoint::LocalVar { match var { rty::Var::Free(name) => { self.env.get_fvar(*name).unwrap_or_else(|| { @@ -696,7 +765,7 @@ impl<'a> ExprCtxt<'a> { fn func_to_fixpoint(&self, func: &rty::Expr) -> fixpoint::Func { match func.kind() { - rty::ExprKind::Var(var) => fixpoint::Func::Var(self.var_to_fixpoint(var)), + rty::ExprKind::Var(var) => fixpoint::Func::Var(self.var_to_fixpoint(var).into()), rty::ExprKind::GlobalFunc(_, FuncKind::Thy(sym)) => fixpoint::Func::Itf(*sym), rty::ExprKind::GlobalFunc(sym, FuncKind::Uif) => { let cinfo = self.const_map.get(&Key::Uif(*sym)).unwrap_or_else(|| { @@ -705,7 +774,7 @@ impl<'a> ExprCtxt<'a> { "no constant found for uninterpreted function `{sym}` in `const_map`" ) }); - fixpoint::Func::Uif(cinfo.name) + fixpoint::Func::Var(cinfo.name.into()) } rty::ExprKind::GlobalFunc(sym, FuncKind::Def) => { span_bug!(self.dbg_span, "unexpected global function `{sym}`. Function must be normalized away at this point") @@ -725,9 +794,9 @@ fn qualifier_to_fixpoint( let mut env = Env::new(); env.push_layer_with_fresh_names(qualifier.body.vars().len()); - let args: Vec<(fixpoint::Name, fixpoint::Sort)> = + let args: Vec<(fixpoint::Var, fixpoint::Sort)> = iter::zip(env.last_layer(), qualifier.body.vars()) - .map(|(name, var)| (*name, sort_to_fixpoint(var.expect_sort()))) + .map(|(name, var)| ((*name).into(), sort_to_fixpoint(var.expect_sort()))) .collect(); let cx = ExprCtxt::new(&env, const_map, dbg_span); diff --git a/crates/flux-refineck/src/refine_tree.rs b/crates/flux-refineck/src/refine_tree.rs index 0ef9fdbdc1..a90b5f8778 100644 --- a/crates/flux-refineck/src/refine_tree.rs +++ b/crates/flux-refineck/src/refine_tree.rs @@ -5,7 +5,6 @@ use std::{ }; use flux_common::index::{IndexGen, IndexVec}; -use flux_fixpoint as fixpoint; use flux_middle::rty::{ box_args, evars::EVarSol, @@ -18,7 +17,7 @@ use itertools::Itertools; use crate::{ constraint_gen::Tag, - fixpoint_encoding::{sort_to_fixpoint, FixpointCtxt, TagIdx}, + fixpoint_encoding::{fixpoint, sort_to_fixpoint, stitch, FixpointCtxt}, }; /// A *refine*ment *tree* tracks the "tree-like structure" of refinement variables and predicates @@ -129,7 +128,7 @@ impl RefineTree { self.root.borrow_mut().simplify(); } - pub(crate) fn into_fixpoint(self, cx: &mut FixpointCtxt) -> fixpoint::Constraint { + pub(crate) fn into_fixpoint(self, cx: &mut FixpointCtxt) -> fixpoint::Constraint { self.root .borrow() .to_fixpoint(cx) @@ -529,7 +528,7 @@ impl Node { } } - fn to_fixpoint(&self, cx: &mut FixpointCtxt) -> Option> { + fn to_fixpoint(&self, cx: &mut FixpointCtxt) -> Option { match &self.kind { NodeKind::Comment(_) | NodeKind::Conj | NodeKind::ForAll(_, Sort::Loc) => { children_to_fixpoint(cx, &self.children) @@ -537,7 +536,7 @@ impl Node { NodeKind::ForAll(name, sort) => { cx.with_name_map(*name, |cx, fresh| { Some(fixpoint::Constraint::ForAll( - fresh, + fixpoint::Var::Local(fresh), sort_to_fixpoint(sort), fixpoint::Pred::TRUE, Box::new(children_to_fixpoint(cx, &self.children)?), @@ -588,7 +587,7 @@ impl Node { fn children_to_fixpoint( cx: &mut FixpointCtxt, children: &[NodePtr], -) -> Option> { +) -> Option { let mut children = children .iter() .filter_map(|node| node.borrow().to_fixpoint(cx)) @@ -600,15 +599,6 @@ fn children_to_fixpoint( } } -fn stitch( - bindings: Vec<(fixpoint::Name, fixpoint::Sort, fixpoint::Expr)>, - c: fixpoint::Constraint, -) -> fixpoint::Constraint { - bindings.into_iter().rev().fold(c, |c, (name, sort, e)| { - fixpoint::Constraint::ForAll(name, sort, fixpoint::Pred::Expr(e), Box::new(c)) - }) -} - struct ParentsIter { ptr: Option, }