From 4a7d16e6cc512e4155912b5b28cfc40074d1e85d Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 13 Feb 2024 21:04:41 -0800 Subject: [PATCH] New form factor for traceable wrapper --- js/src/index.ts | 12 +- js/src/run_helpers.ts | 141 ++++++++++++--------- js/src/run_trees.ts | 8 ++ js/src/tests/traceable_wrapper.int.test.ts | 42 ++++++ 4 files changed, 143 insertions(+), 60 deletions(-) create mode 100644 js/src/tests/traceable_wrapper.int.test.ts diff --git a/js/src/index.ts b/js/src/index.ts index 832f2a8a7..d3c0c38b0 100644 --- a/js/src/index.ts +++ b/js/src/index.ts @@ -1,12 +1,18 @@ export { Client } from "./client.js"; -export { Dataset, Example, TracerSession, Run, Feedback } from "./schemas.js"; +export type { + Dataset, + Example, + TracerSession, + Run, + Feedback, +} from "./schemas.js"; -export { RunTree, RunTreeConfig } from "./run_trees.js"; +export { RunTree, type RunTreeConfig } from "./run_trees.js"; export { traceable, - TraceableFunction, + type TraceableFunction, isTraceableFunction, } from "./run_helpers.js"; diff --git a/js/src/run_helpers.ts b/js/src/run_helpers.ts index 4b9108e57..118b98526 100644 --- a/js/src/run_helpers.ts +++ b/js/src/run_helpers.ts @@ -1,4 +1,4 @@ -import { RunTree, RunTreeConfig } from "./run_trees.js"; +import { RunTree, RunTreeConfig, isRunTree } from "./run_trees.js"; import { KVMap } from "./schemas.js"; export type TraceableFunction = ( @@ -13,69 +13,96 @@ export function isTraceableFunction( return typeof x === "function" && "langsmith:traceable" in x; } -export const traceable = (params: RunTreeConfig) => { - return (func: (rawInput: I, parentRun: RunTree | null) => O) => { - async function wrappedFunc( - rawInput: I, - parentRun: RunTree | { root: RunTree } | null - ) { - if (parentRun == null) { - return await func(rawInput, parentRun); - } +export function traceable( + wrappedFunc: (...inputs: Inputs) => Output, + config?: RunTreeConfig +) { + let boundParentRunTree: RunTree | undefined; + const traceableFunc = async ( + ...rawInputs: Inputs | [RunTree, ...Inputs] + ): Promise => { + let parentRunTree: RunTree | undefined = boundParentRunTree; + let wrappedFunctionInputs: Inputs; - const inputs: KVMap = - typeof rawInput === "object" && rawInput != null - ? rawInput - : { input: rawInput }; - - const currentRun: RunTree = - "root" in parentRun - ? parentRun.root - : await parentRun.createChild({ - ...params, - inputs, - }); + if (isRunTree(rawInputs[0])) { + [parentRunTree, ...wrappedFunctionInputs] = rawInputs as [ + RunTree, + ...Inputs + ]; + } else { + wrappedFunctionInputs = rawInputs as Inputs; + } + if (parentRunTree == null) { + return wrappedFunc(...wrappedFunctionInputs); + } + let inputs: KVMap; + const firstWrappedFunctionInput = wrappedFunctionInputs[0]; + if (firstWrappedFunctionInput == null) { + inputs = {}; + } else if (wrappedFunctionInputs.length > 1) { + inputs = { args: wrappedFunctionInputs }; + } else if ( + typeof firstWrappedFunctionInput === "object" && + !Array.isArray(firstWrappedFunctionInput) && + // eslint-disable-next-line no-instanceof/no-instanceof + !(firstWrappedFunctionInput instanceof Date) + ) { + inputs = firstWrappedFunctionInput; + } else { + inputs = { input: firstWrappedFunctionInput }; + } - if ("root" in parentRun) { - Object.assign(currentRun, { ...params, inputs }); - } + const ensuredConfig = { name: "traced_function", config }; - const initialOutputs = currentRun.outputs; - const initialError = currentRun.error; + const currentRunTree: RunTree = + "root" in parentRunTree + ? (parentRunTree.root as RunTree) + : await parentRunTree.createChild({ + ...ensuredConfig, + inputs, + }); - try { - const rawOutput = await func(rawInput, currentRun); - const outputs: KVMap = - typeof rawOutput === "object" && - rawOutput != null && - !Array.isArray(rawOutput) - ? rawOutput - : { outputs: rawOutput }; + if ("root" in parentRunTree) { + Object.assign(currentRunTree, { ...ensuredConfig, inputs }); + } - if (initialOutputs === currentRun.outputs) { - await currentRun.end(outputs); - } else { - currentRun.end_time = Date.now(); - } + const initialOutputs = currentRunTree.outputs; + const initialError = currentRunTree.error; - return rawOutput; - } catch (error) { - if (initialError === currentRun.error) { - await currentRun.end(initialOutputs, String(error)); - } else { - currentRun.end_time = Date.now(); - } + try { + const rawOutput = await wrappedFunc(...wrappedFunctionInputs); + const outputs: KVMap = + typeof rawOutput === "object" && + rawOutput != null && + !Array.isArray(rawOutput) && + // eslint-disable-next-line no-instanceof/no-instanceof + !(rawOutput instanceof Date) + ? rawOutput + : { outputs: rawOutput }; - throw error; - } finally { - await currentRun.postRun(); + if (initialOutputs === currentRunTree.outputs) { + await currentRunTree.end(outputs); + } else { + currentRunTree.end_time = Date.now(); + } + return rawOutput; + } catch (error) { + if (initialError === currentRunTree.error) { + await currentRunTree.end(initialOutputs, String(error)); + } else { + currentRunTree.end_time = Date.now(); } - } - - Object.defineProperty(wrappedFunc, "langsmith:traceable", { - value: params, - }); - return wrappedFunc; + throw error; + } finally { + await currentRunTree.postRun(); + } + }; + traceableFunc.setParentRunTree = (parent: RunTree) => { + boundParentRunTree = parent; }; -}; + Object.defineProperty(wrappedFunc, "langsmith:traceable", { + value: config, + }); + return traceableFunc; +} diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 6b98ae661..f081593f1 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -205,3 +205,11 @@ export class RunTree implements BaseRun { await this.client.updateRun(this.id, runUpdate); } } + +export function isRunTree(x?: unknown): x is RunTree { + return ( + x !== undefined && + typeof (x as RunTree).createChild === "function" && + typeof (x as RunTree).postRun === "function" + ); +} diff --git a/js/src/tests/traceable_wrapper.int.test.ts b/js/src/tests/traceable_wrapper.int.test.ts new file mode 100644 index 000000000..d7549b71d --- /dev/null +++ b/js/src/tests/traceable_wrapper.int.test.ts @@ -0,0 +1,42 @@ +import { Client } from "../client.js"; +import { traceable } from "../run_helpers.js"; +import { RunTree } from "../run_trees.js"; +import { v4 as uuidv4 } from "uuid"; + +// async function deleteProject(langchainClient: Client, projectName: string) { +// try { +// await langchainClient.readProject({ projectName }); +// await langchainClient.deleteProject({ projectName }); +// } catch (e) { +// // Pass +// } +// } + +test.concurrent( + "Test traceable wrapper", + async () => { + const langchainClient = new Client({ + callerOptions: { maxRetries: 0 }, + }); + const testFunction = (a: string, b: number) => { + return a + b; + }; + const projectName = "__test_traceable_wrapper"; + const runId = uuidv4(); + const runTree = new RunTree({ + name: "Test Run Tree", + id: runId, + client: langchainClient, + project_name: projectName, + }); + + const traceableFunction = traceable(testFunction); + + traceableFunction.setParentRunTree(runTree); + + console.log(await traceableFunction("testing", 9)); + + // await deleteProject(langchainClient, projectName); + }, + 180_000 +);