diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs index cfde52c..c20a1c9 100644 --- a/crates/autodiff/src/lib.rs +++ b/crates/autodiff/src/lib.rs @@ -177,11 +177,18 @@ impl Autodiff<'_> { self.pack(var, x, dx) } &Expr::Unary { op, arg } => match op { - // boring case + // boring cases Unop::Not => self.code.push(Instr { var, expr: Expr::Unary { op: Unop::Not, arg }, }), + Unop::IMod => self.code.push(Instr { + var, + expr: Expr::Unary { + op: Unop::IMod, + arg, + }, + }), // interesting cases Unop::Neg => { diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index bceba08..898f92c 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -192,6 +192,9 @@ pub enum Unop { // `Bool` -> `Bool` Not, + // `Fin` -> `Fin` + IMod, + // `F64` -> `F64` Neg, Abs, diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 7aa6fc8..7ec1fdd 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -218,6 +218,14 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { match op { Unop::Not => Val::Bool(!x.bool()), + Unop::IMod => { + let n = match self.typemap[self.types[self.def.vars[arg.var()].ty()].ty()] { + Ty::Fin { size } => size, + _ => unreachable!(), + }; + Val::Fin(x.fin() % n) + } + Unop::Neg => val_f64(-x.f64()), Unop::Abs => val_f64(x.f64().abs()), Unop::Sign => val_f64(x.f64().signum()), diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index ba1a901..67929d2 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -572,6 +572,7 @@ impl<'a> Transpose<'a> { match self.f.vars[var.var()] { DUAL => match op { Unop::Not + | Unop::IMod | Unop::Abs | Unop::Sign | Unop::Ceil @@ -601,7 +602,7 @@ impl<'a> Transpose<'a> { }, _ => { let x = match op { - Unop::Not => arg, + Unop::Not | Unop::IMod => arg, Unop::Neg | Unop::Abs | Unop::Sign diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 5b088f6..23cbbd5 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -737,6 +737,15 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.get(arg); self.wasm.instruction(&Instruction::I32Eqz); } + Unop::IMod => { + let n = match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Fin { size } => size, + _ => unreachable!(), + }; + self.get(arg); + self.wasm.instruction(&Instruction::I32Const(n as i32)); + self.wasm.instruction(&Instruction::I32RemU); + } Unop::Neg => { self.get(arg); self.wasm.instruction(&Instruction::F64Neg); diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index aa58bd0..8c209f0 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -947,12 +947,10 @@ impl FuncBuilder { pub fn num(&mut self, t: usize, x: f64) -> Result { match self.ty(t)? { Ty::F64 | Ty::T64 => Ok(self.constant(t, rose::Expr::F64 { val: x })), - &Ty::Fin { size } => { + &Ty::Fin { .. } => { let y = x as usize; if y as f64 != x { Err(JsError::new("can't be represented by an unsigned integer")) - } else if y >= size { - Err(JsError::new("out of range")) } else { Ok(self.constant(t, rose::Expr::Fin { val: y })) } @@ -1194,6 +1192,17 @@ impl Block { self.instr(f, t, expr) } + /// Return the variable ID for a new index modulus instruction on `arg`. + /// + /// Assumes `arg` is defined, in scope, and has boolean type. + pub fn imod(&mut self, f: &mut FuncBuilder, t: usize, arg: usize) -> usize { + let expr = rose::Expr::Unary { + op: rose::Unop::IMod, + arg: id::var(arg), + }; + self.instr(f, id::ty(t), expr) + } + /// Return the variable ID for a new absolute value instruction on `arg`. /// /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs index c616cbc..6b4bee0 100644 --- a/crates/web/src/pprint.rs +++ b/crates/web/src/pprint.rs @@ -158,6 +158,7 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { Expr::Field { tuple, member } => writeln!(f, "&x{}[{}]", tuple.var(), member.member())?, Expr::Unary { op, arg } => match op { Unop::Not => writeln!(f, "not x{}", arg.var())?, + Unop::IMod => writeln!(f, "x{} mod T{}", arg.var(), self.def.vars[x].ty())?, Unop::Neg => writeln!(f, "-x{}", arg.var())?, Unop::Abs => writeln!(f, "|x{}|", arg.var())?, Unop::Sign => writeln!(f, "sign(x{})", arg.var())?, diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 3c0a135..dc191ef 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -930,6 +930,14 @@ export const not = (p: Bool): Bool => { return newVar(ctx.block.not(ctx.func, boolId(ctx, p))); }; +/** Return the modulus of the abstract index `i`. */ +export const imod = (ty: Nats, i: Nat): Nat => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + const j = ctx.block.imod(ctx.func, t, valId(ctx, t, i)); + return idVal(ctx, t, j) as Nat; +}; + /** Return the conjunction of the abstract booleans `p` and `q`. */ export const and = (p: Bool, q: Bool): Bool => { const ctx = getCtx(); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 77c59d7..885e3d0 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -21,6 +21,7 @@ import { igt, ileq, ilt, + imod, ineq, interp, jvp, @@ -83,10 +84,6 @@ describe("invalid", () => { ); }); - test("out of bounds index", () => { - expect(() => fn([Vec(2, Real)], Real, (v) => v[2])).toThrow("out of range"); - }); - test("access index out of scope", () => { const n = 2; expect(() => @@ -1052,4 +1049,15 @@ describe("valid", () => { expect(g(1, 1)).toBe(2); expect(g(2, 0)).toBe(2); }); + + test("index modulus", async () => { + const f = fn([], Vec(7, 3), () => { + const v = []; + for (let i = 0; i < 7; ++i) v.push(imod(3, i)); + return v; + }); + const expected = [0, 1, 2, 0, 1, 2, 0]; + expect(interp(f)()).toEqual(expected); + expect((await compile(f))()).toEqual(expected); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index e803e2d..ce0476c 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -33,6 +33,7 @@ export { igt, ileq, ilt, + imod, ineq, interp, jvp,