Skip to content

Commit

Permalink
Adds support for .each
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Dec 21, 2024
1 parent 70e94f5 commit 8a7947c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 69 deletions.
63 changes: 46 additions & 17 deletions js/src/jest/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { v4 } from "uuid";

import { traceable } from "../traceable.js";
import { RunTree, RunTreeConfig } from "../run_trees.js";
import { TracerSession } from "../schemas.js";
import { KVMap, TracerSession } from "../schemas.js";
import { randomName } from "../evaluation/_random_name.js";
import { Client } from "../client.js";
import { LangSmithConflictError } from "../utils/error.js";
Expand Down Expand Up @@ -46,6 +46,28 @@ expect.extend({
toBeSemanticCloseTo,
});

const objectHash = (obj: KVMap, depth = 0): any => {
// Prevent infinite recursion
if (depth > 50) {
return "[Max Depth Exceeded]";
}

if (Array.isArray(obj)) {
return obj.map((item) => objectHash(item, depth + 1));
}

if (obj && typeof obj === "object") {
return Object.keys(obj)
.sort()
.reduce((result: KVMap, key) => {
result[key] = objectHash(obj[key], depth + 1);
return result;
}, {});
}

return crypto.createHash("sha256").update(JSON.stringify(obj)).digest("hex");
};

async function _createProject(client: Client, datasetId: string) {
// Create the project, updating the experimentName until we find a unique one.
let project: TracerSession;
Expand Down Expand Up @@ -108,14 +130,8 @@ async function runDatasetSetup(testClient: Client, datasetName: string) {
});
const examples = [];
for await (const example of examplesList) {
const inputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
const outputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
const inputHash = objectHash(example.inputs);
const outputHash = objectHash(example.outputs ?? {});
examples.push({ ...example, inputHash, outputHash });
}
const project = await _createProject(testClient, dataset.id);
Expand Down Expand Up @@ -180,6 +196,9 @@ function wrapTestMethod(method: (...args: any[]) => void) {
params: { inputs: I; outputs: O } | string,
config?: Partial<RunTreeConfig>
): LangSmithJestTestWrapper<I, O> {
// Due to https://github.com/jestjs/jest/issues/13653,
// we must access the local store value here before
// entering an async context
const context = jestAsyncLocalStorageInstance.getStore();
// This typing is wrong, but necessary to avoid lint errors
// eslint-disable-next-line @typescript-eslint/no-misused-promises
Expand All @@ -206,14 +225,8 @@ function wrapTestMethod(method: (...args: any[]) => void) {
typeof params === "string" ? ({} as I) : params.inputs;
const testOutput: O =
typeof params === "string" ? ({} as O) : params.outputs;
const inputHash = crypto
.createHash("sha256")
.update(JSON.stringify(testInput))
.digest("hex");
const outputHash = crypto
.createHash("sha256")
.update(JSON.stringify(testOutput))
.digest("hex");
const inputHash = objectHash(testInput);
const outputHash = objectHash(testOutput ?? {});
if (trackingEnabled()) {
const missingFields = [];
if (examples === undefined) {
Expand Down Expand Up @@ -305,9 +318,25 @@ function wrapTestMethod(method: (...args: any[]) => void) {
};
}

function eachMethod<I extends KVMap, O extends KVMap>(
table: { inputs: I; outputs: O }[]
) {
return function (
name: string,
fn: (params: { inputs: I; outputs: O }) => unknown | Promise<unknown>,
timeout?: number
) {
for (let i = 0; i < table.length; i += 1) {
const example = table[i];
wrapTestMethod(test)<I, O>(example)(`${name} ${i}`, fn, timeout);
}
};
}

const lsTest = Object.assign(wrapTestMethod(test), {
only: wrapTestMethod(test.only),
skip: wrapTestMethod(test.skip),
each: eachMethod,
});

export default {
Expand Down
52 changes: 0 additions & 52 deletions js/src/jest/matchers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,55 +145,3 @@ export async function toBeSemanticCloseTo(
: `Expected "${received}" to be semantically close to "${expected}" (threshold: ${threshold}, similarity: ${similarity})`,
};
}

// export async function toPassEvaluator(
// this: MatcherContext,
// actual: KVMap,
// evaluator: SimpleEvaluator,
// _expected?: KVMap
// ) {
// const runTree = getCurrentRunTree();
// const context = localStorage.getStore();
// if (context === undefined || context.currentExample === undefined) {
// throw new Error(
// `Could not identify example context from current context.\nPlease ensure you are calling this matcher within "ls.test()"`
// );
// }

// const wrappedEvaluator = traceable(evaluator, {
// reference_example_id: context.currentExample.id,
// metadata: {
// example_version: context.currentExample.modified_at
// ? new Date(context.currentExample.modified_at).toISOString()
// : new Date(context.currentExample.created_at).toISOString(),
// },
// client: context.client,
// tracingEnabled: true,
// });

// const evalResult = await wrappedEvaluator({
// input: runTree.inputs,
// expected: context.currentExample.outputs ?? {},
// actual,
// });

// await context.client.logEvaluationFeedback(evalResult, runTree);
// if (!("results" in evalResult) && !evalResult.score) {
// return {
// pass: false,
// message: () =>
// `expected ${this.utils.printReceived(
// actual
// )} to pass evaluator. Failed with ${JSON.stringify(
// evalResult,
// null,
// 2
// )}`,
// };
// }
// return {
// pass: true,
// message: () =>
// `evaluator passed with score ${JSON.stringify(evalResult, null, 2)}`,
// };
// }
26 changes: 26 additions & 0 deletions js/src/tests/jest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,30 @@ ls.describe("js unit testing test demo", () => {
},
180_000
);

ls.test.each([
{
inputs: {
one: "uno",
},
outputs: {
ein: "un",
},
},
{
inputs: {
two: "dos",
},
outputs: {
zwei: "deux",
},
},
])("Does the thing", async ({ inputs: _inputs, outputs: _outputs }) => {
const myApp = () => {
return { bar: "bad" };
};
const res = myApp();
await expect(res).gradedBy(myEvaluator).not.toBeGreaterThanOrEqual(0.5);
return res;
});
});

0 comments on commit 8a7947c

Please sign in to comment.