Skip to content

Commit

Permalink
Update return type
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Feb 14, 2024
1 parent 4a7d16e commit db62e94
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 34 deletions.
1 change: 1 addition & 0 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ export class Client {
runCreate,
]);

console.log(mergedRunCreateParams[0]);
const response = await this.caller.call(fetch, `${this.apiUrl}/runs`, {
method: "POST",
headers,
Expand Down
53 changes: 25 additions & 28 deletions js/src/run_helpers.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { v4 as uuidv4 } from "uuid";
import { RunTree, RunTreeConfig, isRunTree } from "./run_trees.js";
import { KVMap } from "./schemas.js";

export type TraceableFunction<I, O> = (
rawInput: I,
parentRun: RunTree | { root: RunTree } | null
) => Promise<O>;
export type TraceableFunction<Inputs extends any[], Output> = (
...rawInputs: Inputs | [RunTree, ...Inputs]
) => Promise<[Output, RunTree]>;

export function isTraceableFunction(
x: unknown
Expand All @@ -17,23 +17,31 @@ export function traceable<Inputs extends any[], Output>(
wrappedFunc: (...inputs: Inputs) => Output,
config?: RunTreeConfig
) {
let boundParentRunTree: RunTree | undefined;
const traceableFunc = async (
const traceableFunc: TraceableFunction<Inputs, Output> = async (
...rawInputs: Inputs | [RunTree, ...Inputs]
): Promise<Output> => {
let parentRunTree: RunTree | undefined = boundParentRunTree;
): Promise<[Output, RunTree]> => {
let inputRunTree: RunTree | undefined;
let currentRunTree: RunTree;
let wrappedFunctionInputs: Inputs;
const ensuredConfig = { name: "traced_function", ...config };

if (isRunTree(rawInputs[0])) {
[parentRunTree, ...wrappedFunctionInputs] = rawInputs as [
[inputRunTree, ...wrappedFunctionInputs] = rawInputs as [
RunTree,
...Inputs
];
if ("root" in inputRunTree) {
currentRunTree = inputRunTree.root as RunTree;
} else {
currentRunTree = await inputRunTree.createChild(ensuredConfig);
}
} else {
wrappedFunctionInputs = rawInputs as Inputs;
}
if (parentRunTree == null) {
return wrappedFunc(...wrappedFunctionInputs);
currentRunTree = new RunTree({
id: uuidv4(),
run_type: "chain",
...ensuredConfig,
});
}
let inputs: KVMap;
const firstWrappedFunctionInput = wrappedFunctionInputs[0];
Expand All @@ -52,20 +60,12 @@ export function traceable<Inputs extends any[], Output>(
inputs = { input: firstWrappedFunctionInput };
}

const ensuredConfig = { name: "traced_function", config };

const currentRunTree: RunTree =
"root" in parentRunTree
? (parentRunTree.root as RunTree)
: await parentRunTree.createChild({
...ensuredConfig,
inputs,
});

if ("root" in parentRunTree) {
Object.assign(currentRunTree, { ...ensuredConfig, inputs });
if ("root" in currentRunTree) {
Object.assign(currentRunTree, { ...ensuredConfig });
}

currentRunTree.inputs = inputs;

const initialOutputs = currentRunTree.outputs;
const initialError = currentRunTree.error;

Expand All @@ -85,7 +85,7 @@ export function traceable<Inputs extends any[], Output>(
} else {
currentRunTree.end_time = Date.now();
}
return rawOutput;
return [rawOutput, currentRunTree];
} catch (error) {
if (initialError === currentRunTree.error) {
await currentRunTree.end(initialOutputs, String(error));
Expand All @@ -98,9 +98,6 @@ export function traceable<Inputs extends any[], Output>(
await currentRunTree.postRun();
}
};
traceableFunc.setParentRunTree = (parent: RunTree) => {
boundParentRunTree = parent;
};
Object.defineProperty(wrappedFunc, "langsmith:traceable", {
value: config,
});
Expand Down
13 changes: 7 additions & 6 deletions js/src/tests/traceable_wrapper.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ test.concurrent(
};
const projectName = "__test_traceable_wrapper";
const runId = uuidv4();
const runTree = new RunTree({
name: "Test Run Tree",
const rootRunTree = new RunTree({
name: "test_run_tree",
run_type: "chain",
id: runId,
client: langchainClient,
project_name: projectName,
});

const traceableFunction = traceable(testFunction);
const traceableFunction = traceable(testFunction, { name: "testinger" });
// const openaiCompletions = traceable(openai.chat.completions.create);

traceableFunction.setParentRunTree(runTree);

console.log(await traceableFunction("testing", 9));
const [response, runTree] = await traceableFunction(rootRunTree, "testing", 9);
const [response2, runTree2] = await traceableFunction(runTree, "testing2", 10);

// await deleteProject(langchainClient, projectName);
},
Expand Down

0 comments on commit db62e94

Please sign in to comment.