Skip to content

Commit

Permalink
JS AsyncLocalStorage tracer (#442)
Browse files Browse the repository at this point in the history
@dqbd @nfcampos @hinthornw

---------

Co-authored-by: Tat Dat Duong <[email protected]>
  • Loading branch information
jacoblee93 and dqbd authored Feb 15, 2024
1 parent a9a12e9 commit ebb61b6
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 6 deletions.
3 changes: 3 additions & 0 deletions js/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Chinook_Sqlite.sql
/run_trees.cjs
/run_trees.js
/run_trees.d.ts
/traceable.cjs
/traceable.js
/traceable.d.ts
/evaluation.cjs
/evaluation.js
/evaluation.d.ts
Expand Down
10 changes: 9 additions & 1 deletion js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
"run_trees.cjs",
"run_trees.js",
"run_trees.d.ts",
"traceable.cjs",
"traceable.js",
"traceable.d.ts",
"evaluation.cjs",
"evaluation.js",
"evaluation.d.ts",
Expand Down Expand Up @@ -111,6 +114,11 @@
"import": "./run_trees.js",
"require": "./run_trees.cjs"
},
"./traceable": {
"types": "./traceable.d.ts",
"import": "./traceable.js",
"require": "./traceable.cjs"
},
"./evaluation": {
"types": "./evaluation.d.ts",
"import": "./evaluation.js",
Expand All @@ -123,4 +131,4 @@
},
"./package.json": "./package.json"
}
}
}
1 change: 1 addition & 0 deletions js/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as path from "path";
const entrypoints = {
client: "client",
run_trees: "run_trees",
traceable: "traceable",
evaluation: "evaluation/index",
schemas: "schemas",
};
Expand Down
10 changes: 8 additions & 2 deletions js/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
export { Client } from "./client.js";

export { Dataset, Example, TracerSession, Run, Feedback } from "./schemas.js";
export type {
Dataset,
Example,
TracerSession,
Run,
Feedback,
} from "./schemas.js";

export { RunTree, RunTreeConfig } from "./run_trees.js";
export { RunTree, type RunTreeConfig } from "./run_trees.js";

// Update using yarn bump-version
export const __version__ = "0.0.70";
8 changes: 8 additions & 0 deletions js/src/run_trees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,11 @@ export class RunTree implements BaseRun {
await this.client.updateRun(this.id, runUpdate);
}
}

export function isRunTree(x?: unknown): x is RunTree {
return (
x !== undefined &&
typeof (x as RunTree).createChild === "function" &&
typeof (x as RunTree).postRun === "function"
);
}
168 changes: 168 additions & 0 deletions js/src/tests/traceable.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import { v4 as uuidv4 } from "uuid";
// eslint-disable-next-line import/no-extraneous-dependencies
import { FakeStreamingLLM } from "@langchain/core/utils/testing";
import { Client } from "../client.js";
import { traceable } from "../traceable.js";
import { RunTree } from "../run_trees.js";

async function deleteProject(langchainClient: Client, projectName: string) {
try {
await langchainClient.readProject({ projectName });
await langchainClient.deleteProject({ projectName });
} catch (e) {
// Pass
}
}

async function waitUntil(
condition: () => Promise<boolean>,
timeout: number,
interval: number
): Promise<void> {
const start = Date.now();
while (Date.now() - start < timeout) {
if (await condition()) {
return;
}
await new Promise((resolve) => setTimeout(resolve, interval));
}
throw new Error("Timeout");
}

async function waitUntilRunFound(
client: Client,
runId: string,
checkOutputs = false
) {
return waitUntil(
async () => {
try {
const run = await client.readRun(runId);
if (checkOutputs) {
return (
run.outputs !== null &&
run.outputs !== undefined &&
Object.keys(run.outputs).length !== 0
);
}
return true;
} catch (e) {
return false;
}
},
30_000,
5_000
);
}

test.concurrent(
"Test traceable wrapper",
async () => {
const langchainClient = new Client({
callerOptions: { maxRetries: 0 },
});
const runId = uuidv4();
const projectName = "__test_traceable_wrapper";
const addValueTraceable = traceable(
(a: string, b: number) => {
return a + b;
},
{
name: "add_value",
project_name: projectName,
client: langchainClient,
id: runId,
}
);

expect(await addValueTraceable("testing", 9)).toBe("testing9");

await waitUntilRunFound(langchainClient, runId, true);
const storedRun = await langchainClient.readRun(runId);
expect(storedRun.id).toEqual(runId);

const runId2 = uuidv4();
const nestedAddValueTraceable = traceable(
(a: string, b: number) => {
return a + b;
},
{
name: "nested_add_value",
project_name: projectName,
client: langchainClient,
}
);
const entryTraceable = traceable(
async (complex: { value: string }) => {
const result = await nestedAddValueTraceable(complex.value, 1);
const result2 = await nestedAddValueTraceable(result, 2);
await nestedAddValueTraceable(
new RunTree({
name: "root_nested_add_value",
project_name: projectName,
client: langchainClient,
}),
result,
2
);
return nestedAddValueTraceable(result2, 3);
},
{
name: "run_with_nesting",
project_name: projectName,
client: langchainClient,
id: runId2,
}
);

expect(await entryTraceable({ value: "testing" })).toBe("testing123");

await waitUntilRunFound(langchainClient, runId2, true);
const storedRun2 = await langchainClient.readRun(runId2);
expect(storedRun2.id).toEqual(runId2);

const runId3 = uuidv4();

const llm = new FakeStreamingLLM({ sleep: 0 });

const iterableTraceable = traceable(llm.stream.bind(llm), {
name: "iterable_traceable",
project_name: projectName,
client: langchainClient,
id: runId3,
});

const chunks = [];

for await (const chunk of await iterableTraceable("Hello there")) {
chunks.push(chunk);
}
expect(chunks.join("")).toBe("Hello there");
await waitUntilRunFound(langchainClient, runId3, true);
const storedRun3 = await langchainClient.readRun(runId3);
expect(storedRun3.id).toEqual(runId3);

await deleteProject(langchainClient, projectName);

async function overload(a: string, b: number): Promise<string>;
async function overload(config: { a: string; b: number }): Promise<string>;
async function overload(
...args: [a: string, b: number] | [config: { a: string; b: number }]
): Promise<string> {
if (args.length === 1) {
return args[0].a + args[0].b;
}
return args[0] + args[1];
}

const wrappedOverload = traceable(overload, {
name: "wrapped_overload",
project_name: projectName,
client: langchainClient,
});

expect(await wrappedOverload("testing", 123)).toBe("testing123");
expect(await wrappedOverload({ a: "testing", b: 456 })).toBe("testing456");
},
180_000
);
Loading

0 comments on commit ebb61b6

Please sign in to comment.