From 0d1dcd6ee8ab6dd44f8f50ffd8d5121591707cef Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 17 Nov 2023 14:38:38 -0500 Subject: [PATCH] Parse math expressions --- Makefile | 7 +- packages/site/index.html | 2 +- packages/site/package.json | 3 +- packages/site/src/func.ts | 58 --------- packages/site/src/main.ts | 82 +++++++++++- packages/site/src/math.ts | 142 +++++++++++++++++++++ packages/site/src/parse.test.ts | 13 ++ packages/site/src/parse.ts | 217 ++++++++++++++++++++++++++++++++ packages/site/style.css | 4 + 9 files changed, 462 insertions(+), 66 deletions(-) delete mode 100644 packages/site/src/func.ts create mode 100644 packages/site/src/math.ts create mode 100644 packages/site/src/parse.test.ts create mode 100644 packages/site/src/parse.ts diff --git a/Makefile b/Makefile index f1448b0..9abc29f 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ prettier: npm packages: core site wasm # run JavaScript tests -test-js: test-core +test-js: test-core test-site ## `packages/core` @@ -68,9 +68,14 @@ test-core: npm wasm site-deps: npm core +# build site: site-deps npm run --workspace=@rose-lang/site build +# test +test-site: site-deps + npm run --workspace=@rose-lang/site test -- run --no-threads + ## `packages/wasm` # build diff --git a/packages/site/index.html b/packages/site/index.html index b7643c3..23d81b2 100644 --- a/packages/site/index.html +++ b/packages/site/index.html @@ -30,7 +30,7 @@ >
- +
diff --git a/packages/site/package.json b/packages/site/package.json index 2f0558e..c8cf33e 100644 --- a/packages/site/package.json +++ b/packages/site/package.json @@ -10,6 +10,7 @@ "scripts": { "build": "vite build", "dev": "vite", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest" } } diff --git a/packages/site/src/func.ts b/packages/site/src/func.ts deleted file mode 100644 index 469aceb..0000000 --- a/packages/site/src/func.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { - Dual, - Real, - Vec, - add, - compile, - div, - fn, - jvp, - mul, - opaque, - vec, - vjp, -} from "rose"; - -const log = opaque([Real], Real, Math.log); -log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: log(x), du: div(dx, x) }; -}); - -const pow = opaque([Real, Real], Real, Math.pow); -pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => { - const z = pow(x, y); - return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) }; -}); - -const Vec2 = Vec(2, Real); -const Mat2 = Vec(2, Vec2); - -const f = fn([Vec2], Real, ([x, y]) => pow(x, y)); -const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1)); -const h = fn([Vec2], Mat2, ([x, y]) => { - const d = jvp(g); - const a = d([ - { re: x, du: 1 }, - { re: y, du: 0 }, - ]); - const b = d([ - { re: x, du: 0 }, - { re: y, du: 1 }, - ]); - return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)]; -}); - -type Vec2 = [number, number]; - -export interface Info { - val: number; - grad: Vec2; - hess: [Vec2, Vec2]; -} - -export default (await compile( - fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => { - const v = [x, y]; - return { val: f(v), grad: g(v), hess: h(v) }; - }), -)) as unknown as (x: number, y: number) => Info; diff --git a/packages/site/src/main.ts b/packages/site/src/main.ts index e4aaf21..efb151d 100644 --- a/packages/site/src/main.ts +++ b/packages/site/src/main.ts @@ -1,7 +1,57 @@ -import all, { Info } from "./func.js"; +import { Real, Vec, compile, fn, jvp, vec, vjp } from "rose"; +import { Expr, parse } from "./parse.js"; type Vec2 = [number, number]; +interface Info { + val: number; + grad: Vec2; + hess: [Vec2, Vec2]; +} + +type Func = (x: number, y: number) => Info; + +const autodiff = async (root: Expr): Promise => { + const Vec2 = Vec(2, Real); + const f = fn([Vec2], Real, (v) => { + const emit = (e: Expr): Real => { + switch (e.kind) { + case "const": + return e.val; + case "var": + return v[e.idx]; + case "unary": + return e.f(emit(e.arg)); + case "binary": + return e.f(emit(e.lhs), emit(e.rhs)); + } + }; + return emit(root); + }); + + const Mat2 = Vec(2, Vec2); + const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1)); + const h = fn([Vec2], Mat2, ([x, y]) => { + const d = jvp(g); + const a = d([ + { re: x, du: 1 }, + { re: y, du: 0 }, + ]); + const b = d([ + { re: x, du: 0 }, + { re: y, du: 1 }, + ]); + return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)]; + }); + + return (await compile( + fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => { + const v = [x, y]; + return { val: f(v), grad: g(v), hess: h(v) }; + }), + )) as unknown as Func; +}; + interface Parabola { /** coefficient of square term */ a: number; @@ -75,7 +125,13 @@ const bezier = ( ): [Vec2, Vec2, Vec2] => { const l1 = pointSlope(parabola, x1); const l2 = pointSlope(parabola, x2); - const [x3, y3] = intersectPointSlope(l1, l2); + let [x3, y3] = intersectPointSlope(l1, l2); + if (!(Number.isFinite(x3) && Number.isFinite(y3))) { + const [x1, y1] = l1.point; + const [x2, y2] = l2.point; + x3 = (x1 + x2) / 2; + y3 = (y1 + y2) / 2; + } return [l1.point, [x3, y3], l2.point]; }; @@ -168,15 +224,31 @@ const toWorld = ([x, y]: Vec2): Vec3 => { return matVecMul(world, [x, y, z]); }; -let point: Vec2; +let func: Func; +let point: Vec2 = [0.5, 0.5]; let info: Info; const setPoint = (newPoint: Vec2) => { point = newPoint; - info = all(...point); + info = func(...point); }; -setPoint([0.5, 0.5]); +const textbox = document.getElementById("textbox") as HTMLInputElement; +const setFunc = async () => { + let root: Expr = { kind: "const", val: NaN }; + try { + root = parse(textbox.value); + textbox.classList.remove("error"); + } catch (e) { + textbox.classList.add("error"); + } + func = await autodiff(root); + setPoint(point); +}; +await setFunc(); +textbox.addEventListener("input", async () => { + await setFunc(); +}); const roseColor = "#C33358"; diff --git a/packages/site/src/math.ts b/packages/site/src/math.ts new file mode 100644 index 0000000..b16f904 --- /dev/null +++ b/packages/site/src/math.ts @@ -0,0 +1,142 @@ +import { Dual, Real, add, div, fn, mul, neg, opaque, sqrt, sub } from "rose"; + +export const acos = opaque([Real], Real, Math.acos); +export const acosh = opaque([Real], Real, Math.acosh); +export const asin = opaque([Real], Real, Math.asin); +export const asinh = opaque([Real], Real, Math.asinh); +export const atan = opaque([Real], Real, Math.atan); +export const atanh = opaque([Real], Real, Math.atanh); +export const cbrt = opaque([Real], Real, Math.cbrt); +export const cos = opaque([Real], Real, Math.cos); +export const cosh = opaque([Real], Real, Math.cosh); +export const exp = opaque([Real], Real, Math.exp); +export const expm1 = opaque([Real], Real, Math.expm1); +export const log = opaque([Real], Real, Math.log); +export const log10 = opaque([Real], Real, Math.log10); +export const log1p = opaque([Real], Real, Math.log1p); +export const log2 = opaque([Real], Real, Math.log2); +export const pow = opaque([Real, Real], Real, Math.pow); +export const sin = opaque([Real], Real, Math.sin); +export const sinh = opaque([Real], Real, Math.sinh); +export const tan = opaque([Real], Real, Math.tan); +export const tanh = opaque([Real], Real, Math.tanh); + +acos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = acos(x); + const dy = div(dx, neg(sqrt(sub(1, mul(x, x))))); + return { re: y, du: dy }; +}); + +acosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = acosh(x); + const dy = div(dx, mul(sqrt(sub(x, 1)), sqrt(add(x, 1)))); + return { re: y, du: dy }; +}); + +asin.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }) => { + const y = asin(x); + const dy = div(dx, sqrt(sub(1, mul(x, x)))); + return { re: y, du: dy }; +}); + +asinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = asinh(x); + const dy = div(dx, sqrt(add(1, mul(x, x)))); + return { re: y, du: dy }; +}); + +atan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = atan(x); + const dy = div(dx, add(1, mul(x, x))); + return { re: y, du: dy }; +}); + +atanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = atanh(x); + const dy = div(dx, sub(1, mul(x, x))); + return { re: y, du: dy }; +}); + +cbrt.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = cbrt(x); + const dy = mul(dx, div(1 / 3, mul(y, y))); + return { re: y, du: dy }; +}); + +cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = cos(x); + const dy = mul(dx, neg(sin(x))); + return { re: y, du: dy }; +}); + +cosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = cosh(x); + const dy = mul(dx, sinh(x)); + return { re: y, du: dy }; +}); + +exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = exp(x); + const dy = mul(dx, y); + return { re: y, du: dy }; +}); + +expm1.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = expm1(x); + const dy = mul(dx, add(y, 1)); + return { re: y, du: dy }; +}); + +log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = log(x); + const dy = div(dx, x); + return { re: y, du: dy }; +}); + +log10.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = log10(x); + const dy = mul(dx, div(Math.LOG10E, x)); + return { re: y, du: dy }; +}); + +log1p.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = log1p(x); + const dy = div(dx, add(1, x)); + return { re: y, du: dy }; +}); + +log2.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = log2(x); + const dy = mul(dx, div(Math.LOG2E, x)); + return { re: y, du: dy }; +}); + +pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => { + const z = pow(x, y); + const dz = mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z); + return { re: z, du: dz }; +}); + +sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = sin(x); + const dy = mul(dx, cos(x)); + return { re: y, du: dy }; +}); + +sinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = sinh(x); + const dy = mul(dx, cosh(x)); + return { re: y, du: dy }; +}); + +tan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = tan(x); + const dy = mul(dx, add(1, mul(y, y))); + return { re: y, du: dy }; +}); + +tanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = tanh(x); + const dy = mul(dx, sub(1, mul(y, y))); + return { re: y, du: dy }; +}); diff --git a/packages/site/src/parse.test.ts b/packages/site/src/parse.test.ts new file mode 100644 index 0000000..83923a2 --- /dev/null +++ b/packages/site/src/parse.test.ts @@ -0,0 +1,13 @@ +import { add } from "rose"; +import { expect, test } from "vitest"; +import { Expr, parse } from "./parse.js"; + +test("add", () => { + const expected: Expr = { + kind: "binary", + f: add, + lhs: { kind: "var", idx: 0 }, + rhs: { kind: "var", idx: 1 }, + }; + expect(parse("x+y")).toEqual(expected); +}); diff --git a/packages/site/src/parse.ts b/packages/site/src/parse.ts new file mode 100644 index 0000000..f43dc43 --- /dev/null +++ b/packages/site/src/parse.ts @@ -0,0 +1,217 @@ +import { + Real, + abs, + add, + ceil, + div, + floor, + mul, + neg, + sign, + sqrt, + sub, + trunc, +} from "rose"; +import { + acos, + acosh, + asin, + asinh, + atan, + atanh, + cbrt, + cos, + cosh, + exp, + expm1, + log, + log10, + log1p, + log2, + pow, + sin, + sinh, + tan, + tanh, +} from "./math.js"; + +const unaries = { + abs, + acos, + acosh, + asin, + asinh, + atan, + atanh, + cbrt, + ceil, + cos, + cosh, + exp, + expm1, + floor, + log, + log10, + log1p, + log2, + sign, + sin, + sinh, + sqrt, + tan, + tanh, + trunc, +}; + +const unary = (name: string): ((x: Real) => Real) => { + if (name in unaries) return unaries[name as keyof typeof unaries]; + throw Error(`unknown unary function: ${name}`); +}; + +type Token = + | { kind: "end" } + | { kind: "const"; val: number } + | { kind: "name"; val: string } + | { kind: "+" } + | { kind: "-" } + | { kind: "*" } + | { kind: "/" } + | { kind: "^" } + | { kind: "(" } + | { kind: ")" }; + +class Lexer { + s: string; + + constructor(s: string) { + this.s = s; + } + + token(): Token { + this.s = this.s.trimStart(); + if (this.s.length === 0) return { kind: "end" }; + { + const m = this.s.match(/^[0-9]+(?:\.[0-9]+)?\b/); + if (m) { + this.s = this.s.slice(m[0].length); + return { kind: "const", val: Number(m[0]) }; + } + } + { + const m = this.s.match(/^[A-Z_a-z][0-9A-Z_a-z]*/); + if (m) { + this.s = this.s.slice(m[0].length); + return { kind: "name", val: m[0] }; + } + } + { + const m = this.s.match(/^[+\-*/^()]/); + if (m) { + this.s = this.s.slice(m[0].length); + return { kind: m[0] as any }; + } + } + throw Error(`can't tokenize: ${this.s}`); + } +} + +function* lex(s: string) { + const lexer = new Lexer(s); + while (true) { + const tok = lexer.token(); + yield tok; + if (tok.kind === "end") break; + } +} + +export type Expr = + | { kind: "const"; val: number } + | { kind: "var"; idx: number } + | { kind: "unary"; f: (x: Real) => Real; arg: Expr } + | { kind: "binary"; f: (x: Real, y: Real) => Real; lhs: Expr; rhs: Expr }; + +class Parser { + tokens: Token[]; + + constructor(tokens: Token[]) { + this.tokens = tokens; + } + + peek(): Token { + return this.tokens[this.tokens.length - 1]; + } + + pop(): Token { + const tok = this.tokens.pop(); + if (!tok) throw Error("unexpected end of input"); + return tok; + } + + parseAtom(): Expr { + const tok = this.pop(); + switch (tok.kind) { + case "const": + return { kind: "const", val: tok.val }; + case "name": { + if (tok.val === "x") return { kind: "var", idx: 0 }; + if (tok.val === "y") return { kind: "var", idx: 1 }; + const f = unary(tok.val); + const arg = this.parseAtom(); + return { kind: "unary", f, arg }; + } + case "(": { + const x = this.parseExpr(); + const tok = this.pop(); + if (tok.kind !== ")") throw Error("expected )"); + return x; + } + default: + throw Error(`unexpected token: ${tok.kind}`); + } + } + + parseFactor(): Expr { + if (this.peek().kind === "-") { + this.pop(); + return { kind: "unary", f: neg, arg: this.parseFactor() }; + } + const x = this.parseAtom(); + if (this.peek().kind === "^") { + this.pop(); + return { kind: "binary", f: pow, lhs: x, rhs: this.parseFactor() }; + } + return x; + } + + parseTerm(): Expr { + let x = this.parseFactor(); + let tok = this.peek(); + while (tok.kind === "*" || tok.kind === "/") { + this.pop(); + const f = { "*": mul, "/": div }[tok.kind]; + x = { kind: "binary", f, lhs: x, rhs: this.parseFactor() }; + tok = this.peek(); + } + return x; + } + + parseExpr(): Expr { + let x = this.parseTerm(); + let tok = this.peek(); + while (tok.kind === "+" || tok.kind === "-") { + this.pop(); + const f = { "+": add, "-": sub }[tok.kind]; + x = { kind: "binary", f, lhs: x, rhs: this.parseTerm() }; + tok = this.peek(); + } + return x; + } +} + +export const parse = (s: string): Expr => { + const parser = new Parser([...lex(s)].reverse()); + const expr = parser.parseExpr(); + if (parser.pop().kind !== "end") throw Error("expected end of input"); + if (parser.tokens.length !== 0) throw Error("unexpected tokens after end"); + return expr; +}; diff --git a/packages/site/style.css b/packages/site/style.css index b66099e..844054f 100644 --- a/packages/site/style.css +++ b/packages/site/style.css @@ -71,6 +71,10 @@ body { outline: 2px solid var(--color-link-dark-hover); } +.error { + background-color: hsl(345, 50%, 20%); +} + .bottom { position: fixed; bottom: 20px;