Skip to content

Commit

Permalink
Parse math expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Nov 17, 2023
1 parent 737fe75 commit 0d1dcd6
Show file tree
Hide file tree
Showing 9 changed files with 462 additions and 66 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ prettier: npm
packages: core site wasm

# run JavaScript tests
test-js: test-core
test-js: test-core test-site

## `packages/core`

Expand All @@ -68,9 +68,14 @@ test-core: npm wasm

site-deps: npm core

# build
site: site-deps
npm run --workspace=@rose-lang/site build

# test
test-site: site-deps
npm run --workspace=@rose-lang/site test -- run --no-threads

## `packages/wasm`

# build
Expand Down
2 changes: 1 addition & 1 deletion packages/site/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
></a>
</div>
<div class="example">
<input class="textbox" value="x^y" readonly />
<input class="textbox" value="x^y" id="textbox" />
<canvas width="300" height="300" id="canvas"></canvas>
</div>
<div></div>
Expand Down
3 changes: 2 additions & 1 deletion packages/site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"scripts": {
"build": "vite build",
"dev": "vite",
"preview": "vite preview"
"preview": "vite preview",
"test": "vitest"
}
}
58 changes: 0 additions & 58 deletions packages/site/src/func.ts

This file was deleted.

82 changes: 77 additions & 5 deletions packages/site/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,57 @@
import all, { Info } from "./func.js";
import { Real, Vec, compile, fn, jvp, vec, vjp } from "rose";
import { Expr, parse } from "./parse.js";

type Vec2 = [number, number];

interface Info {
val: number;
grad: Vec2;
hess: [Vec2, Vec2];
}

type Func = (x: number, y: number) => Info;

const autodiff = async (root: Expr): Promise<Func> => {
const Vec2 = Vec(2, Real);
const f = fn([Vec2], Real, (v) => {
const emit = (e: Expr): Real => {
switch (e.kind) {
case "const":
return e.val;
case "var":
return v[e.idx];
case "unary":
return e.f(emit(e.arg));
case "binary":
return e.f(emit(e.lhs), emit(e.rhs));
}
};
return emit(root);
});

const Mat2 = Vec(2, Vec2);
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
const h = fn([Vec2], Mat2, ([x, y]) => {
const d = jvp(g);
const a = d([
{ re: x, du: 1 },
{ re: y, du: 0 },
]);
const b = d([
{ re: x, du: 0 },
{ re: y, du: 1 },
]);
return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
});

return (await compile(
fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
const v = [x, y];
return { val: f(v), grad: g(v), hess: h(v) };
}),
)) as unknown as Func;
};

interface Parabola {
/** coefficient of square term */
a: number;
Expand Down Expand Up @@ -75,7 +125,13 @@ const bezier = (
): [Vec2, Vec2, Vec2] => {
const l1 = pointSlope(parabola, x1);
const l2 = pointSlope(parabola, x2);
const [x3, y3] = intersectPointSlope(l1, l2);
let [x3, y3] = intersectPointSlope(l1, l2);
if (!(Number.isFinite(x3) && Number.isFinite(y3))) {
const [x1, y1] = l1.point;
const [x2, y2] = l2.point;
x3 = (x1 + x2) / 2;
y3 = (y1 + y2) / 2;
}
return [l1.point, [x3, y3], l2.point];
};

Expand Down Expand Up @@ -168,15 +224,31 @@ const toWorld = ([x, y]: Vec2): Vec3 => {
return matVecMul(world, [x, y, z]);
};

let point: Vec2;
let func: Func;
let point: Vec2 = [0.5, 0.5];
let info: Info;

const setPoint = (newPoint: Vec2) => {
point = newPoint;
info = all(...point);
info = func(...point);
};

setPoint([0.5, 0.5]);
const textbox = document.getElementById("textbox") as HTMLInputElement;
const setFunc = async () => {
let root: Expr = { kind: "const", val: NaN };
try {
root = parse(textbox.value);
textbox.classList.remove("error");
} catch (e) {
textbox.classList.add("error");
}
func = await autodiff(root);
setPoint(point);
};
await setFunc();
textbox.addEventListener("input", async () => {
await setFunc();
});

const roseColor = "#C33358";

Expand Down
142 changes: 142 additions & 0 deletions packages/site/src/math.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import { Dual, Real, add, div, fn, mul, neg, opaque, sqrt, sub } from "rose";

export const acos = opaque([Real], Real, Math.acos);
export const acosh = opaque([Real], Real, Math.acosh);
export const asin = opaque([Real], Real, Math.asin);
export const asinh = opaque([Real], Real, Math.asinh);
export const atan = opaque([Real], Real, Math.atan);
export const atanh = opaque([Real], Real, Math.atanh);
export const cbrt = opaque([Real], Real, Math.cbrt);
export const cos = opaque([Real], Real, Math.cos);
export const cosh = opaque([Real], Real, Math.cosh);
export const exp = opaque([Real], Real, Math.exp);
export const expm1 = opaque([Real], Real, Math.expm1);
export const log = opaque([Real], Real, Math.log);
export const log10 = opaque([Real], Real, Math.log10);
export const log1p = opaque([Real], Real, Math.log1p);
export const log2 = opaque([Real], Real, Math.log2);
export const pow = opaque([Real, Real], Real, Math.pow);
export const sin = opaque([Real], Real, Math.sin);
export const sinh = opaque([Real], Real, Math.sinh);
export const tan = opaque([Real], Real, Math.tan);
export const tanh = opaque([Real], Real, Math.tanh);

acos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = acos(x);
const dy = div(dx, neg(sqrt(sub(1, mul(x, x)))));
return { re: y, du: dy };
});

acosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = acosh(x);
const dy = div(dx, mul(sqrt(sub(x, 1)), sqrt(add(x, 1))));
return { re: y, du: dy };
});

asin.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }) => {
const y = asin(x);
const dy = div(dx, sqrt(sub(1, mul(x, x))));
return { re: y, du: dy };
});

asinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = asinh(x);
const dy = div(dx, sqrt(add(1, mul(x, x))));
return { re: y, du: dy };
});

atan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = atan(x);
const dy = div(dx, add(1, mul(x, x)));
return { re: y, du: dy };
});

atanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = atanh(x);
const dy = div(dx, sub(1, mul(x, x)));
return { re: y, du: dy };
});

cbrt.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = cbrt(x);
const dy = mul(dx, div(1 / 3, mul(y, y)));
return { re: y, du: dy };
});

cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = cos(x);
const dy = mul(dx, neg(sin(x)));
return { re: y, du: dy };
});

cosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = cosh(x);
const dy = mul(dx, sinh(x));
return { re: y, du: dy };
});

exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = exp(x);
const dy = mul(dx, y);
return { re: y, du: dy };
});

expm1.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = expm1(x);
const dy = mul(dx, add(y, 1));
return { re: y, du: dy };
});

log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = log(x);
const dy = div(dx, x);
return { re: y, du: dy };
});

log10.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = log10(x);
const dy = mul(dx, div(Math.LOG10E, x));
return { re: y, du: dy };
});

log1p.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = log1p(x);
const dy = div(dx, add(1, x));
return { re: y, du: dy };
});

log2.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = log2(x);
const dy = mul(dx, div(Math.LOG2E, x));
return { re: y, du: dy };
});

pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
const z = pow(x, y);
const dz = mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z);
return { re: z, du: dz };
});

sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = sin(x);
const dy = mul(dx, cos(x));
return { re: y, du: dy };
});

sinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = sinh(x);
const dy = mul(dx, cosh(x));
return { re: y, du: dy };
});

tan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = tan(x);
const dy = mul(dx, add(1, mul(y, y)));
return { re: y, du: dy };
});

tanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = tanh(x);
const dy = mul(dx, sub(1, mul(y, y)));
return { re: y, du: dy };
});
13 changes: 13 additions & 0 deletions packages/site/src/parse.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { add } from "rose";
import { expect, test } from "vitest";
import { Expr, parse } from "./parse.js";

test("add", () => {
const expected: Expr = {
kind: "binary",
f: add,
lhs: { kind: "var", idx: 0 },
rhs: { kind: "var", idx: 1 },
};
expect(parse("x+y")).toEqual(expected);
});
Loading

0 comments on commit 0d1dcd6

Please sign in to comment.