diff --git a/js/src/schemas.ts b/js/src/schemas.ts index 274a76bb6..7dc9562d8 100644 --- a/js/src/schemas.ts +++ b/js/src/schemas.ts @@ -502,3 +502,82 @@ export interface RunWithAnnotationQueueInfo extends BaseRun { /** The time this run was added to the queue. */ added_at?: string; } + +/** + * Breakdown of input token counts. + * + * Does not *need* to sum to full input token count. Does *not* need to have all keys. + */ +export type InputTokenDetails = { + /** + * Audio input tokens. + */ + audio?: number; + + /** + * Input tokens that were cached and there was a cache hit. + * + * Since there was a cache hit, the tokens were read from the cache. + * More precisely, the model state given these tokens was read from the cache. + */ + cache_read?: number; + + /** + * Input tokens that were cached and there was a cache miss. + * + * Since there was a cache miss, the cache was created from these tokens. + */ + cache_creation?: number; +}; + +/** + * Breakdown of output token counts. + * + * Does *not* need to sum to full output token count. Does *not* need to have all keys. + */ +export type OutputTokenDetails = { + /** + * Audio output tokens + */ + audio?: number; + + /** + * Reasoning output tokens. + * + * Tokens generated by the model in a chain of thought process (i.e. by + * OpenAI's o1 models) that are not returned as part of model output. + */ + reasoning?: number; +}; + +/** + * Usage metadata for a message, such as token counts. + */ +export type UsageMetadata = { + /** + * Count of input (or prompt) tokens. Sum of all input token types. + */ + input_tokens: number; + /** + * Count of output (or completion) tokens. Sum of all output token types. + */ + output_tokens: number; + /** + * Total token count. Sum of input_tokens + output_tokens. + */ + total_tokens: number; + + /** + * Breakdown of input token counts. + * + * Does *not* need to sum to full input token count. Does *not* need to have all keys. + */ + input_token_details?: InputTokenDetails; + + /** + * Breakdown of output token counts. + * + * Does *not* need to sum to full output token count. Does *not* need to have all keys. + */ + output_token_details?: OutputTokenDetails; +}; diff --git a/js/src/tests/test_data/langsmith_js_wrap_openai_default.json b/js/src/tests/test_data/langsmith_js_wrap_openai_default.json new file mode 100644 index 000000000..722d74992 --- /dev/null +++ b/js/src/tests/test_data/langsmith_js_wrap_openai_default.json @@ -0,0 +1,98 @@ +{ + "post": [ + { + "session_name": "default", + "id": "dc34609e-3eeb-459d-bc2a-6fedb01d2e6e", + "name": "ChatOpenAI", + "start_time": 1728803137170, + "run_type": "llm", + "extra": { + "metadata": { + "ls_provider": "openai", + "ls_model_type": "chat", + "ls_model_name": "gpt-4o-mini" + }, + "runtime": { + "library": "langsmith", + "runtime": "node", + "sdk": "langsmith-js", + "sdk_version": "0.1.65" + } + }, + "serialized": {}, + "inputs": { + "model": "gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "howdy" + } + ] + }, + "child_runs": [], + "trace_id": "dc34609e-3eeb-459d-bc2a-6fedb01d2e6e", + "dotted_order": "20241013T070537170001Zdc34609e-3eeb-459d-bc2a-6fedb01d2e6e", + "tags": [] + } + ], + "patch": [ + { + "end_time": 1728803138285, + "inputs": { + "model": "gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "howdy" + } + ] + }, + "outputs": { + "id": "chatcmpl-AHmxWgRAkoZJCaH30D7gz5t1OCc30", + "object": "chat.completion", + "created": 1728803138, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Howdy! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "system_fingerprint": "fp_e2bde53e6e", + "usage_metadata": { + "input_tokens": 9, + "output_tokens": 9, + "total_tokens": 18, + "input_token_details": { + "cache_read": 0 + }, + "output_token_details": { + "reasoning": 0 + } + } + }, + "extra": { + "metadata": { + "ls_provider": "openai", + "ls_model_type": "chat", + "ls_model_name": "gpt-4o-mini" + }, + "runtime": { + "library": "langsmith", + "runtime": "node", + "sdk": "langsmith-js", + "sdk_version": "0.1.65" + } + }, + "dotted_order": "20241013T070537170001Zdc34609e-3eeb-459d-bc2a-6fedb01d2e6e", + "trace_id": "dc34609e-3eeb-459d-bc2a-6fedb01d2e6e", + "tags": [] + } + ] +} \ No newline at end of file diff --git a/js/src/tests/test_data/langsmith_js_wrap_openai_reasoning.json b/js/src/tests/test_data/langsmith_js_wrap_openai_reasoning.json new file mode 100644 index 000000000..6435c469f --- /dev/null +++ b/js/src/tests/test_data/langsmith_js_wrap_openai_reasoning.json @@ -0,0 +1,97 @@ +{ + "post": [ + { + "session_name": "default", + "id": "e954b8e3-c337-4a05-bf0a-ca2baac3ba48", + "name": "ChatOpenAI", + "start_time": 1728803138291, + "run_type": "llm", + "extra": { + "metadata": { + "ls_provider": "openai", + "ls_model_type": "chat", + "ls_model_name": "o1-mini" + }, + "runtime": { + "library": "langsmith", + "runtime": "node", + "sdk": "langsmith-js", + "sdk_version": "0.1.65" + } + }, + "serialized": {}, + "inputs": { + "model": "o1-mini", + "messages": [ + { + "role": "user", + "content": "Write a bash script that takes a matrix represented as a string with format '[1,2],[3,4],[5,6]' and prints the transpose in the same format." + } + ] + }, + "child_runs": [], + "trace_id": "e954b8e3-c337-4a05-bf0a-ca2baac3ba48", + "dotted_order": "20241013T070538291001Ze954b8e3-c337-4a05-bf0a-ca2baac3ba48", + "tags": [] + } + ], + "patch": [ + { + "end_time": 1728803148730, + "inputs": { + "model": "o1-mini", + "messages": [ + { + "role": "user", + "content": "Write a bash script that takes a matrix represented as a string with format '[1,2],[3,4],[5,6]' and prints the transpose in the same format." + } + ] + }, + "outputs": { + "id": "chatcmpl-AHmxWXSp7oUeT2kkt7kuhLOmfLKDJ", + "object": "chat.completion", + "created": 1728803138, + "model": "o1-mini-2024-09-12", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Certainly! Below is a Bash script that takes a matrix represented as a string in the format `\"[1,2],[3,4],[5,6]\"` and outputs its transpose in the same format.\n\n### Script: `transpose_matrix.sh`\n\n```bash\n#!/bin/bash\n\n# Check if an argument is provided\nif [ $# -ne 1 ]; then\n echo \"Usage: $0 \\\"[1,2],[3,4],[5,6]\\\"\"\n exit 1\nfi\n\ninput=\"$1\"\n\n# Function to trim leading and trailing brackets\ntrim_brackets() {\n local str=\"$1\"\n str=\"${str#[}\"\n str=\"${str%]}\"\n echo \"$str\"\n}\n\n# Remove outer brackets if present and split the matrix into rows\ntrimmed_input=$(trim_brackets \"$input\")\nIFS=\"],[\" read -r -a rows <<< \"$trimmed_input\"\n\n# Initialize an array of arrays to hold the matrix\ndeclare -a matrix\nnum_cols=0\n\n# Parse each row into the matrix array\nfor row in \"${rows[@]}\"; do\n IFS=',' read -r -a cols <<< \"$row\"\n matrix+=(\"${cols[@]}\")\n if [ \"${#cols[@]}\" -gt \"$num_cols\" ]; then\n num_cols=\"${#cols[@]}\"\n fi\ndone\n\n# Determine the number of rows\nnum_rows=\"${#rows[@]}\"\n\n# Initialize an array to hold the transposed matrix\ndeclare -a transpose\n\n# Build the transpose by iterating over columns and rows\nfor ((c=0; c { const { client, callSpy } = mockClient(); @@ -1129,3 +1130,336 @@ test("traceable continues execution when client throws error", async () => { expect(errorClient.createRun).toHaveBeenCalled(); expect(errorClient.updateRun).toHaveBeenCalled(); }); + +test("traceable with processInputs", async () => { + const { client, callSpy } = mockClient(); + + const processInputs = jest.fn((inputs: Readonly) => { + return { ...inputs, password: "****" }; + }); + + const func = traceable( + async function func(input: { username: string; password: string }) { + // The function should receive the original inputs + expect(input.password).toBe("secret"); + return `Welcome, ${input.username}`; + }, + { + client, + tracingEnabled: true, + processInputs, + } + ); + + await func({ username: "user1", password: "secret" }); + + expect(processInputs).toHaveBeenCalledWith({ + username: "user1", + password: "secret", + }); + // Verify that the logged inputs have the password masked + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + inputs: { + username: "user1", + password: "****", + }, + outputs: { outputs: "Welcome, user1" }, + }, + }, + }); +}); + +test("traceable with processOutputs", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((_outputs: Readonly) => { + return { outputs: "Modified Output" }; + }); + + const func = traceable( + async function func(input: string) { + return `Original Output for ${input}`; + }, + { + client, + tracingEnabled: true, + processOutputs, + } + ); + + const result = await func("test"); + + expect(processOutputs).toHaveBeenCalledWith({ + outputs: "Original Output for test", + }); + expect(result).toBe("Original Output for test"); + // Verify that the tracing data shows the modified output + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + inputs: { input: "test" }, + outputs: { outputs: "Modified Output" }, + }, + }, + }); +}); + +test("traceable with processInputs throwing error does not affect invocation", async () => { + const { client, callSpy } = mockClient(); + + const processInputs = jest.fn((_inputs: Readonly) => { + throw new Error("processInputs error"); + }); + + const func = traceable( + async function func(input: { username: string }) { + // This should not be called + return `Hello, ${input.username}`; + }, + { + client, + tracingEnabled: true, + processInputs, + } + ); + + const result = await func({ username: "user1" }); + + expect(processInputs).toHaveBeenCalledWith({ username: "user1" }); + expect(result).toBe("Hello, user1"); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + inputs: { username: "user1" }, + outputs: { outputs: "Hello, user1" }, + }, + }, + }); +}); + +test("traceable with processOutputs throwing error does not affect invocation", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((_outputs: Readonly) => { + throw new Error("processOutputs error"); + }); + + const func = traceable( + async function func(input: string) { + return `Original Output for ${input}`; + }, + { + client, + tracingEnabled: true, + processOutputs, + } + ); + + const result = await func("test"); + + expect(processOutputs).toHaveBeenCalledWith({ + outputs: "Original Output for test", + }); + expect(result).toBe("Original Output for test"); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + inputs: { input: "test" }, + outputs: { outputs: "Original Output for test" }, + }, + }, + }); +}); + +test("traceable async generator with processOutputs", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((outputs: Readonly) => { + return { outputs: outputs.outputs.map((output: number) => output * 2) }; + }); + + const func = traceable( + async function* func() { + for (let i = 1; i <= 3; i++) { + yield i; + } + }, + { + client, + tracingEnabled: true, + processOutputs, + } + ); + + const results: number[] = []; + for await (const value of func()) { + results.push(value); + } + + expect(results).toEqual([1, 2, 3]); // Original values + expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] }); + + // Tracing data should reflect the processed outputs + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + outputs: { outputs: [2, 4, 6] }, // Processed outputs + }, + }, + }); +}); + +test("traceable function returning object with async iterable and processOutputs", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((outputs: Readonly) => { + return { outputs: outputs.outputs.map((output: number) => output * 2) }; + }); + + const func = traceable( + async function func() { + return { + data: "some data", + stream: (async function* () { + for (let i = 1; i <= 3; i++) { + yield i; + } + })(), + }; + }, + { + client, + tracingEnabled: true, + processOutputs, + __finalTracedIteratorKey: "stream", + } + ); + + const result = await func(); + expect(result.data).toBe("some data"); + + const results: number[] = []; + for await (const value of result.stream) { + results.push(value); + } + + expect(results).toEqual([1, 2, 3]); + expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] }); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + outputs: { outputs: [2, 4, 6] }, + }, + }, + }); +}); + +test("traceable generator function with processOutputs", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((outputs: Readonly) => { + return { outputs: outputs.outputs.map((output: number) => output * 2) }; + }); + + function* func() { + for (let i = 1; i <= 3; i++) { + yield i; + } + } + + const tracedFunc = traceable(func, { + client, + tracingEnabled: true, + processOutputs, + }); + + const results: number[] = []; + for (const value of await tracedFunc()) { + results.push(value); + } + + expect(results).toEqual([1, 2, 3]); + expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] }); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + outputs: { outputs: [2, 4, 6] }, + }, + }, + }); +}); + +test("traceable with complex outputs", async () => { + const { client, callSpy } = mockClient(); + + const processOutputs = jest.fn((outputs: Readonly) => { + return { data: "****", output: outputs.output, nested: outputs.nested }; + }); + + const func = traceable( + async function func(input: string) { + return { + data: "some sensitive data", + output: `Original Output for ${input}`, + nested: { + key: "value", + nestedOutput: `Nested Output for ${input}`, + }, + }; + }, + { + client, + tracingEnabled: true, + processOutputs, + } + ); + + const result = await func("test"); + + expect(result).toEqual({ + data: "some sensitive data", + output: "Original Output for test", + nested: { + key: "value", + nestedOutput: "Nested Output for test", + }, + }); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["func:0"], + edges: [], + data: { + "func:0": { + inputs: { input: "test" }, + outputs: { + data: "****", + output: "Original Output for test", + nested: { + key: "value", + nestedOutput: "Nested Output for test", + }, + }, + }, + }, + }); +}); diff --git a/js/src/tests/wrapped_openai.int.test.ts b/js/src/tests/wrapped_openai.int.test.ts index f4c2829bc..a8b83144f 100644 --- a/js/src/tests/wrapped_openai.int.test.ts +++ b/js/src/tests/wrapped_openai.int.test.ts @@ -8,6 +8,8 @@ import { mockClient } from "./utils/mock_client.js"; import { getAssumedTreeFromCalls } from "./utils/tree.js"; import { zodResponseFormat } from "openai/helpers/zod"; import { z } from "zod"; +import { UsageMetadata } from "../schemas.js"; +import fs from "fs"; test("wrapOpenAI should return type compatible with OpenAI", async () => { let originalClient = new OpenAI(); @@ -574,3 +576,135 @@ test.concurrent("beta.chat.completions.parse", async () => { } callSpy.mockClear(); }); + +const usageMetadataTestCases = [ + { + description: "stream", + params: { + model: "gpt-4o-mini", + messages: [{ role: "user", content: "howdy" }], + stream: true, + stream_options: { include_usage: true }, + }, + expectUsageMetadata: true, + }, + { + description: "stream no usage", + params: { + model: "gpt-4o-mini", + messages: [{ role: "user", content: "howdy" }], + stream: true, + }, + expectUsageMetadata: false, + }, + { + description: "default", + params: { + model: "gpt-4o-mini", + messages: [{ role: "user", content: "howdy" }], + }, + expectUsageMetadata: true, + }, + { + description: "reasoning", + params: { + model: "o1-mini", + messages: [ + { + role: "user", + content: + "Write a bash script that takes a matrix represented as a string with format '[1,2],[3,4],[5,6]' and prints the transpose in the same format.", + }, + ], + }, + expectUsageMetadata: true, + checkReasoningTokens: true, + }, +]; + +describe("Usage Metadata Tests", () => { + usageMetadataTestCases.forEach( + ({ description, params, expectUsageMetadata, checkReasoningTokens }) => { + it(`should handle ${description}`, async () => { + const { client, callSpy } = mockClient(); + const openai = wrapOpenAI(new OpenAI(), { + tracingEnabled: true, + client, + }); + + const requestParams = { ...params }; + + let oaiUsage: OpenAI.CompletionUsage | undefined; + if (requestParams.stream) { + const stream = await openai.chat.completions.create( + requestParams as OpenAI.ChatCompletionCreateParamsStreaming + ); + for await (const chunk of stream) { + if (expectUsageMetadata && chunk.usage) { + oaiUsage = chunk.usage; + } + } + } else { + const res = await openai.chat.completions.create( + requestParams as OpenAI.ChatCompletionCreateParams + ); + oaiUsage = (res as OpenAI.ChatCompletion).usage; + } + + let usageMetadata: UsageMetadata | undefined; + const requestBodies: any = {}; + for (const call of callSpy.mock.calls) { + const request = call[2] as any; + const requestBody = JSON.parse(request.body); + if (request.method === "POST") { + requestBodies["post"] = [requestBody]; + } + if (request.method === "PATCH") { + requestBodies["patch"] = [requestBody]; + } + if (requestBody.outputs && requestBody.outputs.usage_metadata) { + usageMetadata = requestBody.outputs.usage_metadata; + break; + } + } + + if (expectUsageMetadata) { + expect(usageMetadata).not.toBeUndefined(); + expect(usageMetadata).not.toBeNull(); + expect(oaiUsage).not.toBeUndefined(); + expect(oaiUsage).not.toBeNull(); + expect(usageMetadata!.input_tokens).toEqual(oaiUsage!.prompt_tokens); + expect(usageMetadata!.output_tokens).toEqual( + oaiUsage!.completion_tokens + ); + expect(usageMetadata!.total_tokens).toEqual(oaiUsage!.total_tokens); + + if (checkReasoningTokens) { + expect(usageMetadata!.output_token_details).not.toBeUndefined(); + expect( + usageMetadata!.output_token_details!.reasoning + ).not.toBeUndefined(); + expect(usageMetadata!.output_token_details!.reasoning).toEqual( + oaiUsage!.completion_tokens_details?.reasoning_tokens + ); + } + } else { + expect(usageMetadata).toBeUndefined(); + expect(oaiUsage).toBeUndefined(); + } + + if (process.env.WRITE_TOKEN_COUNTING_TEST_DATA === "1") { + fs.writeFileSync( + `${__dirname}/test_data/langsmith_js_wrap_openai_${description.replace( + " ", + "_" + )}.json`, + JSON.stringify(requestBodies, null, 2) + ); + } + + callSpy.mockClear(); + }); + } + ); +}); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index 0934a55df..b8d48c663 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -29,29 +29,55 @@ AsyncLocalStorageProviderSingleton.initializeGlobalInstance( new AsyncLocalStorage() ); -const handleRunInputs = (rawInputs: unknown[]): KVMap => { +const handleRunInputs = ( + rawInputs: unknown[], + processInputs: (inputs: Readonly) => KVMap +): KVMap => { const firstInput = rawInputs[0]; + let inputs: KVMap; if (firstInput == null) { - return {}; + inputs = {}; + } else if (rawInputs.length > 1) { + inputs = { args: rawInputs }; + } else if (isKVMap(firstInput)) { + inputs = firstInput; + } else { + inputs = { input: firstInput }; } - if (rawInputs.length > 1) { - return { args: rawInputs }; + try { + return processInputs(inputs); + } catch (e) { + console.error( + "Error occurred during processInputs. Sending raw inputs:", + e + ); + return inputs; } - - if (isKVMap(firstInput)) { - return firstInput; - } - - return { input: firstInput }; }; -const handleRunOutputs = (rawOutputs: unknown): KVMap => { +const handleRunOutputs = ( + rawOutputs: unknown, + processOutputs: (outputs: Readonly) => KVMap +): KVMap => { + let outputs: KVMap; + if (isKVMap(rawOutputs)) { - return rawOutputs; + outputs = rawOutputs; + } else { + outputs = { outputs: rawOutputs }; + } + + try { + return processOutputs(outputs); + } catch (e) { + console.error( + "Error occurred during processOutputs. Sending raw outputs:", + e + ); + return outputs; } - return { outputs: rawOutputs }; }; const getTracingRunTree = ( @@ -59,13 +85,14 @@ const getTracingRunTree = ( inputs: Args, getInvocationParams: | ((...args: Args) => InvocationParamsSchema | undefined) - | undefined + | undefined, + processInputs: (inputs: Readonly) => KVMap ): RunTree | undefined => { if (!isTracingEnabled(runTree.tracingEnabled)) { return undefined; } - runTree.inputs = handleRunInputs(inputs); + runTree.inputs = handleRunInputs(inputs, processInputs); const invocationParams = getInvocationParams?.(...inputs); if (invocationParams != null) { @@ -293,6 +320,26 @@ export function traceable any>( getInvocationParams?: ( ...args: Parameters ) => InvocationParamsSchema | undefined; + + /** + * Apply transformations to the inputs before logging. + * This function should NOT mutate the inputs. + * `processInputs` is not inherited by nested traceable functions. + * + * @param inputs Key-value map of the function inputs. + * @returns Transformed key-value map + */ + processInputs?: (inputs: Readonly) => KVMap; + + /** + * Apply transformations to the outputs before logging. + * This function should NOT mutate the outputs. + * `processOutputs` is not inherited by nested traceable functions. + * + * @param outputs Key-value map of the function outputs + * @returns Transformed key-value map + */ + processOutputs?: (outputs: Readonly) => KVMap; } ) { type Inputs = Parameters; @@ -300,9 +347,14 @@ export function traceable any>( aggregator, argsConfigPath, __finalTracedIteratorKey, + processInputs, + processOutputs, ...runTreeConfig } = config ?? {}; + const processInputsFn = processInputs ?? ((x) => x); + const processOutputsFn = processOutputs ?? ((x) => x); + const traceableFunc = ( ...args: Inputs | [RunTree, ...Inputs] | [RunnableConfigLike, ...Inputs] ) => { @@ -374,7 +426,8 @@ export function traceable any>( getTracingRunTree( RunTree.fromRunnableConfig(firstArg, ensuredConfig), restArgs as Inputs, - config?.getInvocationParams + config?.getInvocationParams, + processInputsFn ), restArgs as Inputs, ]; @@ -398,7 +451,8 @@ export function traceable any>( ? new RunTree(ensuredConfig) : firstArg.createChild(ensuredConfig), restArgs as Inputs, - config?.getInvocationParams + config?.getInvocationParams, + processInputsFn ); return [currentRunTree, [currentRunTree, ...restArgs] as Inputs]; @@ -412,7 +466,8 @@ export function traceable any>( getTracingRunTree( prevRunFromStore.createChild(ensuredConfig), processedArgs, - config?.getInvocationParams + config?.getInvocationParams, + processInputsFn ), processedArgs as Inputs, ]; @@ -421,7 +476,8 @@ export function traceable any>( const currentRunTree = getTracingRunTree( new RunTree(ensuredConfig), processedArgs, - config?.getInvocationParams + config?.getInvocationParams, + processInputsFn ); // If a context var is set by LangChain outside of a traceable, // it will be an object with a single property and we should copy @@ -470,7 +526,7 @@ export function traceable any>( if (result.done) { finished = true; await currentRunTree?.end( - handleRunOutputs(await handleChunks(chunks)) + handleRunOutputs(await handleChunks(chunks), processOutputsFn) ); await handleEnd(); controller.close(); @@ -483,7 +539,7 @@ export function traceable any>( async cancel(reason) { if (!finished) await currentRunTree?.end(undefined, "Cancelled"); await currentRunTree?.end( - handleRunOutputs(await handleChunks(chunks)) + handleRunOutputs(await handleChunks(chunks), processOutputsFn) ); await handleEnd(); return reader.cancel(reason); @@ -517,7 +573,7 @@ export function traceable any>( } finally { if (!finished) await currentRunTree?.end(undefined, "Cancelled"); await currentRunTree?.end( - handleRunOutputs(await handleChunks(chunks)) + handleRunOutputs(await handleChunks(chunks), processOutputsFn) ); await handleEnd(); } @@ -640,7 +696,8 @@ export function traceable any>( return memo; }, []) - ) + ), + processOutputsFn ) ); await handleEnd(); @@ -657,7 +714,9 @@ export function traceable any>( } try { - await currentRunTree?.end(handleRunOutputs(rawOutput)); + await currentRunTree?.end( + handleRunOutputs(rawOutput, processOutputsFn) + ); await handleEnd(); } finally { // eslint-disable-next-line no-unsafe-finally diff --git a/js/src/wrappers/openai.ts b/js/src/wrappers/openai.ts index 6164ceca6..9987e0adb 100644 --- a/js/src/wrappers/openai.ts +++ b/js/src/wrappers/openai.ts @@ -2,6 +2,7 @@ import { OpenAI } from "openai"; import type { APIPromise } from "openai/core"; import type { RunTreeConfig } from "../index.js"; import { isTraceableFunction, traceable } from "../traceable.js"; +import { KVMap } from "../schemas.js"; // Extra leniency around types in case multiple OpenAI SDK versions get installed type OpenAIType = { @@ -187,6 +188,44 @@ const textAggregator = ( return aggregatedOutput; }; +function processChatCompletion(outputs: Readonly): KVMap { + const chatCompletion = outputs as OpenAI.ChatCompletion; + // copy the original object, minus usage + const result = { ...chatCompletion } as KVMap; + const usage = chatCompletion.usage; + if (usage) { + const inputTokenDetails = { + ...(usage.prompt_tokens_details?.audio_tokens !== null && { + audio: usage.prompt_tokens_details?.audio_tokens, + }), + ...(usage.prompt_tokens_details?.cached_tokens !== null && { + cache_read: usage.prompt_tokens_details?.cached_tokens, + }), + }; + const outputTokenDetails = { + ...(usage.completion_tokens_details?.audio_tokens !== null && { + audio: usage.completion_tokens_details?.audio_tokens, + }), + ...(usage.completion_tokens_details?.reasoning_tokens !== null && { + reasoning: usage.completion_tokens_details?.reasoning_tokens, + }), + }; + result.usage_metadata = { + input_tokens: usage.prompt_tokens ?? 0, + output_tokens: usage.completion_tokens ?? 0, + total_tokens: usage.total_tokens ?? 0, + ...(Object.keys(inputTokenDetails).length > 0 && { + input_token_details: inputTokenDetails, + }), + ...(Object.keys(outputTokenDetails).length > 0 && { + output_token_details: outputTokenDetails, + }), + }; + } + delete result.usage; + return result; +} + /** * Wraps an OpenAI client's completion methods, enabling automatic LangSmith * tracing. Method signatures are unchanged, with the exception that you can pass @@ -307,6 +346,7 @@ export const wrapOpenAI = ( ls_stop, }; }, + processOutputs: processChatCompletion, ...options, } ),