Skip to content

Commit

Permalink
Add very suspicious imod unop
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Jan 24, 2024
1 parent a0bb36c commit 7fd88d7
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 9 deletions.
9 changes: 8 additions & 1 deletion crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,18 @@ impl Autodiff<'_> {
self.pack(var, x, dx)
}
&Expr::Unary { op, arg } => match op {
// boring case
// boring cases
Unop::Not => self.code.push(Instr {
var,
expr: Expr::Unary { op: Unop::Not, arg },
}),
Unop::IMod => self.code.push(Instr {
var,
expr: Expr::Unary {
op: Unop::IMod,
arg,
},
}),

// interesting cases
Unop::Neg => {
Expand Down
3 changes: 3 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ pub enum Unop {
// `Bool` -> `Bool`
Not,

// `Fin` -> `Fin`
IMod,

// `F64` -> `F64`
Neg,
Abs,
Expand Down
8 changes: 8 additions & 0 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
match op {
Unop::Not => Val::Bool(!x.bool()),

Unop::IMod => {
let n = match self.typemap[self.types[self.def.vars[arg.var()].ty()].ty()] {
Ty::Fin { size } => size,
_ => unreachable!(),
};
Val::Fin(x.fin() % n)
}

Unop::Neg => val_f64(-x.f64()),
Unop::Abs => val_f64(x.f64().abs()),
Unop::Sign => val_f64(x.f64().signum()),
Expand Down
3 changes: 2 additions & 1 deletion crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ impl<'a> Transpose<'a> {
match self.f.vars[var.var()] {
DUAL => match op {
Unop::Not
| Unop::IMod
| Unop::Abs
| Unop::Sign
| Unop::Ceil
Expand Down Expand Up @@ -601,7 +602,7 @@ impl<'a> Transpose<'a> {
},
_ => {
let x = match op {
Unop::Not => arg,
Unop::Not | Unop::IMod => arg,
Unop::Neg
| Unop::Abs
| Unop::Sign
Expand Down
9 changes: 9 additions & 0 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,15 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
self.get(arg);
self.wasm.instruction(&Instruction::I32Eqz);
}
Unop::IMod => {
let n = match self.def.types[self.def.vars[instr.var.var()].ty()] {
Ty::Fin { size } => size,
_ => unreachable!(),
};
self.get(arg);
self.wasm.instruction(&Instruction::I32Const(n as i32));
self.wasm.instruction(&Instruction::I32RemU);
}
Unop::Neg => {
self.get(arg);
self.wasm.instruction(&Instruction::F64Neg);
Expand Down
15 changes: 12 additions & 3 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,12 +947,10 @@ impl FuncBuilder {
pub fn num(&mut self, t: usize, x: f64) -> Result<usize, JsError> {
match self.ty(t)? {
Ty::F64 | Ty::T64 => Ok(self.constant(t, rose::Expr::F64 { val: x })),
&Ty::Fin { size } => {
&Ty::Fin { .. } => {
let y = x as usize;
if y as f64 != x {
Err(JsError::new("can't be represented by an unsigned integer"))
} else if y >= size {
Err(JsError::new("out of range"))
} else {
Ok(self.constant(t, rose::Expr::Fin { val: y }))
}
Expand Down Expand Up @@ -1194,6 +1192,17 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new index modulus instruction on `arg`.
///
/// Assumes `arg` is defined, in scope, and has boolean type.
pub fn imod(&mut self, f: &mut FuncBuilder, t: usize, arg: usize) -> usize {
let expr = rose::Expr::Unary {
op: rose::Unop::IMod,
arg: id::var(arg),
};
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new absolute value instruction on `arg`.
///
/// Assumes `arg` is defined, in scope, and has 64-bit floating point type.
Expand Down
1 change: 1 addition & 0 deletions crates/web/src/pprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
Expr::Field { tuple, member } => writeln!(f, "&x{}[{}]", tuple.var(), member.member())?,
Expr::Unary { op, arg } => match op {
Unop::Not => writeln!(f, "not x{}", arg.var())?,
Unop::IMod => writeln!(f, "x{} mod T{}", arg.var(), self.def.vars[x].ty())?,
Unop::Neg => writeln!(f, "-x{}", arg.var())?,
Unop::Abs => writeln!(f, "|x{}|", arg.var())?,
Unop::Sign => writeln!(f, "sign(x{})", arg.var())?,
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,14 @@ export const not = (p: Bool): Bool => {
return newVar(ctx.block.not(ctx.func, boolId(ctx, p)));
};

/** Return the modulus of the abstract index `i`. */
export const imod = (ty: Nats, i: Nat): Nat => {
const ctx = getCtx();
const t = tyId(ctx, ty);
const j = ctx.block.imod(ctx.func, t, valId(ctx, t, i));
return idVal(ctx, t, j) as Nat;
};

/** Return the conjunction of the abstract booleans `p` and `q`. */
export const and = (p: Bool, q: Bool): Bool => {
const ctx = getCtx();
Expand Down
16 changes: 12 additions & 4 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
igt,
ileq,
ilt,
imod,
ineq,
interp,
jvp,
Expand Down Expand Up @@ -83,10 +84,6 @@ describe("invalid", () => {
);
});

test("out of bounds index", () => {
expect(() => fn([Vec(2, Real)], Real, (v) => v[2])).toThrow("out of range");
});

test("access index out of scope", () => {
const n = 2;
expect(() =>
Expand Down Expand Up @@ -1052,4 +1049,15 @@ describe("valid", () => {
expect(g(1, 1)).toBe(2);
expect(g(2, 0)).toBe(2);
});

test("index modulus", async () => {
const f = fn([], Vec(7, 3), () => {
const v = [];
for (let i = 0; i < 7; ++i) v.push(imod(3, i));
return v;
});
const expected = [0, 1, 2, 0, 1, 2, 0];
expect(interp(f)()).toEqual(expected);
expect((await compile(f))()).toEqual(expected);
});
});
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export {
igt,
ileq,
ilt,
imod,
ineq,
interp,
jvp,
Expand Down

0 comments on commit 7fd88d7

Please sign in to comment.