diff --git a/js/.gitignore b/js/.gitignore index 9bdc273aa..a87badf7c 100644 --- a/js/.gitignore +++ b/js/.gitignore @@ -37,6 +37,9 @@ Chinook_Sqlite.sql /run_trees.cjs /run_trees.js /run_trees.d.ts +/traceable.cjs +/traceable.js +/traceable.d.ts /evaluation.cjs /evaluation.js /evaluation.d.ts diff --git a/js/package.json b/js/package.json index a2a9ce837..4eaeea10f 100644 --- a/js/package.json +++ b/js/package.json @@ -11,6 +11,9 @@ "run_trees.cjs", "run_trees.js", "run_trees.d.ts", + "traceable.cjs", + "traceable.js", + "traceable.d.ts", "evaluation.cjs", "evaluation.js", "evaluation.d.ts", @@ -111,6 +114,11 @@ "import": "./run_trees.js", "require": "./run_trees.cjs" }, + "./traceable": { + "types": "./traceable.d.ts", + "import": "./traceable.js", + "require": "./traceable.cjs" + }, "./evaluation": { "types": "./evaluation.d.ts", "import": "./evaluation.js", @@ -123,4 +131,4 @@ }, "./package.json": "./package.json" } -} \ No newline at end of file +} diff --git a/js/scripts/create-entrypoints.js b/js/scripts/create-entrypoints.js index 21290f5cd..ef8ade94f 100644 --- a/js/scripts/create-entrypoints.js +++ b/js/scripts/create-entrypoints.js @@ -9,6 +9,7 @@ import * as path from "path"; const entrypoints = { client: "client", run_trees: "run_trees", + traceable: "traceable", evaluation: "evaluation/index", schemas: "schemas", }; diff --git a/js/src/index.ts b/js/src/index.ts index 013cbac9c..a34d26520 100644 --- a/js/src/index.ts +++ b/js/src/index.ts @@ -1,8 +1,14 @@ export { Client } from "./client.js"; -export { Dataset, Example, TracerSession, Run, Feedback } from "./schemas.js"; +export type { + Dataset, + Example, + TracerSession, + Run, + Feedback, +} from "./schemas.js"; -export { RunTree, RunTreeConfig } from "./run_trees.js"; +export { RunTree, type RunTreeConfig } from "./run_trees.js"; // Update using yarn bump-version export const __version__ = "0.0.70"; diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 6b98ae661..f081593f1 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -205,3 +205,11 @@ export class RunTree implements BaseRun { await this.client.updateRun(this.id, runUpdate); } } + +export function isRunTree(x?: unknown): x is RunTree { + return ( + x !== undefined && + typeof (x as RunTree).createChild === "function" && + typeof (x as RunTree).postRun === "function" + ); +} diff --git a/js/src/tests/traceable.int.test.ts b/js/src/tests/traceable.int.test.ts new file mode 100644 index 000000000..eaf3c05b8 --- /dev/null +++ b/js/src/tests/traceable.int.test.ts @@ -0,0 +1,168 @@ +import { v4 as uuidv4 } from "uuid"; +// eslint-disable-next-line import/no-extraneous-dependencies +import { FakeStreamingLLM } from "@langchain/core/utils/testing"; +import { Client } from "../client.js"; +import { traceable } from "../traceable.js"; +import { RunTree } from "../run_trees.js"; + +async function deleteProject(langchainClient: Client, projectName: string) { + try { + await langchainClient.readProject({ projectName }); + await langchainClient.deleteProject({ projectName }); + } catch (e) { + // Pass + } +} + +async function waitUntil( + condition: () => Promise, + timeout: number, + interval: number +): Promise { + const start = Date.now(); + while (Date.now() - start < timeout) { + if (await condition()) { + return; + } + await new Promise((resolve) => setTimeout(resolve, interval)); + } + throw new Error("Timeout"); +} + +async function waitUntilRunFound( + client: Client, + runId: string, + checkOutputs = false +) { + return waitUntil( + async () => { + try { + const run = await client.readRun(runId); + if (checkOutputs) { + return ( + run.outputs !== null && + run.outputs !== undefined && + Object.keys(run.outputs).length !== 0 + ); + } + return true; + } catch (e) { + return false; + } + }, + 30_000, + 5_000 + ); +} + +test.concurrent( + "Test traceable wrapper", + async () => { + const langchainClient = new Client({ + callerOptions: { maxRetries: 0 }, + }); + const runId = uuidv4(); + const projectName = "__test_traceable_wrapper"; + const addValueTraceable = traceable( + (a: string, b: number) => { + return a + b; + }, + { + name: "add_value", + project_name: projectName, + client: langchainClient, + id: runId, + } + ); + + expect(await addValueTraceable("testing", 9)).toBe("testing9"); + + await waitUntilRunFound(langchainClient, runId, true); + const storedRun = await langchainClient.readRun(runId); + expect(storedRun.id).toEqual(runId); + + const runId2 = uuidv4(); + const nestedAddValueTraceable = traceable( + (a: string, b: number) => { + return a + b; + }, + { + name: "nested_add_value", + project_name: projectName, + client: langchainClient, + } + ); + const entryTraceable = traceable( + async (complex: { value: string }) => { + const result = await nestedAddValueTraceable(complex.value, 1); + const result2 = await nestedAddValueTraceable(result, 2); + await nestedAddValueTraceable( + new RunTree({ + name: "root_nested_add_value", + project_name: projectName, + client: langchainClient, + }), + result, + 2 + ); + return nestedAddValueTraceable(result2, 3); + }, + { + name: "run_with_nesting", + project_name: projectName, + client: langchainClient, + id: runId2, + } + ); + + expect(await entryTraceable({ value: "testing" })).toBe("testing123"); + + await waitUntilRunFound(langchainClient, runId2, true); + const storedRun2 = await langchainClient.readRun(runId2); + expect(storedRun2.id).toEqual(runId2); + + const runId3 = uuidv4(); + + const llm = new FakeStreamingLLM({ sleep: 0 }); + + const iterableTraceable = traceable(llm.stream.bind(llm), { + name: "iterable_traceable", + project_name: projectName, + client: langchainClient, + id: runId3, + }); + + const chunks = []; + + for await (const chunk of await iterableTraceable("Hello there")) { + chunks.push(chunk); + } + expect(chunks.join("")).toBe("Hello there"); + await waitUntilRunFound(langchainClient, runId3, true); + const storedRun3 = await langchainClient.readRun(runId3); + expect(storedRun3.id).toEqual(runId3); + + await deleteProject(langchainClient, projectName); + + async function overload(a: string, b: number): Promise; + async function overload(config: { a: string; b: number }): Promise; + async function overload( + ...args: [a: string, b: number] | [config: { a: string; b: number }] + ): Promise { + if (args.length === 1) { + return args[0].a + args[0].b; + } + return args[0] + args[1]; + } + + const wrappedOverload = traceable(overload, { + name: "wrapped_overload", + project_name: projectName, + client: langchainClient, + }); + + expect(await wrappedOverload("testing", 123)).toBe("testing123"); + expect(await wrappedOverload({ a: "testing", b: 456 })).toBe("testing456"); + }, + 180_000 +); diff --git a/js/src/traceable.ts b/js/src/traceable.ts new file mode 100644 index 000000000..e31fded3a --- /dev/null +++ b/js/src/traceable.ts @@ -0,0 +1,204 @@ +import { AsyncLocalStorage } from "async_hooks"; + +import { RunTree, RunTreeConfig, isRunTree } from "./run_trees.js"; +import { KVMap } from "./schemas.js"; + +const asyncLocalStorage = new AsyncLocalStorage(); + +export type RunTreeLike = RunTree; + +type WrapArgReturnPair = Pair extends [ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + infer Args extends any[], + infer Return +] + ? { + (...args: Args): Promise; + (...args: [runTree: RunTreeLike, ...rest: Args]): Promise; + } + : never; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type UnionToIntersection = (U extends any ? (x: U) => void : never) extends ( + x: infer I +) => void + ? I + : never; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type TraceableFunction any> = + // function overloads are represented as intersections rather than unions + // matches the behavior introduced in https://github.com/microsoft/TypeScript/pull/54448 + Func extends { + (...args: infer A1): infer R1; + (...args: infer A2): infer R2; + (...args: infer A3): infer R3; + (...args: infer A4): infer R4; + (...args: infer A5): infer R5; + } + ? UnionToIntersection< + WrapArgReturnPair<[A1, R1] | [A2, R2] | [A3, R3] | [A4, R4] | [A5, R5]> + > + : Func extends { + (...args: infer A1): infer R1; + (...args: infer A2): infer R2; + (...args: infer A3): infer R3; + (...args: infer A4): infer R4; + } + ? UnionToIntersection< + WrapArgReturnPair<[A1, R1] | [A2, R2] | [A3, R3] | [A4, R4]> + > + : Func extends { + (...args: infer A1): infer R1; + (...args: infer A2): infer R2; + (...args: infer A3): infer R3; + } + ? UnionToIntersection> + : Func extends { + (...args: infer A1): infer R1; + (...args: infer A2): infer R2; + } + ? UnionToIntersection> + : Func extends { + (...args: infer A1): infer R1; + } + ? UnionToIntersection> + : never; + +const isAsyncIterable = (x: unknown): x is AsyncIterable => + x != null && + typeof x === "object" && + // eslint-disable-next-line @typescript-eslint/no-explicit-any + typeof (x as any)[Symbol.asyncIterator] === "function"; + +/** + * Higher-order function that takes function as input and returns a + * "TraceableFunction" - a wrapped version of the input that + * automatically handles tracing. If the returned traceable function calls any + * traceable functions, those are automatically traced as well. + * + * The returned TraceableFunction can accept a run tree or run tree config as + * its first argument. If omitted, it will default to the caller's run tree, + * or will be treated as a root run. + * + * @param wrappedFunc Targeted function to be traced + * @param config Additional metadata such as name, tags or providing + * a custom LangSmith client instance + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function traceable any>( + wrappedFunc: Func, + config?: Partial +) { + type Inputs = Parameters; + type Output = ReturnType; + + const traceableFunc = async ( + ...args: Inputs | [RunTreeLike, ...Inputs] + ): Promise => { + let currentRunTree: RunTree; + let rawInputs: Inputs; + + const ensuredConfig: RunTreeConfig = { + name: wrappedFunc.name || "", + ...config, + }; + + const previousRunTree = asyncLocalStorage.getStore(); + if (isRunTree(args[0])) { + currentRunTree = args[0]; + rawInputs = args.slice(1) as Inputs; + } else if (previousRunTree !== undefined) { + currentRunTree = await previousRunTree.createChild(ensuredConfig); + rawInputs = args as Inputs; + } else { + currentRunTree = new RunTree(ensuredConfig); + rawInputs = args as Inputs; + } + + let inputs: KVMap; + const firstInput = rawInputs[0]; + if (firstInput == null) { + inputs = {}; + } else if (rawInputs.length > 1) { + inputs = { args: rawInputs }; + } else if (isKVMap(firstInput)) { + inputs = firstInput; + } else { + inputs = { input: firstInput }; + } + + currentRunTree.inputs = inputs; + + const initialOutputs = currentRunTree.outputs; + const initialError = currentRunTree.error; + await currentRunTree.postRun(); + + return new Promise((resolve, reject) => { + void asyncLocalStorage.run(currentRunTree, async () => { + try { + const rawOutput = await wrappedFunc(...rawInputs); + if (isAsyncIterable(rawOutput)) { + // eslint-disable-next-line no-inner-declarations + async function* wrapOutputForTracing() { + const chunks: unknown[] = []; + // TypeScript thinks this is unsafe + for await (const chunk of rawOutput as AsyncIterable) { + chunks.push(chunk); + yield chunk; + } + await currentRunTree.end({ outputs: chunks }); + await currentRunTree.patchRun(); + } + return resolve(wrapOutputForTracing() as Output); + } else { + const outputs: KVMap = isKVMap(rawOutput) + ? rawOutput + : { outputs: rawOutput }; + + if (initialOutputs === currentRunTree.outputs) { + await currentRunTree.end(outputs); + } else { + currentRunTree.end_time = Date.now(); + } + + await currentRunTree.patchRun(); + return resolve(rawOutput); + } + } catch (error) { + if (initialError === currentRunTree.error) { + await currentRunTree.end(initialOutputs, String(error)); + } else { + currentRunTree.end_time = Date.now(); + } + + await currentRunTree.patchRun(); + reject(error); + } + }); + }); + }; + + Object.defineProperty(wrappedFunc, "langsmith:traceable", { + value: config, + }); + + return traceableFunc as TraceableFunction; +} + +export function isTraceableFunction( + x: unknown + // eslint-disable-next-line @typescript-eslint/no-explicit-any +): x is TraceableFunction { + return typeof x === "function" && "langsmith:traceable" in x; +} + +function isKVMap(x: unknown): x is Record { + return ( + typeof x === "object" && + x != null && + !Array.isArray(x) && + // eslint-disable-next-line no-instanceof/no-instanceof + !(x instanceof Date) + ); +} diff --git a/js/tsconfig.json b/js/tsconfig.json index 5a466ec5a..5edd93e56 100644 --- a/js/tsconfig.json +++ b/js/tsconfig.json @@ -34,6 +34,7 @@ "entryPoints": [ "src/client.ts", "src/run_trees.ts", + "src/traceable.ts", "src/evaluation/index.ts", "src/schemas.ts" ] diff --git a/js/yarn.lock b/js/yarn.lock index 4456625d0..a153e3709 100644 --- a/js/yarn.lock +++ b/js/yarn.lock @@ -3452,9 +3452,9 @@ kleur@^3.0.3: integrity sha512-eTIzlVOSUR+JxdDFepEYcBMtZ9Qqdef+rnzWdRZuMbOywu5tO2w2N7rqjoANZ5k9vywhL6Br1VRjUIgTQx4E8w== langsmith@~0.0.48: - version "0.0.68" - resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.0.68.tgz#8748d3203d348cc19e5ee4ddeef908964a62e21a" - integrity sha512-bxaJndEhUFDfv5soWKxONrLMZaVZfS+G4smJl3WYQlsEph8ierG3QbJfx1PEwl40TD0aFBjzq62usUX1UOCjuA== + version "0.0.70" + resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.0.70.tgz#797be2b26da18843a94a802b6a73c91b72e8042b" + integrity sha512-QFHrzo/efBowGPCxtObv7G40/OdwqQfGshavMbSJtHBgX+OMqnn4lCMqVeEwTdyue4lEcpwAsGNg5Vty91YIyw== dependencies: "@types/uuid" "^9.0.1" commander "^10.0.1"