diff --git a/js/src/client.ts b/js/src/client.ts index fd4caec5d..ed74f62f2 100644 --- a/js/src/client.ts +++ b/js/src/client.ts @@ -34,6 +34,7 @@ import { ValueType, AnnotationQueue, RunWithAnnotationQueueInfo, + Attachments, } from "./schemas.js"; import { convertLangChainMessageToExample, @@ -240,7 +241,7 @@ interface CreateRunParams { revision_id?: string; trace_id?: string; dotted_order?: string; - attachments?: Record; + attachments?: Attachments; } interface UpdateRunParams extends RunUpdate { @@ -1032,10 +1033,7 @@ export class Client { return; } // transform and convert to dicts - const allAttachments: Record< - string, - Record - > = {}; + const allAttachments: Record = {}; let preparedCreateParams = []; for (const create of runCreates ?? []) { const preparedCreate = this.prepareRunCreateOrUpdateInputs(create); @@ -1048,7 +1046,6 @@ export class Client { delete preparedCreate.attachments; preparedCreateParams.push(preparedCreate); } - let preparedUpdateParams = []; for (const update of runUpdates ?? []) { preparedUpdateParams.push(this.prepareRunCreateOrUpdateInputs(update)); @@ -1116,7 +1113,8 @@ export class Client { ]) { for (const originalPayload of payloads) { // collect fields to be sent as separate parts - const { inputs, outputs, events, ...payload } = originalPayload; + const { inputs, outputs, events, attachments, ...payload } = + originalPayload; const fields = { inputs, outputs, events }; // encode the main run payload const stringifiedPayload = stringifyForTracing(payload); @@ -1147,10 +1145,18 @@ export class Client { for (const [name, [contentType, content]] of Object.entries( attachments )) { + // Validate that the attachment name doesn't contain a '.' + if (name.includes(".")) { + console.warn( + `Skipping attachment '${name}' for run ${payload.id}: Invalid attachment name. ` + + `Attachment names must not contain periods ('.'). Please rename the attachment and try again.` + ); + continue; + } accumulatedParts.push({ name: `attachment.${payload.id}.${name}`, payload: new Blob([content], { - type: `${contentType}; length=${content.length}`, + type: `${contentType}; length=${content.byteLength}`, }), }); } @@ -1172,6 +1178,7 @@ export class Client { for (const part of parts) { formData.append(part.name, part.payload); } + // Log the form data await this.batchIngestCaller.call( _getFetchImplementation(), `${this.apiUrl}/runs/multipart`, diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 97a33a19c..dc92cbac1 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -1,5 +1,11 @@ import * as uuid from "uuid"; -import { BaseRun, KVMap, RunCreate, RunUpdate } from "./schemas.js"; +import { + Attachments, + BaseRun, + KVMap, + RunCreate, + RunUpdate, +} from "./schemas.js"; import { RuntimeEnvironment, getEnvironmentVariable, @@ -55,6 +61,7 @@ export interface RunTreeConfig { trace_id?: string; dotted_order?: string; + attachments?: Attachments; } export interface RunnableConfigLike { @@ -172,6 +179,11 @@ export class RunTree implements BaseRun { tracingEnabled?: boolean; execution_order: number; child_execution_order: number; + /** + * Attachments associated with the run. + * Each entry is a tuple of [mime_type, bytes] + */ + attachments?: Attachments; constructor(originalConfig: RunTreeConfig | RunTree) { // If you pass in a run tree directly, return a shallow clone @@ -370,6 +382,7 @@ export class RunTree implements BaseRun { trace_id: run.trace_id, dotted_order: run.dotted_order, tags: run.tags, + attachments: run.attachments, }; return persistedRun; } @@ -407,6 +420,7 @@ export class RunTree implements BaseRun { dotted_order: this.dotted_order, trace_id: this.trace_id, tags: this.tags, + attachments: this.attachments, }; await this.client.updateRun(this.id, runUpdate); diff --git a/js/src/schemas.ts b/js/src/schemas.ts index 1e899bc2c..26afd7fc0 100644 --- a/js/src/schemas.ts +++ b/js/src/schemas.ts @@ -63,6 +63,9 @@ export interface BaseExample { source_run_id?: string; } +export type AttachmentData = Uint8Array | ArrayBuffer; +export type Attachments = Record; + /** * A run can represent either a trace (root run) * or a child run (~span). @@ -131,7 +134,7 @@ export interface BaseRun { * Attachments associated with the run. * Each entry is a tuple of [mime_type, bytes] */ - attachments?: Record; + attachments?: Attachments; } type S3URL = { @@ -231,6 +234,12 @@ export interface RunUpdate { * - 20230915T223155647Z1b64098b-4ab7-43f6-afee-992304f198d8.20230914T223155650Zc8d9f4c5-6c5a-4b2d-9b1c-3d9d7a7c5c7c */ dotted_order?: string; + + /** + * Attachments associated with the run. + * Each entry is a tuple of [mime_type, bytes] + */ + attachments?: Attachments; } export interface ExampleCreate extends BaseExample { diff --git a/js/src/tests/traceable.int.test.ts b/js/src/tests/traceable.int.test.ts index 4458e7222..bb3cfd7a5 100644 --- a/js/src/tests/traceable.int.test.ts +++ b/js/src/tests/traceable.int.test.ts @@ -12,11 +12,12 @@ import { import { RunTree } from "../run_trees.js"; import { BaseRun } from "../schemas.js"; import { expect } from "@jest/globals"; +import { jest } from "@jest/globals"; -async function deleteProject(langchainClient: Client, projectName: string) { +async function deleteProject(langsmithClient: Client, projectName: string) { try { - await langchainClient.readProject({ projectName }); - await langchainClient.deleteProject({ projectName }); + await langsmithClient.readProject({ projectName }); + await langsmithClient.deleteProject({ projectName }); } catch (e) { // Pass } @@ -64,7 +65,7 @@ async function waitUntilRunFound( } test.concurrent("Test traceable wrapper with error thrown", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -80,7 +81,7 @@ test.concurrent("Test traceable wrapper with error thrown", async () => { { name: "add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, on_end: _getRun, tracingEnabled: true, @@ -95,15 +96,15 @@ test.concurrent("Test traceable wrapper with error thrown", async () => { } expect(collectedRun).not.toBeNull(); expect(collectedRun!.error).toEqual("Error: I am bad"); - await waitUntilRunFound(langchainClient, runId); - const storedRun = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun = await langsmithClient.readRun(runId); expect(storedRun.id).toEqual(runId); expect(storedRun.status).toEqual("error"); expect(storedRun.error).toEqual("Error: I am bad"); }); test.concurrent("Test traceable wrapper with async error thrown", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -119,7 +120,7 @@ test.concurrent("Test traceable wrapper with async error thrown", async () => { { name: "add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, on_end: _getRun, tracingEnabled: true, @@ -136,8 +137,8 @@ test.concurrent("Test traceable wrapper with async error thrown", async () => { expect(collectedRun).not.toBeNull(); expect(collectedRun!.error).toEqual("Error: I am bad"); expect(collectedRun!.inputs).toEqual({ args: ["testing", 9] }); - await waitUntilRunFound(langchainClient, runId); - const storedRun = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun = await langsmithClient.readRun(runId); expect(storedRun.id).toEqual(runId); expect(storedRun.status).toEqual("error"); expect(storedRun.error).toEqual("Error: I am bad"); @@ -146,7 +147,7 @@ test.concurrent("Test traceable wrapper with async error thrown", async () => { test.concurrent( "Test traceable wrapper", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -162,7 +163,7 @@ test.concurrent( { name: "add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, on_end: _getRun, tracingEnabled: true, @@ -174,8 +175,8 @@ test.concurrent( expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "testing9" }); - await waitUntilRunFound(langchainClient, runId, true); - const storedRun = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId, true); + const storedRun = await langsmithClient.readRun(runId); expect(storedRun.id).toEqual(runId); const runId2 = uuidv4(); @@ -186,7 +187,7 @@ test.concurrent( { name: "nested_add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, } ); const entryTraceable = traceable( @@ -197,7 +198,7 @@ test.concurrent( new RunTree({ name: "root_nested_add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, }), result, 2 @@ -207,7 +208,7 @@ test.concurrent( { name: "run_with_nesting", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId2, } ); @@ -215,8 +216,8 @@ test.concurrent( expect(await entryTraceable({ value: "testing" })).toBe("testing123"); expect(isTraceableFunction(entryTraceable)).toBe(true); - await waitUntilRunFound(langchainClient, runId2, true); - const storedRun2 = await langchainClient.readRun(runId2); + await waitUntilRunFound(langsmithClient, runId2, true); + const storedRun2 = await langsmithClient.readRun(runId2); expect(storedRun2.id).toEqual(runId2); const runId3 = uuidv4(); @@ -227,7 +228,7 @@ test.concurrent( const iterableTraceable = traceable(llm.stream.bind(llm), { name: "iterable_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId3, on_end: (r: RunTree): void => { collectedRun = r; @@ -244,11 +245,11 @@ test.concurrent( expect(chunks.join("")).toBe("Hello there"); expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).not.toBeNull(); - await waitUntilRunFound(langchainClient, runId3, true); - const storedRun3 = await langchainClient.readRun(runId3); + await waitUntilRunFound(langsmithClient, runId3, true); + const storedRun3 = await langsmithClient.readRun(runId3); expect(storedRun3.id).toEqual(runId3); - await deleteProject(langchainClient, projectName); + await deleteProject(langsmithClient, projectName); async function overload(a: string, b: number): Promise; async function overload(config: { a: string; b: number }): Promise; @@ -264,7 +265,7 @@ test.concurrent( const wrappedOverload = traceable(overload, { name: "wrapped_overload", project_name: projectName, - client: langchainClient, + client: langsmithClient, }); expect(await wrappedOverload("testing", 123)).toBe("testing123"); @@ -275,7 +276,7 @@ test.concurrent( ); test.concurrent("Test get run tree method", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); // Called outside a traceable function @@ -293,7 +294,7 @@ test.concurrent("Test get run tree method", async () => { { name: "nested_add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, } ); const addValueTraceable = traceable( @@ -305,7 +306,7 @@ test.concurrent("Test get run tree method", async () => { { name: "add_value", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, } ); @@ -313,7 +314,7 @@ test.concurrent("Test get run tree method", async () => { }); test.concurrent("Test traceable wrapper with aggregator", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const openai = new OpenAI(); @@ -330,7 +331,7 @@ test.concurrent("Test traceable wrapper with aggregator", async () => { { name: "openai_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { tracedOutput = chunks @@ -356,17 +357,16 @@ test.concurrent("Test traceable wrapper with aggregator", async () => { // eslint-disable-next-line @typescript-eslint/no-unused-vars const _test = chunk.invalidProp; } - console.log(tracedOutput); expect(typeof tracedOutput).toEqual("string"); expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: tracedOutput }); - await waitUntilRunFound(langchainClient, runId, true); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId, true); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); }); test.concurrent("Test async generator success", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -385,7 +385,7 @@ test.concurrent("Test async generator success", async () => { const iterableTraceable = traceable(giveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -401,8 +401,8 @@ test.concurrent("Test async generator success", async () => { } expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0 1 2 3 4" }); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("success"); expect(storedRun3.outputs).toEqual({ outputs: "0 1 2 3 4" }); @@ -410,7 +410,7 @@ test.concurrent("Test async generator success", async () => { }); test.concurrent("Test async generator throws error", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -433,7 +433,7 @@ test.concurrent("Test async generator throws error", async () => { const iterableTraceable = traceable(giveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -453,8 +453,8 @@ test.concurrent("Test async generator throws error", async () => { } expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0 1 2" }); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("error"); expect(storedRun3.outputs).toEqual({ outputs: "0 1 2" }); @@ -462,7 +462,7 @@ test.concurrent("Test async generator throws error", async () => { }); test.concurrent("Test async generator break finishes run", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -481,7 +481,7 @@ test.concurrent("Test async generator break finishes run", async () => { const iterableTraceable = traceable(giveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -498,8 +498,8 @@ test.concurrent("Test async generator break finishes run", async () => { expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0" }); expect(collectedRun!.id).toEqual(runId); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("error"); expect(storedRun3.outputs).toEqual({ outputs: "0" }); @@ -507,7 +507,7 @@ test.concurrent("Test async generator break finishes run", async () => { }); test.concurrent("Test async generator success", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -529,7 +529,7 @@ test.concurrent("Test async generator success", async () => { const iterableTraceable = traceable(giveMeGiveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -546,8 +546,8 @@ test.concurrent("Test async generator success", async () => { expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0 1 2 3 4" }); expect(collectedRun!.id).toEqual(runId); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("success"); expect(storedRun3.outputs).toEqual({ outputs: "0 1 2 3 4" }); @@ -555,7 +555,7 @@ test.concurrent("Test async generator success", async () => { }); test.concurrent("Test promise for async generator success", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -577,7 +577,7 @@ test.concurrent("Test promise for async generator success", async () => { const iterableTraceable = traceable(giveMeGiveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -600,8 +600,8 @@ test.concurrent("Test promise for async generator success", async () => { expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0 1 2" }); expect(collectedRun!.id).toEqual(runId); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("error"); expect(storedRun3.outputs).toEqual({ outputs: "0 1 2" }); @@ -611,7 +611,7 @@ test.concurrent("Test promise for async generator success", async () => { test.concurrent( "Test promise for async generator break finishes run", async () => { - const langchainClient = new Client({ + const langsmithClient = new Client({ callerOptions: { maxRetries: 0 }, }); const runId = uuidv4(); @@ -634,7 +634,7 @@ test.concurrent( const iterableTraceable = traceable(giveMeGiveMeNumbers, { name: "i_traceable", project_name: projectName, - client: langchainClient, + client: langsmithClient, id: runId, aggregator: (chunks) => { return chunks.join(" "); @@ -651,9 +651,128 @@ test.concurrent( expect(collectedRun).not.toBeNull(); expect(collectedRun!.outputs).toEqual({ outputs: "0" }); expect(collectedRun!.id).toEqual(runId); - await waitUntilRunFound(langchainClient, runId); - const storedRun3 = await langchainClient.readRun(runId); + await waitUntilRunFound(langsmithClient, runId); + const storedRun3 = await langsmithClient.readRun(runId); expect(storedRun3.id).toEqual(runId); expect(storedRun3.status).toEqual("error"); } ); + +test.concurrent( + "Test upload attachments and process inputs.", + async () => { + const langsmithClient = new Client({ + callerOptions: { maxRetries: 0 }, + }); + const runId = uuidv4(); + const projectName = "__test_traceable_wrapper_attachments_and_inputs"; + + const testAttachment1 = new Uint8Array([1, 2, 3, 4]); + const testAttachment2 = new Uint8Array([5, 6, 7, 8]); + const testAttachment3 = new ArrayBuffer(4); + new Uint8Array(testAttachment3).set([13, 14, 15, 16]); + + const traceableWithAttachmentsAndInputs = traceable( + ( + val: number, + text: string, + extra: string, + attachment: Uint8Array, + attachment2: ArrayBuffer + ) => + `Processed: ${val}, ${text}, ${extra}, ${attachment.length}, ${attachment2.byteLength}`, + { + name: "attachment_and_input_test", + project_name: projectName, + client: langsmithClient, + id: runId, + extractAttachments: ( + val: number, + text: string, + extra: string, + attachment: Uint8Array, + attachment2: ArrayBuffer + ) => [ + { + test1bin: ["application/octet-stream", testAttachment1], + test2bin: ["application/octet-stream", testAttachment2], + "input.bin": ["application/octet-stream", attachment], + "input2.bin": [ + "application/octet-stream", + new Uint8Array(attachment2), + ], + }, + { val, text, extra }, + ], + processInputs: (inputs) => { + expect(inputs).not.toHaveProperty("attachment"); + expect(inputs).not.toHaveProperty("attachment2"); + return { + ...inputs, + processed_val: (inputs.val as number) * 2, + processed_text: (inputs.text as string).toUpperCase(), + }; + }, + tracingEnabled: true, + } + ); + + const multipartIngestRunsSpy = jest.spyOn( + langsmithClient, + "multipartIngestRuns" + ); + + await traceableWithAttachmentsAndInputs( + 42, + "test input", + "extra data", + new Uint8Array([9, 10, 11, 12]), + testAttachment3 + ); + + await langsmithClient.awaitPendingTraceBatches(); + + expect(multipartIngestRunsSpy).toHaveBeenCalled(); + const callArgs = multipartIngestRunsSpy.mock.calls[0][0]; + + expect(callArgs.runCreates).toBeDefined(); + expect(callArgs.runCreates?.length).toBe(1); + + const runCreate = callArgs.runCreates?.[0]; + expect(runCreate?.id).toBe(runId); + expect(runCreate?.attachments).toBeDefined(); + expect(runCreate?.attachments?.["test1bin"]).toEqual([ + "application/octet-stream", + testAttachment1, + ]); + expect(runCreate?.attachments?.["test2bin"]).toEqual([ + "application/octet-stream", + testAttachment2, + ]); + expect(runCreate?.attachments?.["inputbin"]).toEqual([ + "application/octet-stream", + new Uint8Array([9, 10, 11, 12]), + ]); + expect(runCreate?.attachments?.["input2bin"]).toEqual([ + "application/octet-stream", + new Uint8Array([13, 14, 15, 16]), + ]); + + await waitUntilRunFound(langsmithClient, runId); + const storedRun = await langsmithClient.readRun(runId); + expect(storedRun.id).toEqual(runId); + expect(storedRun.inputs).toEqual({ + val: 42, + text: "test input", + extra: "extra data", + processed_val: 84, + processed_text: "TEST INPUT", + }); + expect(storedRun.outputs).toEqual({ + outputs: "Processed: 42, test input, extra data, 4, 4", + }); + + multipartIngestRunsSpy.mockRestore(); + }, + 60000 +); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index b8d48c663..fa64d17b4 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -7,7 +7,7 @@ import { isRunTree, isRunnableConfigLike, } from "./run_trees.js"; -import { InvocationParamsSchema, KVMap } from "./schemas.js"; +import { Attachments, InvocationParamsSchema, KVMap } from "./schemas.js"; import { isTracingEnabled } from "./env.js"; import { ROOT, @@ -29,10 +29,7 @@ AsyncLocalStorageProviderSingleton.initializeGlobalInstance( new AsyncLocalStorage() ); -const handleRunInputs = ( - rawInputs: unknown[], - processInputs: (inputs: Readonly) => KVMap -): KVMap => { +const runInputsToMap = (rawInputs: unknown[]) => { const firstInput = rawInputs[0]; let inputs: KVMap; @@ -45,7 +42,13 @@ const handleRunInputs = ( } else { inputs = { input: firstInput }; } + return inputs; +}; +const handleRunInputs = ( + inputs: KVMap, + processInputs: (inputs: Readonly) => KVMap +): KVMap => { try { return processInputs(inputs); } catch (e) { @@ -79,6 +82,24 @@ const handleRunOutputs = ( return outputs; } }; +const handleRunAttachments = ( + rawInputs: unknown[], + extractAttachments?: ( + ...args: unknown[] + ) => [Attachments | undefined, unknown[]] +): [Attachments | undefined, unknown[]] => { + if (!extractAttachments) { + return [undefined, rawInputs]; + } + + try { + const [attachments, remainingArgs] = extractAttachments(...rawInputs); + return [attachments, remainingArgs]; + } catch (e) { + console.error("Error occurred during extractAttachments:", e); + return [undefined, rawInputs]; + } +}; const getTracingRunTree = ( runTree: RunTree, @@ -86,13 +107,23 @@ const getTracingRunTree = ( getInvocationParams: | ((...args: Args) => InvocationParamsSchema | undefined) | undefined, - processInputs: (inputs: Readonly) => KVMap + processInputs: (inputs: Readonly) => KVMap, + extractAttachments: + | ((...args: Args) => [Attachments | undefined, KVMap]) + | undefined ): RunTree | undefined => { if (!isTracingEnabled(runTree.tracingEnabled)) { return undefined; } - runTree.inputs = handleRunInputs(inputs, processInputs); + const [attached, args] = handleRunAttachments( + inputs, + extractAttachments as + | ((...args: unknown[]) => [Attachments | undefined, unknown[]]) + | undefined + ); + runTree.attachments = attached; + runTree.inputs = handleRunInputs(args, processInputs); const invocationParams = getInvocationParams?.(...inputs); if (invocationParams != null) { @@ -309,6 +340,15 @@ export function traceable any>( argsConfigPath?: [number] | [number, string]; __finalTracedIteratorKey?: string; + /** + * Extract attachments from args and return remaining args. + * @param args Arguments of the traced function + * @returns Tuple of [Attachments, remaining args] + */ + extractAttachments?: ( + ...args: Parameters + ) => [Attachments | undefined, KVMap]; + /** * Extract invocation parameters from the arguments of the traced function. * This is useful for LangSmith to properly track common metadata like @@ -349,11 +389,14 @@ export function traceable any>( __finalTracedIteratorKey, processInputs, processOutputs, + extractAttachments, ...runTreeConfig } = config ?? {}; const processInputsFn = processInputs ?? ((x) => x); const processOutputsFn = processOutputs ?? ((x) => x); + const extractAttachmentsFn = + extractAttachments ?? ((...x) => [undefined, runInputsToMap(x)]); const traceableFunc = ( ...args: Inputs | [RunTree, ...Inputs] | [RunnableConfigLike, ...Inputs] @@ -427,7 +470,8 @@ export function traceable any>( RunTree.fromRunnableConfig(firstArg, ensuredConfig), restArgs as Inputs, config?.getInvocationParams, - processInputsFn + processInputsFn, + extractAttachmentsFn ), restArgs as Inputs, ]; @@ -452,7 +496,8 @@ export function traceable any>( : firstArg.createChild(ensuredConfig), restArgs as Inputs, config?.getInvocationParams, - processInputsFn + processInputsFn, + extractAttachmentsFn ); return [currentRunTree, [currentRunTree, ...restArgs] as Inputs]; @@ -467,7 +512,8 @@ export function traceable any>( prevRunFromStore.createChild(ensuredConfig), processedArgs, config?.getInvocationParams, - processInputsFn + processInputsFn, + extractAttachmentsFn ), processedArgs as Inputs, ]; @@ -477,7 +523,8 @@ export function traceable any>( new RunTree(ensuredConfig), processedArgs, config?.getInvocationParams, - processInputsFn + processInputsFn, + extractAttachmentsFn ); // If a context var is set by LangChain outside of a traceable, // it will be an object with a single property and we should copy diff --git a/python/langsmith/_internal/_operations.py b/python/langsmith/_internal/_operations.py index 1ba99a6db..e1e99d6e2 100644 --- a/python/langsmith/_internal/_operations.py +++ b/python/langsmith/_internal/_operations.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +import logging import uuid from typing import Literal, Optional, Union, cast @@ -10,6 +11,8 @@ from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext from langsmith._internal._serde import dumps_json as _dumps_json +logger = logging.getLogger(__name__) + class SerializedRunOperation: operation: Literal["post", "patch"] @@ -245,6 +248,15 @@ def serialized_run_operation_to_multipart_parts_and_context( ) if op.attachments: for n, (content_type, valb) in op.attachments.items(): + if "." in n: + logger.warning( + f"Skipping logging of attachment '{n}' " + f"for run {op.id}:" + " Invalid attachment name. Attachment names must not contain" + " periods ('.'). Please rename the attachment and try again." + ) + continue + acc_parts.append( ( f"attachment.{op.id}.{n}",