Skip to content

Commit

Permalink
Final fix for async local storage issue (hopefully)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Dec 20, 2024
1 parent c644b4d commit 70e94f5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 107 deletions.
4 changes: 3 additions & 1 deletion js/src/jest/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ export const jestAsyncLocalStorageInstance = new AsyncLocalStorage<{
createdAt: string;
project?: TracerSession;
currentExample?: Partial<Example>;
client?: Client;
client: Client;
suiteUuid: string;
suiteName: string;
}>();

export function trackingEnabled() {
Expand Down
194 changes: 88 additions & 106 deletions js/src/jest/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable import/no-extraneous-dependencies */
/* eslint-disable @typescript-eslint/no-namespace */

import { expect, test, describe, beforeAll } from "@jest/globals";
import { expect, test, describe } from "@jest/globals";
import crypto from "crypto";
import { v4 } from "uuid";

Expand Down Expand Up @@ -80,6 +80,55 @@ export type LangSmithJestDescribeWrapper = (
config?: Partial<RunTreeConfig>
) => void;

const setupPromises = new Map();

async function runDatasetSetup(testClient: Client, datasetName: string) {
let storageValue;
if (!trackingEnabled()) {
storageValue = {
createdAt: new Date().toISOString(),
};
} else {
let dataset;
try {
dataset = await testClient.readDataset({
datasetName,
});
} catch (e: any) {
if (e.message.includes("not found")) {
dataset = await testClient.createDataset(datasetName, {
description: `Dataset for unit tests created on ${new Date().toISOString()}`,
});
} else {
throw e;
}
}
const examplesList = testClient.listExamples({
datasetName,
});
const examples = [];
for await (const example of examplesList) {
const inputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
const outputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
examples.push({ ...example, inputHash, outputHash });
}
const project = await _createProject(testClient, dataset.id);
storageValue = {
dataset,
examples,
project,
client: testClient,
};
}
return storageValue;
}

function wrapDescribeMethod(
method: (name: string, fn: () => void | Promise<void>) => void
): LangSmithJestDescribeWrapper {
Expand All @@ -88,102 +137,26 @@ function wrapDescribeMethod(
fn: () => void | Promise<void>,
config?: Partial<RunTreeConfig>
) {
const testClient = config?.client ?? RunTree.getSharedClient();
let storageValue;
return method(datasetName, () => {
// beforeAll(async () => {
// if (!trackingEnabled()) {
// storageValue = {
// createdAt: new Date().toISOString(),
// };
// } else {
// let dataset;
// try {
// dataset = await testClient.readDataset({
// datasetName,
// });
// } catch (e: any) {
// if (e.message.includes("not found")) {
// dataset = await testClient.createDataset(datasetName, {
// description: `Dataset for unit tests created on ${new Date().toISOString()}`,
// });
// } else {
// throw e;
// }
// }
// const examplesList = testClient.listExamples({
// datasetName,
// });
// const examples = [];
// for await (const example of examplesList) {
// const inputHash = crypto
// .createHash("sha256")
// .update(JSON.stringify(example.inputs))
// .digest("hex");
// const outputHash = crypto
// .createHash("sha256")
// .update(JSON.stringify(example.inputs))
// .digest("hex");
// examples.push({ ...example, inputHash, outputHash });
// }
// const project = await _createProject(testClient, dataset.id);
// storageValue = {
// dataset,
// examples,
// createdAt: new Date().toISOString(),
// project,
// client: testClient,
// };
// }
// jestAsyncLocalStorageInstance.enterWith(storageValue!);
// });

(async function init() {
if (!trackingEnabled()) {
storageValue = {
createdAt: new Date().toISOString(),
};
} else {
let dataset;
try {
dataset = await testClient.readDataset({
datasetName,
});
} catch (e: any) {
if (e.message.includes("not found")) {
dataset = await testClient.createDataset(datasetName, {
description: `Dataset for unit tests created on ${new Date().toISOString()}`,
});
} else {
throw e;
}
}
const examplesList = testClient.listExamples({
datasetName,
});
const examples = [];
for await (const example of examplesList) {
const inputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
const outputHash = crypto
.createHash("sha256")
.update(JSON.stringify(example.inputs))
.digest("hex");
examples.push({ ...example, inputHash, outputHash });
}
const project = await _createProject(testClient, dataset.id);
storageValue = {
dataset,
examples,
createdAt: new Date().toISOString(),
project,
client: testClient,
};
}
jestAsyncLocalStorageInstance.enterWith(storageValue!);
})().then(fn);
const suiteUuid = v4();
/**
* We cannot rely on setting AsyncLocalStorage in beforeAll or beforeEach,
* due to https://github.com/jestjs/jest/issues/13653 and needing to use
* the janky .enterWith.
*
* We also cannot do async setup in describe due to Jest restrictions.
* However, .run works and since the below function does not contain synchronously,
* it works.
*/
void jestAsyncLocalStorageInstance.run(
{
suiteUuid,
suiteName: datasetName,
client: config?.client ?? RunTree.getSharedClient(),
createdAt: new Date().toISOString(),
},
fn
);
});
};
}
Expand All @@ -207,14 +180,28 @@ function wrapTestMethod(method: (...args: any[]) => void) {
params: { inputs: I; outputs: O } | string,
config?: Partial<RunTreeConfig>
): LangSmithJestTestWrapper<I, O> {
const context = jestAsyncLocalStorageInstance.getStore();
// This typing is wrong, but necessary to avoid lint errors
// eslint-disable-next-line @typescript-eslint/no-misused-promises
return async function (...args: any[]) {
const context = jestAsyncLocalStorageInstance.getStore();
console.log(context);
return method(
args[0],
async () => {
if (context === undefined) {
throw new Error(
`Could not retrieve test context.\nPlease make sure you have tracing enabled and you are wrapping all of your test cases in an "ls.describe()" function.`
);
}
// Because of https://github.com/jestjs/jest/issues/13653, we have to do asynchronous setup
// within the test itself
if (!setupPromises.get(context.suiteUuid)) {
setupPromises.set(
context.suiteUuid,
runDatasetSetup(context.client, context.suiteName)
);
}
const { examples, dataset, createdAt, project, client } =
await setupPromises.get(context.suiteUuid);
const testInput: I =
typeof params === "string" ? ({} as I) : params.inputs;
const testOutput: O =
Expand All @@ -227,13 +214,6 @@ function wrapTestMethod(method: (...args: any[]) => void) {
.createHash("sha256")
.update(JSON.stringify(testOutput))
.digest("hex");
const context = jestAsyncLocalStorageInstance.getStore();
if (context === undefined) {
throw new Error(
`Could not retrieve test context.\nPlease make sure you have tracing enabled and you are wrapping all of your test cases in an "ls.describe()" function.`
);
}
const { examples, dataset, createdAt, project, client } = context;
if (trackingEnabled()) {
const missingFields = [];
if (examples === undefined) {
Expand All @@ -258,7 +238,7 @@ function wrapTestMethod(method: (...args: any[]) => void) {
);
}
const testClient = config?.client ?? client!;
let example = (examples ?? []).find((example) => {
let example = (examples ?? []).find((example: any) => {
return (
example.inputHash === inputHash &&
example.outputHash === outputHash
Expand All @@ -275,6 +255,7 @@ function wrapTestMethod(method: (...args: any[]) => void) {
);
example = { ...newExample, inputHash, outputHash };
}
// .enterWith is OK here
jestAsyncLocalStorageInstance.enterWith({
...context,
currentExample: example,
Expand Down Expand Up @@ -307,6 +288,7 @@ function wrapTestMethod(method: (...args: any[]) => void) {
await (tracedFunction as any)(testInput);
await testClient.awaitPendingTraceBatches();
} else {
// .enterWith is OK here
jestAsyncLocalStorageInstance.enterWith({
...context,
currentExample: { inputs: testInput, outputs: testOutput },
Expand Down

0 comments on commit 70e94f5

Please sign in to comment.