diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 68ac572..b65d836 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1100,6 +1100,25 @@ impl FuncBuilder { sig.push(types[def.vars[def.ret.var()].ty()].ty()); sig } + + /// Return the type IDs for `left` and `right`, checking that they are defined and in scope. + fn get_lr(&self, left: usize, right: usize) -> Result<(id::Ty, id::Ty), JsError> { + let x = self + .vars + .get(left) + .ok_or_else(|| JsError::new("left is undefined"))?; + let y = self + .vars + .get(right) + .ok_or_else(|| JsError::new("right is undefined"))?; + if let Extra::Expired = x.extra { + return Err(JsError::new("left is out of scope")); + } + if let Extra::Expired = y.extra { + return Err(JsError::new("right is out of scope")); + } + Ok((x.t, y.t)) + } } /// A block under construction. @@ -1162,8 +1181,6 @@ impl Block { self.instr(f, id::ty(t), rose::Expr::Member { tuple, member }) } - // unary - /// Return the variable ID for a new boolean negation instruction on `arg`. /// /// Assumes `arg` is defined, in scope, and has boolean type. @@ -1176,18 +1193,6 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new floating-point negation instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = f.vars[arg].t; - let expr = rose::Expr::Unary { - op: rose::Unop::Neg, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - /// Return the variable ID for a new absolute value instruction on `arg`. /// /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. @@ -1260,10 +1265,6 @@ impl Block { self.instr(f, t, expr) } - // end of unary - - // binary - /// Return the variable ID for a new logical conjunction instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have boolean type. @@ -1394,60 +1395,6 @@ impl Block { self.instr(f, t, expr) } - /// Return the variable ID for a new addition instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn add(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = f.vars[left].t; - let expr = rose::Expr::Binary { - op: rose::Binop::Add, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new subtraction instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn sub(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = f.vars[left].t; - let expr = rose::Expr::Binary { - op: rose::Binop::Sub, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new multiplication instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn mul(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = f.vars[left].t; - let expr = rose::Expr::Binary { - op: rose::Binop::Mul, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new division instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn div(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = f.vars[left].t; - let expr = rose::Expr::Binary { - op: rose::Binop::Div, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - // end of binary - /// Return the variable ID for a new instruction using `cond` to choose `then` or `els`. /// /// Assumes `cond`, `then`, and `els` are defined and in scope, that `cond` has boolean type, @@ -1539,4 +1486,114 @@ impl Block { let expr = rose::Expr::Resolve { var: id::var(var) }; self.instr(f, id::ty(t), expr) } + + /// Return the variable ID for a new floating-point negation instruction on `arg`. + /// + /// `Err` if `arg` is undefined or out of scope, or if its type is not `F64` or `T64`. + pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> Result { + let x = f + .vars + .get(arg) + .ok_or_else(|| JsError::new("arg is undefined"))?; + if let Extra::Expired = x.extra { + return Err(JsError::new("arg is out of scope")); + } + let t = x.t; + if !(t.ty() == f.ty_f64() || t.ty() == f.ty_t64()) { + return Err(JsError::new("arg has invalid type")); + } + let expr = rose::Expr::Unary { + op: rose::Unop::Neg, + arg: id::var(arg), + }; + Ok(self.instr(f, t, expr)) + } + + /// Return the variable ID for a new addition instruction on `left` and `right`. + /// + /// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either + /// both `F64` or both `T64`. + pub fn add( + &mut self, + f: &mut FuncBuilder, + left: usize, + right: usize, + ) -> Result { + let (t1, t2) = f.get_lr(left, right)?; + if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) { + return Err(JsError::new("left and right have invalid types")); + } + let expr = rose::Expr::Binary { + op: rose::Binop::Add, + left: id::var(left), + right: id::var(right), + }; + Ok(self.instr(f, t1, expr)) + } + + /// Return the variable ID for a new subtraction instruction on `left` and `right`. + /// + /// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either + /// both `F64` or both `T64`. + pub fn sub( + &mut self, + f: &mut FuncBuilder, + left: usize, + right: usize, + ) -> Result { + let (t1, t2) = f.get_lr(left, right)?; + if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) { + return Err(JsError::new("left and right have invalid types")); + } + let expr = rose::Expr::Binary { + op: rose::Binop::Sub, + left: id::var(left), + right: id::var(right), + }; + Ok(self.instr(f, t1, expr)) + } + + /// Return the variable ID for a new multiplication instruction on `left` and `right`. + /// + /// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or + /// `T64`, or if `right`'s type is not `F64`. + pub fn mul( + &mut self, + f: &mut FuncBuilder, + left: usize, + right: usize, + ) -> Result { + let (t1, t2) = f.get_lr(left, right)?; + if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) { + return Err(JsError::new("left and right have invalid types")); + } + let expr = rose::Expr::Binary { + op: rose::Binop::Mul, + left: id::var(left), + right: id::var(right), + }; + Ok(self.instr(f, t1, expr)) + } + + /// Return the variable ID for a new division instruction on `left` and `right`. + /// + /// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or + /// `T64`, or if `right`'s type is not `F64`. + pub fn div( + &mut self, + f: &mut FuncBuilder, + left: usize, + right: usize, + ) -> Result { + let (t1, t2) = f.get_lr(left, right)?; + if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) { + return Err(JsError::new("left and right have invalid types")); + } + let expr = rose::Expr::Binary { + op: rose::Binop::Div, + left: id::var(left), + right: id::var(right), + }; + Ok(self.instr(f, t1, expr)) + } } diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index c8192eb..f8ef761 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -8,7 +8,6 @@ import { fn, inner, mul, - mulLin, neg, opaque, vjp, @@ -48,13 +47,13 @@ fn f0 = <>{ exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { const y = exp(x); - return { re: y, du: mulLin(dx, y) }; + return { re: y, du: mul(dx, y) }; }); sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: sin(x), du: mulLin(dx, cos(x)) }; + return { re: sin(x), du: mul(dx, cos(x)) }; }); cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: cos(x), du: mulLin(dx, neg(sin(x))) }; + return { re: cos(x), du: mul(dx, neg(sin(x))) }; }); const Complex = { re: Real, im: Real } as const; diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index f421750..36f3a07 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -80,8 +80,14 @@ export type Bool = boolean | Var; /** An abstract 64-bit floating point number. */ export type Real = number | Var; +/** The zero tangent. */ +const zeroSymbol = Symbol("zero"); + +/** The zero tangent. */ +type Zero = typeof zeroSymbol; + /** An abstract 64-bit floating point tangent number. */ -export type Tan = number | Var; +export type Tan = Zero | Var; /** An abstract natural number, which can be used to index into a vector. */ type Nat = number | symbol; @@ -209,7 +215,7 @@ const valId = (ctx: Context, t: number, x: unknown): number => { } } } - } else throw Error(`invalid value: ${x}`); + } else throw Error("invalid value"); map.set(x, id); return id; @@ -540,7 +546,7 @@ const pack = (f: Fn, t: number, x: unknown): RawVal => { } return { Tuple: vals }; } - } else throw Error(`invalid value: ${x}`); + } else throw Error("invalid value"); }; /** Translate a concrete value from the interpreter's raw format. */ @@ -944,12 +950,6 @@ export const select = ( const realId = (ctx: Context, x: Real): number => valId(ctx, ctx.func.tyF64(), x); -/** Return the negative of the abstract number `x`. */ -export const neg = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.neg(ctx.func, realId(ctx, x))); -}; - /** Return the absolute value of the abstract number `x`. */ export const abs = (x: Real): Real => { const ctx = getCtx(); @@ -986,30 +986,6 @@ export const sqrt = (x: Real): Real => { return newVar(ctx.block.sqrt(ctx.func, realId(ctx, x))); }; -/** Return the abstract number `x` plus the abstract number `y`. */ -export const add = (x: Real, y: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.add(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return the abstract number `x` minus the abstract number `y`. */ -export const sub = (x: Real, y: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.sub(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return the abstract number `x` times the abstract number `y`. */ -export const mul = (x: Real, y: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.mul(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return the abstract number `x` divided by the abstract number `y`. */ -export const div = (x: Real, y: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.div(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - /** Return an abstract boolean for if `x` is not equal to `y`. */ export const neq = (x: Real, y: Real): Bool => { const ctx = getCtx(); @@ -1071,35 +1047,68 @@ export const vec = ( return idVal(ctx, t, id) as Vec>; }; -/** Return the variable ID for the abstract floating point tangent `x`. */ -const tanId = (ctx: Context, x: Tan): number => valId(ctx, ctx.func.tyT64(), x); +/** Return the variable ID for the abstract number or tangent `x`. */ +const numId = (ctx: Context, x: Real | Tan): number => { + if (typeof x === "object") return (x as any)[variable]; + let t = x === zeroSymbol ? ctx.func.tyT64() : ctx.func.tyF64(); + const map = typeMap(ctx, t); + let id = map.get(x); + if (id !== undefined) return id; // constant, so can't be out of scope + if (!(typeof x === "number")) throw Error("invalid value"); + id = ctx.func.num(t, x); + map.set(x, id); + return id; +}; -/** Return the negative of the abstract tangent `x`. */ -export const negLin = (x: Tan): Tan => { +/** Return the zero tangent. */ +export const zero = (): Tan => { const ctx = getCtx(); - return newVar(ctx.block.neg(ctx.func, tanId(ctx, x))); + const t = ctx.func.tyT64(); + typeMap(ctx, t).set(zeroSymbol, ctx.func.num(t, 0)); + return zeroSymbol; }; -/** Return the abstract tangent `x` plus the abstract tangent `y`. */ -export const addLin = (x: Tan, y: Tan): Tan => { +/** Return the negative of the abstract number `x`. */ +export const neg: { + (x: Real): Real; + (x: Tan): Tan; +} = (x: Real | Tan): Var => { + const ctx = getCtx(); + return newVar(ctx.block.neg(ctx.func, numId(ctx, x))); +}; + +/** Return the abstract number `x` plus the abstract number `y`. */ +export const add: { + (x: Real, y: Real): Real; + (x: Tan, y: Tan): Tan; +} = (x: Real | Tan, y: Real | Tan): Var => { const ctx = getCtx(); - return newVar(ctx.block.add(ctx.func, tanId(ctx, x), tanId(ctx, y))); + return newVar(ctx.block.add(ctx.func, numId(ctx, x), numId(ctx, y))); }; -/** Return the abstract tangent `x` minus the abstract tangent `y`. */ -export const subLin = (x: Tan, y: Tan): Tan => { +/** Return the abstract number `x` minus the abstract number `y`. */ +export const sub: { + (x: Real, y: Real): Real; + (x: Tan, y: Tan): Tan; +} = (x: Real | Tan, y: Real | Tan): Var => { const ctx = getCtx(); - return newVar(ctx.block.sub(ctx.func, tanId(ctx, x), tanId(ctx, y))); + return newVar(ctx.block.sub(ctx.func, numId(ctx, x), numId(ctx, y))); }; -/** Return the abstract tangent `x` times the abstract number `y`. */ -export const mulLin = (x: Tan, y: Real): Tan => { +/** Return the abstract number `x` times the abstract number `y`. */ +export const mul: { + (x: Real, y: Real): Real; + (x: Tan, y: Real): Tan; +} = (x: Real | Tan, y: Real): Var => { const ctx = getCtx(); - return newVar(ctx.block.mul(ctx.func, tanId(ctx, x), realId(ctx, y))); + return newVar(ctx.block.mul(ctx.func, numId(ctx, x), numId(ctx, y))); }; -/** Return the abstract tangent `x` divided by the abstract number `y`. */ -export const divLin = (x: Tan, y: Real): Tan => { +/** Return the abstract number `x` divided by the abstract number `y`. */ +export const div: { + (x: Real, y: Real): Real; + (x: Tan, y: Real): Tan; +} = (x: Real | Tan, y: Real): Var => { const ctx = getCtx(); - return newVar(ctx.block.div(ctx.func, tanId(ctx, x), realId(ctx, y))); + return newVar(ctx.block.div(ctx.func, numId(ctx, x), numId(ctx, y))); }; diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 51c39b5..ae901ce 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -18,7 +18,6 @@ import { interp, jvp, mul, - mulLin, neg, not, opaque, @@ -31,6 +30,7 @@ import { vec, vjp, xor, + zero, } from "./index.js"; describe("invalid", () => { @@ -54,9 +54,7 @@ describe("invalid", () => { test("add argument type", () => { const two = true as any; - expect(() => fn([], Real, () => add(two, two))).toThrow( - "did not expect boolean", - ); + expect(() => fn([], Real, () => add(two, two))).toThrow("invalid value"); }); test("invalid index type", () => { @@ -427,7 +425,7 @@ describe("valid", () => { const epsilon = 1e-5; f.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { const y = f(x); - return { re: y, du: mulLin(dx, div(1 / 2, max(epsilon, y))) }; + return { re: y, du: mul(dx, div(1 / 2, max(epsilon, y))) }; }); const g = fn([Real, Real], Real, (x, y) => vjp(f)(x).grad(y)); expect(interp(g)(0, 1)).toBeCloseTo(50000); @@ -435,7 +433,7 @@ describe("valid", () => { test("custom JVP with zero tangent", () => { const signum = opaque([Real], Real, Math.sign); - signum.jvp = fn([Dual], Dual, ({ re: x }) => ({ re: sign(x), du: 0 })); + signum.jvp = fn([Dual], Dual, ({ re: x }) => ({ re: sign(x), du: zero() })); const f = interp(jvp(signum)); expect(f({ re: 2, du: 1 })).toEqual({ re: 1, du: 0 }); }); @@ -605,10 +603,10 @@ describe("valid", () => { const cos = opaque([Real], Real, Math.cos); sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: sin(x), du: mulLin(dx, cos(x)) }; + return { re: sin(x), du: mul(dx, cos(x)) }; }); cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: cos(x), du: mulLin(dx, neg(sin(x))) }; + return { re: cos(x), du: mul(dx, neg(sin(x))) }; }); let f = sin; @@ -787,7 +785,7 @@ describe("valid", () => { const exp = opaque([Real], Real, Math.exp); exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { const y = exp(x); - return { re: y, du: mulLin(dx, y) }; + return { re: y, du: mul(dx, y) }; }); const g = fn([Real], Real, (x) => exp(x)); const h = fn([Real], Real, (x) => vjp(g)(x).ret); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index bf15ac5..5506faf 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -8,12 +8,10 @@ export { Vec, abs, add, - addLin, and, ceil, compile, div, - divLin, eq, floor, fn, @@ -25,9 +23,7 @@ export { leq, lt, mul, - mulLin, neg, - negLin, neq, not, opaque, @@ -36,9 +32,9 @@ export { sign, sqrt, sub, - subLin, trunc, vec, vjp, xor, + zero, } from "./impl.js";