From db62e94677e5232829ce3371e72e66e24d036291 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 14 Feb 2024 10:38:24 -0800 Subject: [PATCH] Update return type --- js/src/client.ts | 1 + js/src/run_helpers.ts | 53 ++++++++++------------ js/src/tests/traceable_wrapper.int.test.ts | 13 +++--- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/js/src/client.ts b/js/src/client.ts index 9770dc607..5a0f6401b 100644 --- a/js/src/client.ts +++ b/js/src/client.ts @@ -582,6 +582,7 @@ export class Client { runCreate, ]); + console.log(mergedRunCreateParams[0]); const response = await this.caller.call(fetch, `${this.apiUrl}/runs`, { method: "POST", headers, diff --git a/js/src/run_helpers.ts b/js/src/run_helpers.ts index 118b98526..7ce162df5 100644 --- a/js/src/run_helpers.ts +++ b/js/src/run_helpers.ts @@ -1,10 +1,10 @@ +import { v4 as uuidv4 } from "uuid"; import { RunTree, RunTreeConfig, isRunTree } from "./run_trees.js"; import { KVMap } from "./schemas.js"; -export type TraceableFunction = ( - rawInput: I, - parentRun: RunTree | { root: RunTree } | null -) => Promise; +export type TraceableFunction = ( + ...rawInputs: Inputs | [RunTree, ...Inputs] +) => Promise<[Output, RunTree]>; export function isTraceableFunction( x: unknown @@ -17,23 +17,31 @@ export function traceable( wrappedFunc: (...inputs: Inputs) => Output, config?: RunTreeConfig ) { - let boundParentRunTree: RunTree | undefined; - const traceableFunc = async ( + const traceableFunc: TraceableFunction = async ( ...rawInputs: Inputs | [RunTree, ...Inputs] - ): Promise => { - let parentRunTree: RunTree | undefined = boundParentRunTree; + ): Promise<[Output, RunTree]> => { + let inputRunTree: RunTree | undefined; + let currentRunTree: RunTree; let wrappedFunctionInputs: Inputs; + const ensuredConfig = { name: "traced_function", ...config }; if (isRunTree(rawInputs[0])) { - [parentRunTree, ...wrappedFunctionInputs] = rawInputs as [ + [inputRunTree, ...wrappedFunctionInputs] = rawInputs as [ RunTree, ...Inputs ]; + if ("root" in inputRunTree) { + currentRunTree = inputRunTree.root as RunTree; + } else { + currentRunTree = await inputRunTree.createChild(ensuredConfig); + } } else { wrappedFunctionInputs = rawInputs as Inputs; - } - if (parentRunTree == null) { - return wrappedFunc(...wrappedFunctionInputs); + currentRunTree = new RunTree({ + id: uuidv4(), + run_type: "chain", + ...ensuredConfig, + }); } let inputs: KVMap; const firstWrappedFunctionInput = wrappedFunctionInputs[0]; @@ -52,20 +60,12 @@ export function traceable( inputs = { input: firstWrappedFunctionInput }; } - const ensuredConfig = { name: "traced_function", config }; - - const currentRunTree: RunTree = - "root" in parentRunTree - ? (parentRunTree.root as RunTree) - : await parentRunTree.createChild({ - ...ensuredConfig, - inputs, - }); - - if ("root" in parentRunTree) { - Object.assign(currentRunTree, { ...ensuredConfig, inputs }); + if ("root" in currentRunTree) { + Object.assign(currentRunTree, { ...ensuredConfig }); } + currentRunTree.inputs = inputs; + const initialOutputs = currentRunTree.outputs; const initialError = currentRunTree.error; @@ -85,7 +85,7 @@ export function traceable( } else { currentRunTree.end_time = Date.now(); } - return rawOutput; + return [rawOutput, currentRunTree]; } catch (error) { if (initialError === currentRunTree.error) { await currentRunTree.end(initialOutputs, String(error)); @@ -98,9 +98,6 @@ export function traceable( await currentRunTree.postRun(); } }; - traceableFunc.setParentRunTree = (parent: RunTree) => { - boundParentRunTree = parent; - }; Object.defineProperty(wrappedFunc, "langsmith:traceable", { value: config, }); diff --git a/js/src/tests/traceable_wrapper.int.test.ts b/js/src/tests/traceable_wrapper.int.test.ts index d7549b71d..204aa58d1 100644 --- a/js/src/tests/traceable_wrapper.int.test.ts +++ b/js/src/tests/traceable_wrapper.int.test.ts @@ -23,18 +23,19 @@ test.concurrent( }; const projectName = "__test_traceable_wrapper"; const runId = uuidv4(); - const runTree = new RunTree({ - name: "Test Run Tree", + const rootRunTree = new RunTree({ + name: "test_run_tree", + run_type: "chain", id: runId, client: langchainClient, project_name: projectName, }); - const traceableFunction = traceable(testFunction); + const traceableFunction = traceable(testFunction, { name: "testinger" }); + // const openaiCompletions = traceable(openai.chat.completions.create); - traceableFunction.setParentRunTree(runTree); - - console.log(await traceableFunction("testing", 9)); + const [response, runTree] = await traceableFunction(rootRunTree, "testing", 9); + const [response2, runTree2] = await traceableFunction(runTree, "testing2", 10); // await deleteProject(langchainClient, projectName); },