Skip to content

Commit

Permalink
Draw some arrows
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Nov 16, 2023
1 parent d0f7211 commit c95eefa
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 53 deletions.
50 changes: 50 additions & 0 deletions packages/site/src/func.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import {
Dual,
Real,
Vec,
add,
compile,
div,
fn,
jvp,
mul,
opaque,
vec,
vjp,
} from "rose";

const log = opaque([Real], Real, Math.log);
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
return { re: log(x), du: div(dx, x) };
});

const pow = opaque([Real, Real], Real, Math.pow);
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
const z = pow(x, y);
return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
});

const Vec2 = Vec(2, Real);
const Mat2 = Vec(2, Vec2);

const f = fn([Vec2], Real, ([x, y]) => pow(x, y));
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
const h = fn([Vec2], Mat2, ([x, y]) => {
const d = jvp(g);
const a = d([
{ re: x, du: 1 },
{ re: y, du: 0 },
]);
const b = d([
{ re: x, du: 0 },
{ re: y, du: 1 },
]);
return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
});

export default await compile(
fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
const v = [x, y];
return { val: f(v), grad: g(v), hess: h(v) };
}),
);
150 changes: 97 additions & 53 deletions packages/site/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,70 +1,114 @@
import {
Dual,
Real,
Vec,
add,
compile,
div,
fn,
jvp,
mul,
opaque,
vec,
vjp,
} from "rose";

const log = opaque([Real], Real, Math.log);
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
return { re: log(x), du: div(dx, x) };
});

const pow = opaque([Real, Real], Real, Math.pow);
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
const z = pow(x, y);
return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) };
});

const Vec2 = Vec(2, Real);
const Mat2 = Vec(2, Vec2);

const f = fn([Vec2], Real, ([x, y]) => pow(x, y));
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
const h = fn([Vec2], Mat2, ([x, y]) => {
const d = jvp(g);
const a = d([
{ re: x, du: 1 },
{ re: y, du: 0 },
]);
const b = d([
{ re: x, du: 0 },
{ re: y, du: 1 },
]);
return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
});

const all = await compile(
fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
const v = [x, y];
return { val: f(v), grad: g(v), hess: h(v) };
}),
);
import all from "./func.js";

console.log(all(2, 3));

type Vec2 = [number, number];
type Vec3 = number[];

type Mat3x3 = Vec3[];

const matVecMul = (a: Mat3x3, b: Vec3): Vec3 => {
const c: Vec3 = [0, 0, 0];
for (let i = 0; i < 3; ++i) {
for (let j = 0; j < 3; ++j) {
c[i] += a[i][j] * b[j];
}
}
return c;
};

const matMul = (a: Mat3x3, b: Mat3x3): Mat3x3 => {
const c: Mat3x3 = [
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
];
for (let i = 0; i < 3; ++i) {
for (let j = 0; j < 3; ++j) {
for (let k = 0; k < 3; ++k) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
return c;
};

const identity: Mat3x3 = [
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
];

const scale = (s: number): Mat3x3 => [
[s, 0, 0],
[0, s, 0],
[0, 0, s],
];

const rotateX = (theta: number): Mat3x3 => [
[1, 0, 0],
[0, Math.cos(theta), -Math.sin(theta)],
[0, Math.sin(theta), Math.cos(theta)],
];

const rotateZ = (theta: number): Mat3x3 => [
[Math.cos(theta), -Math.sin(theta), 0],
[Math.sin(theta), Math.cos(theta), 0],
[0, 0, 1],
];

const screen = [scale(150), rotateX(-1), rotateZ((-3 * Math.PI) / 4)].reduce(
matMul,
);

const toScreen = (v: Vec3): Vec2 => {
const [x, y] = matVecMul(screen, v);
return [x, y];
};

const canvas = document.getElementById("canvas") as HTMLCanvasElement;
const { width, height } = canvas;
const ctx = canvas.getContext("2d")!;

const poly = (points: Vec3[]) => {
ctx.moveTo(...toScreen(points[0]));
for (let i = 1; i < points.length; ++i) {
ctx.lineTo(...toScreen(points[i]));
}
};

const lineHalf = 0.015;

const arrowLen = 0.1;
const arrowHalf = 0.05;

const draw: FrameRequestCallback = (milliseconds) => {
ctx.resetTransform();
ctx.clearRect(0, 0, width, height);

ctx.translate(width / 2, height / 2);
ctx.scale(1, -1);
ctx.rotate(milliseconds / 1000);

ctx.fillStyle = "#c7254e";
ctx.fillRect(-100, -100, 200, 200);
const pulse = Math.sin(milliseconds / 1000) / 10;

ctx.fillStyle = "white";
ctx.beginPath();
poly([
[1 + pulse + arrowLen, 0, 0],
[1 + pulse, arrowHalf, 0],
[1 + pulse, lineHalf, 0],
[lineHalf, lineHalf, 0],
[lineHalf, 1 + pulse, 0],
[arrowHalf, 1 + pulse, 0],
[0, 1 + pulse + arrowLen, 0],
[-arrowHalf, 1 + pulse, 0],
[-lineHalf, 1 + pulse, 0],
[-lineHalf, -lineHalf, 0],
[1 + pulse, -lineHalf, 0],
[1 + pulse, -arrowHalf, 0],
]);
ctx.closePath();
ctx.fill();

window.requestAnimationFrame(draw);
};
Expand Down

0 comments on commit c95eefa

Please sign in to comment.