diff --git a/Makefile b/Makefile
index f1448b0..9abc29f 100644
--- a/Makefile
+++ b/Makefile
@@ -51,7 +51,7 @@ prettier: npm
packages: core site wasm
# run JavaScript tests
-test-js: test-core
+test-js: test-core test-site
## `packages/core`
@@ -68,9 +68,14 @@ test-core: npm wasm
site-deps: npm core
+# build
site: site-deps
npm run --workspace=@rose-lang/site build
+# test
+test-site: site-deps
+ npm run --workspace=@rose-lang/site test -- run --no-threads
+
## `packages/wasm`
# build
diff --git a/packages/site/index.html b/packages/site/index.html
index b7643c3..23d81b2 100644
--- a/packages/site/index.html
+++ b/packages/site/index.html
@@ -30,7 +30,7 @@
>
-
+
diff --git a/packages/site/package.json b/packages/site/package.json
index 2f0558e..c8cf33e 100644
--- a/packages/site/package.json
+++ b/packages/site/package.json
@@ -10,6 +10,7 @@
"scripts": {
"build": "vite build",
"dev": "vite",
- "preview": "vite preview"
+ "preview": "vite preview",
+ "test": "vitest"
}
}
diff --git a/packages/site/src/func.ts b/packages/site/src/func.ts
deleted file mode 100644
index 469aceb..0000000
--- a/packages/site/src/func.ts
+++ /dev/null
@@ -1,58 +0,0 @@
-import {
- Dual,
- Real,
- Vec,
- add,
- compile,
- div,
- fn,
- jvp,
- mul,
- opaque,
- vec,
- vjp,
-} from "rose";
-
-const log = opaque([Real], Real, Math.log);
-log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
- return { re: log(x), du: div(dx, x) };
-});
-
-const pow = opaque([Real, Real], Real, Math.pow);
-pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
- const z = pow(x, y);
- return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
-});
-
-const Vec2 = Vec(2, Real);
-const Mat2 = Vec(2, Vec2);
-
-const f = fn([Vec2], Real, ([x, y]) => pow(x, y));
-const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
-const h = fn([Vec2], Mat2, ([x, y]) => {
- const d = jvp(g);
- const a = d([
- { re: x, du: 1 },
- { re: y, du: 0 },
- ]);
- const b = d([
- { re: x, du: 0 },
- { re: y, du: 1 },
- ]);
- return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
-});
-
-type Vec2 = [number, number];
-
-export interface Info {
- val: number;
- grad: Vec2;
- hess: [Vec2, Vec2];
-}
-
-export default (await compile(
- fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
- const v = [x, y];
- return { val: f(v), grad: g(v), hess: h(v) };
- }),
-)) as unknown as (x: number, y: number) => Info;
diff --git a/packages/site/src/main.ts b/packages/site/src/main.ts
index e4aaf21..efb151d 100644
--- a/packages/site/src/main.ts
+++ b/packages/site/src/main.ts
@@ -1,7 +1,57 @@
-import all, { Info } from "./func.js";
+import { Real, Vec, compile, fn, jvp, vec, vjp } from "rose";
+import { Expr, parse } from "./parse.js";
type Vec2 = [number, number];
+interface Info {
+ val: number;
+ grad: Vec2;
+ hess: [Vec2, Vec2];
+}
+
+type Func = (x: number, y: number) => Info;
+
+const autodiff = async (root: Expr): Promise => {
+ const Vec2 = Vec(2, Real);
+ const f = fn([Vec2], Real, (v) => {
+ const emit = (e: Expr): Real => {
+ switch (e.kind) {
+ case "const":
+ return e.val;
+ case "var":
+ return v[e.idx];
+ case "unary":
+ return e.f(emit(e.arg));
+ case "binary":
+ return e.f(emit(e.lhs), emit(e.rhs));
+ }
+ };
+ return emit(root);
+ });
+
+ const Mat2 = Vec(2, Vec2);
+ const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
+ const h = fn([Vec2], Mat2, ([x, y]) => {
+ const d = jvp(g);
+ const a = d([
+ { re: x, du: 1 },
+ { re: y, du: 0 },
+ ]);
+ const b = d([
+ { re: x, du: 0 },
+ { re: y, du: 1 },
+ ]);
+ return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
+ });
+
+ return (await compile(
+ fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
+ const v = [x, y];
+ return { val: f(v), grad: g(v), hess: h(v) };
+ }),
+ )) as unknown as Func;
+};
+
interface Parabola {
/** coefficient of square term */
a: number;
@@ -75,7 +125,13 @@ const bezier = (
): [Vec2, Vec2, Vec2] => {
const l1 = pointSlope(parabola, x1);
const l2 = pointSlope(parabola, x2);
- const [x3, y3] = intersectPointSlope(l1, l2);
+ let [x3, y3] = intersectPointSlope(l1, l2);
+ if (!(Number.isFinite(x3) && Number.isFinite(y3))) {
+ const [x1, y1] = l1.point;
+ const [x2, y2] = l2.point;
+ x3 = (x1 + x2) / 2;
+ y3 = (y1 + y2) / 2;
+ }
return [l1.point, [x3, y3], l2.point];
};
@@ -168,15 +224,31 @@ const toWorld = ([x, y]: Vec2): Vec3 => {
return matVecMul(world, [x, y, z]);
};
-let point: Vec2;
+let func: Func;
+let point: Vec2 = [0.5, 0.5];
let info: Info;
const setPoint = (newPoint: Vec2) => {
point = newPoint;
- info = all(...point);
+ info = func(...point);
};
-setPoint([0.5, 0.5]);
+const textbox = document.getElementById("textbox") as HTMLInputElement;
+const setFunc = async () => {
+ let root: Expr = { kind: "const", val: NaN };
+ try {
+ root = parse(textbox.value);
+ textbox.classList.remove("error");
+ } catch (e) {
+ textbox.classList.add("error");
+ }
+ func = await autodiff(root);
+ setPoint(point);
+};
+await setFunc();
+textbox.addEventListener("input", async () => {
+ await setFunc();
+});
const roseColor = "#C33358";
diff --git a/packages/site/src/math.ts b/packages/site/src/math.ts
new file mode 100644
index 0000000..b16f904
--- /dev/null
+++ b/packages/site/src/math.ts
@@ -0,0 +1,142 @@
+import { Dual, Real, add, div, fn, mul, neg, opaque, sqrt, sub } from "rose";
+
+export const acos = opaque([Real], Real, Math.acos);
+export const acosh = opaque([Real], Real, Math.acosh);
+export const asin = opaque([Real], Real, Math.asin);
+export const asinh = opaque([Real], Real, Math.asinh);
+export const atan = opaque([Real], Real, Math.atan);
+export const atanh = opaque([Real], Real, Math.atanh);
+export const cbrt = opaque([Real], Real, Math.cbrt);
+export const cos = opaque([Real], Real, Math.cos);
+export const cosh = opaque([Real], Real, Math.cosh);
+export const exp = opaque([Real], Real, Math.exp);
+export const expm1 = opaque([Real], Real, Math.expm1);
+export const log = opaque([Real], Real, Math.log);
+export const log10 = opaque([Real], Real, Math.log10);
+export const log1p = opaque([Real], Real, Math.log1p);
+export const log2 = opaque([Real], Real, Math.log2);
+export const pow = opaque([Real, Real], Real, Math.pow);
+export const sin = opaque([Real], Real, Math.sin);
+export const sinh = opaque([Real], Real, Math.sinh);
+export const tan = opaque([Real], Real, Math.tan);
+export const tanh = opaque([Real], Real, Math.tanh);
+
+acos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = acos(x);
+ const dy = div(dx, neg(sqrt(sub(1, mul(x, x)))));
+ return { re: y, du: dy };
+});
+
+acosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = acosh(x);
+ const dy = div(dx, mul(sqrt(sub(x, 1)), sqrt(add(x, 1))));
+ return { re: y, du: dy };
+});
+
+asin.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }) => {
+ const y = asin(x);
+ const dy = div(dx, sqrt(sub(1, mul(x, x))));
+ return { re: y, du: dy };
+});
+
+asinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = asinh(x);
+ const dy = div(dx, sqrt(add(1, mul(x, x))));
+ return { re: y, du: dy };
+});
+
+atan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = atan(x);
+ const dy = div(dx, add(1, mul(x, x)));
+ return { re: y, du: dy };
+});
+
+atanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = atanh(x);
+ const dy = div(dx, sub(1, mul(x, x)));
+ return { re: y, du: dy };
+});
+
+cbrt.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = cbrt(x);
+ const dy = mul(dx, div(1 / 3, mul(y, y)));
+ return { re: y, du: dy };
+});
+
+cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = cos(x);
+ const dy = mul(dx, neg(sin(x)));
+ return { re: y, du: dy };
+});
+
+cosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = cosh(x);
+ const dy = mul(dx, sinh(x));
+ return { re: y, du: dy };
+});
+
+exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = exp(x);
+ const dy = mul(dx, y);
+ return { re: y, du: dy };
+});
+
+expm1.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = expm1(x);
+ const dy = mul(dx, add(y, 1));
+ return { re: y, du: dy };
+});
+
+log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = log(x);
+ const dy = div(dx, x);
+ return { re: y, du: dy };
+});
+
+log10.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = log10(x);
+ const dy = mul(dx, div(Math.LOG10E, x));
+ return { re: y, du: dy };
+});
+
+log1p.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = log1p(x);
+ const dy = div(dx, add(1, x));
+ return { re: y, du: dy };
+});
+
+log2.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = log2(x);
+ const dy = mul(dx, div(Math.LOG2E, x));
+ return { re: y, du: dy };
+});
+
+pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
+ const z = pow(x, y);
+ const dz = mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z);
+ return { re: z, du: dz };
+});
+
+sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = sin(x);
+ const dy = mul(dx, cos(x));
+ return { re: y, du: dy };
+});
+
+sinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = sinh(x);
+ const dy = mul(dx, cosh(x));
+ return { re: y, du: dy };
+});
+
+tan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = tan(x);
+ const dy = mul(dx, add(1, mul(y, y)));
+ return { re: y, du: dy };
+});
+
+tanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
+ const y = tanh(x);
+ const dy = mul(dx, sub(1, mul(y, y)));
+ return { re: y, du: dy };
+});
diff --git a/packages/site/src/parse.test.ts b/packages/site/src/parse.test.ts
new file mode 100644
index 0000000..83923a2
--- /dev/null
+++ b/packages/site/src/parse.test.ts
@@ -0,0 +1,13 @@
+import { add } from "rose";
+import { expect, test } from "vitest";
+import { Expr, parse } from "./parse.js";
+
+test("add", () => {
+ const expected: Expr = {
+ kind: "binary",
+ f: add,
+ lhs: { kind: "var", idx: 0 },
+ rhs: { kind: "var", idx: 1 },
+ };
+ expect(parse("x+y")).toEqual(expected);
+});
diff --git a/packages/site/src/parse.ts b/packages/site/src/parse.ts
new file mode 100644
index 0000000..f43dc43
--- /dev/null
+++ b/packages/site/src/parse.ts
@@ -0,0 +1,217 @@
+import {
+ Real,
+ abs,
+ add,
+ ceil,
+ div,
+ floor,
+ mul,
+ neg,
+ sign,
+ sqrt,
+ sub,
+ trunc,
+} from "rose";
+import {
+ acos,
+ acosh,
+ asin,
+ asinh,
+ atan,
+ atanh,
+ cbrt,
+ cos,
+ cosh,
+ exp,
+ expm1,
+ log,
+ log10,
+ log1p,
+ log2,
+ pow,
+ sin,
+ sinh,
+ tan,
+ tanh,
+} from "./math.js";
+
+const unaries = {
+ abs,
+ acos,
+ acosh,
+ asin,
+ asinh,
+ atan,
+ atanh,
+ cbrt,
+ ceil,
+ cos,
+ cosh,
+ exp,
+ expm1,
+ floor,
+ log,
+ log10,
+ log1p,
+ log2,
+ sign,
+ sin,
+ sinh,
+ sqrt,
+ tan,
+ tanh,
+ trunc,
+};
+
+const unary = (name: string): ((x: Real) => Real) => {
+ if (name in unaries) return unaries[name as keyof typeof unaries];
+ throw Error(`unknown unary function: ${name}`);
+};
+
+type Token =
+ | { kind: "end" }
+ | { kind: "const"; val: number }
+ | { kind: "name"; val: string }
+ | { kind: "+" }
+ | { kind: "-" }
+ | { kind: "*" }
+ | { kind: "/" }
+ | { kind: "^" }
+ | { kind: "(" }
+ | { kind: ")" };
+
+class Lexer {
+ s: string;
+
+ constructor(s: string) {
+ this.s = s;
+ }
+
+ token(): Token {
+ this.s = this.s.trimStart();
+ if (this.s.length === 0) return { kind: "end" };
+ {
+ const m = this.s.match(/^[0-9]+(?:\.[0-9]+)?\b/);
+ if (m) {
+ this.s = this.s.slice(m[0].length);
+ return { kind: "const", val: Number(m[0]) };
+ }
+ }
+ {
+ const m = this.s.match(/^[A-Z_a-z][0-9A-Z_a-z]*/);
+ if (m) {
+ this.s = this.s.slice(m[0].length);
+ return { kind: "name", val: m[0] };
+ }
+ }
+ {
+ const m = this.s.match(/^[+\-*/^()]/);
+ if (m) {
+ this.s = this.s.slice(m[0].length);
+ return { kind: m[0] as any };
+ }
+ }
+ throw Error(`can't tokenize: ${this.s}`);
+ }
+}
+
+function* lex(s: string) {
+ const lexer = new Lexer(s);
+ while (true) {
+ const tok = lexer.token();
+ yield tok;
+ if (tok.kind === "end") break;
+ }
+}
+
+export type Expr =
+ | { kind: "const"; val: number }
+ | { kind: "var"; idx: number }
+ | { kind: "unary"; f: (x: Real) => Real; arg: Expr }
+ | { kind: "binary"; f: (x: Real, y: Real) => Real; lhs: Expr; rhs: Expr };
+
+class Parser {
+ tokens: Token[];
+
+ constructor(tokens: Token[]) {
+ this.tokens = tokens;
+ }
+
+ peek(): Token {
+ return this.tokens[this.tokens.length - 1];
+ }
+
+ pop(): Token {
+ const tok = this.tokens.pop();
+ if (!tok) throw Error("unexpected end of input");
+ return tok;
+ }
+
+ parseAtom(): Expr {
+ const tok = this.pop();
+ switch (tok.kind) {
+ case "const":
+ return { kind: "const", val: tok.val };
+ case "name": {
+ if (tok.val === "x") return { kind: "var", idx: 0 };
+ if (tok.val === "y") return { kind: "var", idx: 1 };
+ const f = unary(tok.val);
+ const arg = this.parseAtom();
+ return { kind: "unary", f, arg };
+ }
+ case "(": {
+ const x = this.parseExpr();
+ const tok = this.pop();
+ if (tok.kind !== ")") throw Error("expected )");
+ return x;
+ }
+ default:
+ throw Error(`unexpected token: ${tok.kind}`);
+ }
+ }
+
+ parseFactor(): Expr {
+ if (this.peek().kind === "-") {
+ this.pop();
+ return { kind: "unary", f: neg, arg: this.parseFactor() };
+ }
+ const x = this.parseAtom();
+ if (this.peek().kind === "^") {
+ this.pop();
+ return { kind: "binary", f: pow, lhs: x, rhs: this.parseFactor() };
+ }
+ return x;
+ }
+
+ parseTerm(): Expr {
+ let x = this.parseFactor();
+ let tok = this.peek();
+ while (tok.kind === "*" || tok.kind === "/") {
+ this.pop();
+ const f = { "*": mul, "/": div }[tok.kind];
+ x = { kind: "binary", f, lhs: x, rhs: this.parseFactor() };
+ tok = this.peek();
+ }
+ return x;
+ }
+
+ parseExpr(): Expr {
+ let x = this.parseTerm();
+ let tok = this.peek();
+ while (tok.kind === "+" || tok.kind === "-") {
+ this.pop();
+ const f = { "+": add, "-": sub }[tok.kind];
+ x = { kind: "binary", f, lhs: x, rhs: this.parseTerm() };
+ tok = this.peek();
+ }
+ return x;
+ }
+}
+
+export const parse = (s: string): Expr => {
+ const parser = new Parser([...lex(s)].reverse());
+ const expr = parser.parseExpr();
+ if (parser.pop().kind !== "end") throw Error("expected end of input");
+ if (parser.tokens.length !== 0) throw Error("unexpected tokens after end");
+ return expr;
+};
diff --git a/packages/site/style.css b/packages/site/style.css
index b66099e..844054f 100644
--- a/packages/site/style.css
+++ b/packages/site/style.css
@@ -71,6 +71,10 @@ body {
outline: 2px solid var(--color-link-dark-hover);
}
+.error {
+ background-color: hsl(345, 50%, 20%);
+}
+
.bottom {
position: fixed;
bottom: 20px;