From e1521a5f9464f921853b5aa0654e4f21e9704326 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Fri, 11 Oct 2024 13:32:18 -0700 Subject: [PATCH] feat(js): Adds further required support for context vars (#1093) --- js/package.json | 2 +- js/src/index.ts | 2 +- js/src/run_trees.ts | 15 ++++++- js/src/singletons/constants.ts | 1 + js/src/tests/traceable.test.ts | 76 ++++++++++++++++++++++++++++++++++ js/src/traceable.ts | 12 ++++++ 6 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 js/src/singletons/constants.ts diff --git a/js/package.json b/js/package.json index e2cbcc447..7a58f2e2d 100644 --- a/js/package.json +++ b/js/package.json @@ -1,6 +1,6 @@ { "name": "langsmith", - "version": "0.1.64", + "version": "0.1.65", "description": "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform.", "packageManager": "yarn@1.22.19", "files": [ diff --git a/js/src/index.ts b/js/src/index.ts index a23fc1732..3d8bba44a 100644 --- a/js/src/index.ts +++ b/js/src/index.ts @@ -14,4 +14,4 @@ export { RunTree, type RunTreeConfig } from "./run_trees.js"; export { overrideFetchImplementation } from "./singletons/fetch.js"; // Update using yarn bump-version -export const __version__ = "0.1.64"; +export const __version__ = "0.1.65"; diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 0aa3bc6a0..c8e8091ab 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -8,6 +8,7 @@ import { import { Client } from "./client.js"; import { isTracingEnabled } from "./env.js"; import { warnOnce } from "./utils/warn.js"; +import { _LC_CONTEXT_VARIABLES_KEY } from "./singletons/constants.js"; function stripNonAlphanumeric(input: string) { return input.replace(/[-:.]/g, ""); @@ -172,7 +173,12 @@ export class RunTree implements BaseRun { execution_order: number; child_execution_order: number; - constructor(originalConfig: RunTreeConfig) { + constructor(originalConfig: RunTreeConfig | RunTree) { + // If you pass in a run tree directly, return a shallow clone + if (isRunTree(originalConfig)) { + Object.assign(this, { ...originalConfig }); + return; + } const defaultConfig = RunTree.getDefaultConfig(); const { metadata, ...config } = originalConfig; const client = config.client ?? RunTree.getSharedClient(); @@ -248,6 +254,13 @@ export class RunTree implements BaseRun { child_execution_order: child_execution_order, }); + // Copy context vars over into the new run tree. + if (_LC_CONTEXT_VARIABLES_KEY in this) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (child as any)[_LC_CONTEXT_VARIABLES_KEY] = + this[_LC_CONTEXT_VARIABLES_KEY]; + } + type ExtraWithSymbol = Record; const LC_CHILD = Symbol.for("lc:child_config"); diff --git a/js/src/singletons/constants.ts b/js/src/singletons/constants.ts new file mode 100644 index 000000000..b841334af --- /dev/null +++ b/js/src/singletons/constants.ts @@ -0,0 +1 @@ +export const _LC_CONTEXT_VARIABLES_KEY = Symbol.for("lc:context_variables"); diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index 7842260fd..9bda5890f 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -1,9 +1,11 @@ import { jest } from "@jest/globals"; import { RunTree, RunTreeConfig } from "../run_trees.js"; +import { _LC_CONTEXT_VARIABLES_KEY } from "../singletons/constants.js"; import { ROOT, traceable, withRunTree } from "../traceable.js"; import { getAssumedTreeFromCalls } from "./utils/tree.js"; import { mockClient } from "./utils/mock_client.js"; import { Client, overrideFetchImplementation } from "../index.js"; +import { AsyncLocalStorageProviderSingleton } from "../singletons/traceable.js"; test("basic traceable implementation", async () => { const { client, callSpy } = mockClient(); @@ -103,6 +105,80 @@ test("nested traceable implementation", async () => { }); }); +test("nested traceable passes through LangChain context vars", (done) => { + const alsInstance = AsyncLocalStorageProviderSingleton.getInstance(); + + alsInstance.run( + { + [_LC_CONTEXT_VARIABLES_KEY]: { foo: "bar" }, + } as any, + // eslint-disable-next-line @typescript-eslint/no-misused-promises + async () => { + try { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + const { client, callSpy } = mockClient(); + + const llm = traceable(async function llm(input: string) { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + return input.repeat(2); + }); + + const str = traceable(async function* str(input: string) { + const response = input.split("").reverse(); + for (const char of response) { + yield char; + } + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + }); + + const chain = traceable( + async function chain(input: string) { + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + const question = await llm(input); + + let answer = ""; + for await (const char of str(question)) { + answer += char; + } + + return { question, answer }; + }, + { client, tracingEnabled: true } + ); + + const result = await chain("Hello world"); + + expect(result).toEqual({ + question: "Hello worldHello world", + answer: "dlrow olleHdlrow olleH", + }); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["chain:0", "llm:1", "str:2"], + edges: [ + ["chain:0", "llm:1"], + ["chain:0", "str:2"], + ], + }); + expect( + (alsInstance.getStore() as any)?.[_LC_CONTEXT_VARIABLES_KEY]?.foo + ).toEqual("bar"); + done(); + } catch (e) { + done(e); + } + } + ); +}); + test("trace circular input and output objects", async () => { const { client, callSpy } = mockClient(); const a: Record = {}; diff --git a/js/src/traceable.ts b/js/src/traceable.ts index f4ceebda2..0934a55df 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -13,6 +13,7 @@ import { ROOT, AsyncLocalStorageProviderSingleton, } from "./singletons/traceable.js"; +import { _LC_CONTEXT_VARIABLES_KEY } from "./singletons/constants.js"; import { TraceableFunction } from "./singletons/types.js"; import { isKVMap, @@ -422,6 +423,17 @@ export function traceable any>( processedArgs, config?.getInvocationParams ); + // If a context var is set by LangChain outside of a traceable, + // it will be an object with a single property and we should copy + // context vars over into the new run tree. + if ( + prevRunFromStore !== undefined && + _LC_CONTEXT_VARIABLES_KEY in prevRunFromStore + ) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (currentRunTree as any)[_LC_CONTEXT_VARIABLES_KEY] = + prevRunFromStore[_LC_CONTEXT_VARIABLES_KEY]; + } return [currentRunTree, processedArgs as Inputs]; })();