Skip to content

Commit

Permalink
Use the nice function names for linear arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Oct 6, 2023
1 parent 3ae4bc1 commit 90ec816
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 140 deletions.
201 changes: 129 additions & 72 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,25 @@ impl FuncBuilder {
sig.push(types[def.vars[def.ret.var()].ty()].ty());
sig
}

/// Return the type IDs for `left` and `right`, checking that they are defined and in scope.
fn get_lr(&self, left: usize, right: usize) -> Result<(id::Ty, id::Ty), JsError> {
let x = self
.vars
.get(left)
.ok_or_else(|| JsError::new("left is undefined"))?;
let y = self
.vars
.get(right)
.ok_or_else(|| JsError::new("right is undefined"))?;
if let Extra::Expired = x.extra {
return Err(JsError::new("left is out of scope"));
}
if let Extra::Expired = y.extra {
return Err(JsError::new("right is out of scope"));
}
Ok((x.t, y.t))
}
}

/// A block under construction.
Expand Down Expand Up @@ -1162,8 +1181,6 @@ impl Block {
self.instr(f, id::ty(t), rose::Expr::Member { tuple, member })
}

// unary

/// Return the variable ID for a new boolean negation instruction on `arg`.
///
/// Assumes `arg` is defined, in scope, and has boolean type.
Expand All @@ -1176,18 +1193,6 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new floating-point negation instruction on `arg`.
///
/// Assumes `arg` is defined, in scope, and has 64-bit floating point type.
pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> usize {
let t = f.vars[arg].t;
let expr = rose::Expr::Unary {
op: rose::Unop::Neg,
arg: id::var(arg),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new absolute value instruction on `arg`.
///
/// Assumes `arg` is defined, in scope, and has 64-bit floating point type.
Expand Down Expand Up @@ -1260,10 +1265,6 @@ impl Block {
self.instr(f, t, expr)
}

// end of unary

// binary

/// Return the variable ID for a new logical conjunction instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have boolean type.
Expand Down Expand Up @@ -1394,60 +1395,6 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new addition instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn add(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Add,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new subtraction instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn sub(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Sub,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new multiplication instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn mul(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Mul,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new division instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn div(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Div,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

// end of binary

/// Return the variable ID for a new instruction using `cond` to choose `then` or `els`.
///
/// Assumes `cond`, `then`, and `els` are defined and in scope, that `cond` has boolean type,
Expand Down Expand Up @@ -1539,4 +1486,114 @@ impl Block {
let expr = rose::Expr::Resolve { var: id::var(var) };
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new floating-point negation instruction on `arg`.
///
/// `Err` if `arg` is undefined or out of scope, or if its type is not `F64` or `T64`.
pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> Result<usize, JsError> {
let x = f
.vars
.get(arg)
.ok_or_else(|| JsError::new("arg is undefined"))?;
if let Extra::Expired = x.extra {
return Err(JsError::new("arg is out of scope"));
}
let t = x.t;
if !(t.ty() == f.ty_f64() || t.ty() == f.ty_t64()) {
return Err(JsError::new("arg has invalid type"));
}
let expr = rose::Expr::Unary {
op: rose::Unop::Neg,
arg: id::var(arg),
};
Ok(self.instr(f, t, expr))
}

/// Return the variable ID for a new addition instruction on `left` and `right`.
///
/// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either
/// both `F64` or both `T64`.
pub fn add(
&mut self,
f: &mut FuncBuilder,
left: usize,
right: usize,
) -> Result<usize, JsError> {
let (t1, t2) = f.get_lr(left, right)?;
if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) {
return Err(JsError::new("left and right have invalid types"));
}
let expr = rose::Expr::Binary {
op: rose::Binop::Add,
left: id::var(left),
right: id::var(right),
};
Ok(self.instr(f, t1, expr))
}

/// Return the variable ID for a new subtraction instruction on `left` and `right`.
///
/// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either
/// both `F64` or both `T64`.
pub fn sub(
&mut self,
f: &mut FuncBuilder,
left: usize,
right: usize,
) -> Result<usize, JsError> {
let (t1, t2) = f.get_lr(left, right)?;
if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) {
return Err(JsError::new("left and right have invalid types"));
}
let expr = rose::Expr::Binary {
op: rose::Binop::Sub,
left: id::var(left),
right: id::var(right),
};
Ok(self.instr(f, t1, expr))
}

/// Return the variable ID for a new multiplication instruction on `left` and `right`.
///
/// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or
/// `T64`, or if `right`'s type is not `F64`.
pub fn mul(
&mut self,
f: &mut FuncBuilder,
left: usize,
right: usize,
) -> Result<usize, JsError> {
let (t1, t2) = f.get_lr(left, right)?;
if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) {
return Err(JsError::new("left and right have invalid types"));
}
let expr = rose::Expr::Binary {
op: rose::Binop::Mul,
left: id::var(left),
right: id::var(right),
};
Ok(self.instr(f, t1, expr))
}

/// Return the variable ID for a new division instruction on `left` and `right`.
///
/// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or
/// `T64`, or if `right`'s type is not `F64`.
pub fn div(
&mut self,
f: &mut FuncBuilder,
left: usize,
right: usize,
) -> Result<usize, JsError> {
let (t1, t2) = f.get_lr(left, right)?;
if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) {
return Err(JsError::new("left and right have invalid types"));
}
let expr = rose::Expr::Binary {
op: rose::Binop::Div,
left: id::var(left),
right: id::var(right),
};
Ok(self.instr(f, t1, expr))
}
}
7 changes: 3 additions & 4 deletions packages/core/src/impl.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
fn,
inner,
mul,
mulLin,
neg,
opaque,
vjp,
Expand Down Expand Up @@ -48,13 +47,13 @@ fn f0 = <>{

exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = exp(x);
return { re: y, du: mulLin(dx, y) };
return { re: y, du: mul(dx, y) };
});
sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
return { re: sin(x), du: mulLin(dx, cos(x)) };
return { re: sin(x), du: mul(dx, cos(x)) };
});
cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
return { re: cos(x), du: mulLin(dx, neg(sin(x))) };
return { re: cos(x), du: mul(dx, neg(sin(x))) };
});

const Complex = { re: Real, im: Real } as const;
Expand Down
Loading

0 comments on commit 90ec816

Please sign in to comment.