Skip to content

Commit

Permalink
Make vector and struct types easier to use
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Oct 10, 2023
1 parent 18a2092 commit 659a0b1
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
3 changes: 2 additions & 1 deletion packages/core/src/impl.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
mul,
neg,
opaque,
struct,
vjp,
} from "./impl.js";

Expand Down Expand Up @@ -56,7 +57,7 @@ fn f0 = <>{
return { re: cos(x), du: mul(dx, neg(sin(x))) };
});

const Complex = { re: Real, im: Real } as const;
const Complex = struct({ re: Real, im: Real });

const complexp = fn([Complex], Complex, (z) => {
const c = exp(z.re);
Expand Down
7 changes: 5 additions & 2 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ const ind = Symbol("index");
const elm = Symbol("elem");

/** Representation of a vector type. */
interface Vecs<K, V> {
export interface Vecs<K, V> {
[ind]: K;
[elm]: V;
}
Expand All @@ -274,8 +274,11 @@ export const Vec = <K, V>(index: K, elem: V): Vecs<K, V> => {
return { [ind]: index, [elm]: elem };
};

/** Create a struct type. */
export const struct = <const T>(t: T): T => t;

/** The 128-bit floating-point dual number type. */
export const Dual = { re: Real, du: Tan } as const;
export const Dual = struct({ re: Real, du: Tan });

// TODO: make this locale-independent
const compare = (a: string, b: string): number => a.localeCompare(b);
Expand Down
25 changes: 13 additions & 12 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
select,
sign,
sqrt,
struct,
sub,
trunc,
vec,
Expand Down Expand Up @@ -341,7 +342,7 @@ describe("valid", () => {
});

test("struct", () => {
const Pair = { x: Real, y: Real } as const;
const Pair = struct({ x: Real, y: Real });
const f = fn([Pair], Real, (p) => sub(p.y, p.x));
const g = fn([Real, Real], Pair, (x, y) => ({ y, x }));
const h = interp(fn([Real, Real], Real, (x, y) => f(g(x, y))));
Expand All @@ -364,7 +365,7 @@ describe("valid", () => {

test("array of structs", () => {
const n = 2;
const Indexed = { i: n, x: Real } as const;
const Indexed = struct({ i: n, x: Real });
const f = fn([Vec(n, Real)], Vec(n, Indexed), (v) =>
vec(n, Indexed, (i) => ({ i, x: v[i] })),
);
Expand Down Expand Up @@ -480,7 +481,7 @@ describe("valid", () => {
});

test("VJP with struct and select", () => {
const Stuff = { a: Null, b: Bool, c: Real } as const;
const Stuff = struct({ a: Null, b: Bool, c: Real });
const f = fn([Stuff], Real, ({ b, c }) =>
select(or(false, not(b)), Real, c, 2),
);
Expand Down Expand Up @@ -574,15 +575,15 @@ describe("valid", () => {
});

test("VJP twice with struct", () => {
const Pair = { x: Real, y: Real } as const;
const Pair = struct({ x: Real, y: Real });
const f = fn([Pair], Real, ({ x, y }) => mul(x, y));
const g = fn([Pair], Pair, (p) => vjp(f)(p).grad(1));
const h = fn([Pair, Pair], Pair, (p, q) => vjp(g)(p).grad(q));
expect(interp(h)({ x: 2, y: 3 }, { x: 5, y: 7 })).toEqual({ x: 7, y: 5 });
});

test("VJP twice with select", () => {
const Stuff = { p: Bool, x: Real, y: Real, z: Real } as const;
const Stuff = struct({ p: Bool, x: Real, y: Real, z: Real });
const f = fn([Stuff], Real, ({ p, x, y, z }) =>
mul(z, select(p, Real, x, y)),
);
Expand Down Expand Up @@ -744,7 +745,7 @@ describe("valid", () => {

test("compile VJP", async () => {
const f = fn(
[Vec(2, { p: Bool, x: Real } as const)],
[Vec(2, struct({ p: Bool, x: Real }))],
{ p: Vec(2, Bool), x: Vec(2, Real) },
(v) => ({
p: vec(2, Bool, (i) => not(v[i].p)),
Expand Down Expand Up @@ -821,14 +822,14 @@ describe("valid", () => {
});

test("compile structs in signature", async () => {
const Pair = { x: Real, y: Real } as const;
const Pair = struct({ x: Real, y: Real });
const f = fn([Pair], Pair, ({ x, y }) => ({ x: y, y: x }));
const g = await compile(f);
expect(g({ x: 2, y: 3 })).toEqual({ x: 3, y: 2 });
});

test("compile zero-sized struct members in signature", async () => {
const Stuff = { a: Null, b: 0, c: 0, d: Null } as const;
const Stuff = struct({ a: Null, b: 0, c: 0, d: Null });
const f = fn([Stuff], Stuff, ({ a, b, c, d }) => {
return { a: d, b: c, c: b, d: a };
});
Expand All @@ -838,8 +839,8 @@ describe("valid", () => {
});

test("compile nested structs in signature", async () => {
const Pair = { x: Real, y: Real } as const;
const Stuff = { p: Bool, q: Pair } as const;
const Pair = struct({ x: Real, y: Real });
const Stuff = struct({ p: Bool, q: Pair });
const f = fn([Stuff], Stuff, ({ p, q }) => ({
p: not(p),
q: { x: q.y, y: q.x },
Expand All @@ -854,7 +855,7 @@ describe("valid", () => {
test("compile big structs in signature", async () => {
const M = 300;
const N = 70000;
const Stuff = {
const Stuff = struct({
a: Real,
b: N,
c: Real,
Expand All @@ -876,7 +877,7 @@ describe("valid", () => {
s: Real,
t: Real,
u: Real,
} as const;
});
const f = fn(
[Stuff],
Stuff,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export {
Real,
Tan,
Vec,
Vecs,
abs,
add,
and,
Expand All @@ -31,6 +32,7 @@ export {
select,
sign,
sqrt,
struct,
sub,
trunc,
vec,
Expand Down

0 comments on commit 659a0b1

Please sign in to comment.