Skip to content

Commit

Permalink
New form factor for traceable wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Feb 14, 2024
1 parent 21dcbc7 commit 4a7d16e
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 60 deletions.
12 changes: 9 additions & 3 deletions js/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
export { Client } from "./client.js";

export { Dataset, Example, TracerSession, Run, Feedback } from "./schemas.js";
export type {
Dataset,
Example,
TracerSession,
Run,
Feedback,
} from "./schemas.js";

export { RunTree, RunTreeConfig } from "./run_trees.js";
export { RunTree, type RunTreeConfig } from "./run_trees.js";

export {
traceable,
TraceableFunction,
type TraceableFunction,
isTraceableFunction,
} from "./run_helpers.js";

Expand Down
141 changes: 84 additions & 57 deletions js/src/run_helpers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { RunTree, RunTreeConfig } from "./run_trees.js";
import { RunTree, RunTreeConfig, isRunTree } from "./run_trees.js";
import { KVMap } from "./schemas.js";

export type TraceableFunction<I, O> = (
Expand All @@ -13,69 +13,96 @@ export function isTraceableFunction(
return typeof x === "function" && "langsmith:traceable" in x;
}

export const traceable = (params: RunTreeConfig) => {
return <I, O>(func: (rawInput: I, parentRun: RunTree | null) => O) => {
async function wrappedFunc(
rawInput: I,
parentRun: RunTree | { root: RunTree } | null
) {
if (parentRun == null) {
return await func(rawInput, parentRun);
}
export function traceable<Inputs extends any[], Output>(
wrappedFunc: (...inputs: Inputs) => Output,
config?: RunTreeConfig
) {
let boundParentRunTree: RunTree | undefined;
const traceableFunc = async (
...rawInputs: Inputs | [RunTree, ...Inputs]
): Promise<Output> => {
let parentRunTree: RunTree | undefined = boundParentRunTree;
let wrappedFunctionInputs: Inputs;

const inputs: KVMap =
typeof rawInput === "object" && rawInput != null
? rawInput
: { input: rawInput };

const currentRun: RunTree =
"root" in parentRun
? parentRun.root
: await parentRun.createChild({
...params,
inputs,
});
if (isRunTree(rawInputs[0])) {
[parentRunTree, ...wrappedFunctionInputs] = rawInputs as [
RunTree,
...Inputs
];
} else {
wrappedFunctionInputs = rawInputs as Inputs;
}
if (parentRunTree == null) {
return wrappedFunc(...wrappedFunctionInputs);
}
let inputs: KVMap;
const firstWrappedFunctionInput = wrappedFunctionInputs[0];
if (firstWrappedFunctionInput == null) {
inputs = {};
} else if (wrappedFunctionInputs.length > 1) {
inputs = { args: wrappedFunctionInputs };
} else if (
typeof firstWrappedFunctionInput === "object" &&
!Array.isArray(firstWrappedFunctionInput) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(firstWrappedFunctionInput instanceof Date)
) {
inputs = firstWrappedFunctionInput;
} else {
inputs = { input: firstWrappedFunctionInput };
}

if ("root" in parentRun) {
Object.assign(currentRun, { ...params, inputs });
}
const ensuredConfig = { name: "traced_function", config };

const initialOutputs = currentRun.outputs;
const initialError = currentRun.error;
const currentRunTree: RunTree =
"root" in parentRunTree
? (parentRunTree.root as RunTree)
: await parentRunTree.createChild({
...ensuredConfig,
inputs,
});

try {
const rawOutput = await func(rawInput, currentRun);
const outputs: KVMap =
typeof rawOutput === "object" &&
rawOutput != null &&
!Array.isArray(rawOutput)
? rawOutput
: { outputs: rawOutput };
if ("root" in parentRunTree) {
Object.assign(currentRunTree, { ...ensuredConfig, inputs });
}

if (initialOutputs === currentRun.outputs) {
await currentRun.end(outputs);
} else {
currentRun.end_time = Date.now();
}
const initialOutputs = currentRunTree.outputs;
const initialError = currentRunTree.error;

return rawOutput;
} catch (error) {
if (initialError === currentRun.error) {
await currentRun.end(initialOutputs, String(error));
} else {
currentRun.end_time = Date.now();
}
try {
const rawOutput = await wrappedFunc(...wrappedFunctionInputs);
const outputs: KVMap =
typeof rawOutput === "object" &&
rawOutput != null &&
!Array.isArray(rawOutput) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(rawOutput instanceof Date)
? rawOutput
: { outputs: rawOutput };

throw error;
} finally {
await currentRun.postRun();
if (initialOutputs === currentRunTree.outputs) {
await currentRunTree.end(outputs);
} else {
currentRunTree.end_time = Date.now();
}
return rawOutput;
} catch (error) {
if (initialError === currentRunTree.error) {
await currentRunTree.end(initialOutputs, String(error));
} else {
currentRunTree.end_time = Date.now();
}
}

Object.defineProperty(wrappedFunc, "langsmith:traceable", {
value: params,
});

return wrappedFunc;
throw error;
} finally {
await currentRunTree.postRun();
}
};
traceableFunc.setParentRunTree = (parent: RunTree) => {
boundParentRunTree = parent;
};
};
Object.defineProperty(wrappedFunc, "langsmith:traceable", {
value: config,
});
return traceableFunc;
}
8 changes: 8 additions & 0 deletions js/src/run_trees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,11 @@ export class RunTree implements BaseRun {
await this.client.updateRun(this.id, runUpdate);
}
}

export function isRunTree(x?: unknown): x is RunTree {
return (
x !== undefined &&
typeof (x as RunTree).createChild === "function" &&
typeof (x as RunTree).postRun === "function"
);
}
42 changes: 42 additions & 0 deletions js/src/tests/traceable_wrapper.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { Client } from "../client.js";
import { traceable } from "../run_helpers.js";
import { RunTree } from "../run_trees.js";
import { v4 as uuidv4 } from "uuid";

// async function deleteProject(langchainClient: Client, projectName: string) {
// try {
// await langchainClient.readProject({ projectName });
// await langchainClient.deleteProject({ projectName });
// } catch (e) {
// // Pass
// }
// }

test.concurrent(
"Test traceable wrapper",
async () => {
const langchainClient = new Client({
callerOptions: { maxRetries: 0 },
});
const testFunction = (a: string, b: number) => {
return a + b;
};
const projectName = "__test_traceable_wrapper";
const runId = uuidv4();
const runTree = new RunTree({
name: "Test Run Tree",
id: runId,
client: langchainClient,
project_name: projectName,
});

const traceableFunction = traceable(testFunction);

traceableFunction.setParentRunTree(runTree);

console.log(await traceableFunction("testing", 9));

// await deleteProject(langchainClient, projectName);
},
180_000
);

0 comments on commit 4a7d16e

Please sign in to comment.