diff --git a/js/src/tests/traceable.int.test.ts b/js/src/tests/traceable.int.test.ts index de1277967..83ef020a4 100644 --- a/js/src/tests/traceable.int.test.ts +++ b/js/src/tests/traceable.int.test.ts @@ -2,7 +2,11 @@ import { v4 as uuidv4 } from "uuid"; // eslint-disable-next-line import/no-extraneous-dependencies import { FakeStreamingLLM } from "@langchain/core/utils/testing"; import { Client } from "../client.js"; -import { isTraceableFunction, traceable } from "../traceable.js"; +import { + getCurrentRunTree, + isTraceableFunction, + traceable, +} from "../traceable.js"; import { RunTree } from "../run_trees.js"; async function deleteProject(langchainClient: Client, projectName: string) { @@ -170,3 +174,41 @@ test.concurrent( }, 180_000 ); + +test.concurrent("Test get run tree method", async () => { + const langchainClient = new Client({ + callerOptions: { maxRetries: 0 }, + }); + // Called outside a traceable function + expect(() => getCurrentRunTree()).toThrowError(); + const runId = uuidv4(); + const projectName = "__test_traceable_wrapper"; + const nestedAddValueTraceable = traceable( + (a: string, b: number) => { + const runTree = getCurrentRunTree(); + expect(runTree.id).toBeDefined(); + expect(runTree.id).not.toEqual(runId); + expect(runTree.dotted_order.includes(`${runId}.`)).toBe(true); + return a + b; + }, + { + name: "nested_add_value", + project_name: projectName, + client: langchainClient, + } + ); + const addValueTraceable = traceable( + (a: string, b: number) => { + const runTree = getCurrentRunTree(); + expect(runTree.id).toBe(runId); + return nestedAddValueTraceable(a, b); + }, + { + name: "add_value", + project_name: projectName, + client: langchainClient, + id: runId, + } + ); + expect(await addValueTraceable("testing", 9)).toBe("testing9"); +}); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index b33dd621e..3f476d5bb 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -196,6 +196,26 @@ export function traceable any>( return traceableFunc as TraceableFunction; } +/** + * Return the current run tree from within a traceable-wrapped function. + * Will throw an error if called outside of a traceable function. + * + * @returns The run tree for the given context. + */ +export function getCurrentRunTree(): RunTree { + const runTree = asyncLocalStorage.getStore(); + if (runTree === undefined) { + throw new Error( + [ + "Could not get the current run tree.", + "", + "Please make sure you are calling this method within a traceable function.", + ].join("\n") + ); + } + return runTree; +} + export function isTraceableFunction( x: unknown // eslint-disable-next-line @typescript-eslint/no-explicit-any