From c487e8559b7128a38a92a88aae1b79303ca9fa18 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 24 Oct 2024 15:04:14 +0200 Subject: [PATCH] Add unit tests with input/output fixtures --- js/src/tests/vercel.test.ts | 898 ++++++++++++++++++++++++++++++++++++ js/src/vercel.ts | 13 +- 2 files changed, 910 insertions(+), 1 deletion(-) create mode 100644 js/src/tests/vercel.test.ts diff --git a/js/src/tests/vercel.test.ts b/js/src/tests/vercel.test.ts new file mode 100644 index 000000000..15ddb2f3f --- /dev/null +++ b/js/src/tests/vercel.test.ts @@ -0,0 +1,898 @@ +import { NodeTracerProvider } from "@opentelemetry/sdk-trace-node"; +import { BatchSpanProcessor } from "@opentelemetry/sdk-trace-base"; + +import { + generateText, + streamText, + generateObject, + streamObject, + tool, + LanguageModelV1StreamPart, +} from "ai"; + +import { z } from "zod"; +import { AISDKExporter } from "../vercel.js"; +import { traceable } from "../traceable.js"; +import { toArray } from "./utils.js"; +import { mockClient } from "./utils/mock_client.js"; +import { convertArrayToReadableStream, MockLanguageModelV1 } from "ai/test"; +import { getAssumedTreeFromCalls } from "./utils/tree.js"; + +const { client, callSpy } = mockClient(); +const provider = new NodeTracerProvider(); +provider.addSpanProcessor( + new BatchSpanProcessor(new AISDKExporter({ client })) +); +provider.register(); + +class ExecutionOrderSame { + $$typeof = Symbol.for("jest.asymmetricMatcher"); + + private expectedNs: string; + private expectedDepth: number; + + constructor(depth: number, ns: string) { + this.expectedDepth = depth; + this.expectedNs = ns; + } + + asymmetricMatch(other: unknown) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (!(typeof other === "string" || other instanceof String)) { + return false; + } + + const segments = other.split("."); + if (segments.length !== this.expectedDepth) return false; + + const last = segments.at(-1); + if (!last) return false; + + const nanoseconds = last.split("Z").at(0)?.slice(-3); + return nanoseconds === this.expectedNs; + } + + toString() { + return "ExecutionOrderSame"; + } + + getExpectedType() { + return "string"; + } + + toAsymmetricMatcher() { + return `ExecutionOrderSame<${this.expectedDepth}, ${this.expectedNs}>`; + } +} + +class MockMultiStepLanguageModelV1 extends MockLanguageModelV1 { + generateStep = -1; + streamStep = -1; + + constructor(...args: ConstructorParameters) { + super(...args); + + const oldDoGenerate = this.doGenerate; + this.doGenerate = async (...args) => { + this.generateStep += 1; + return await oldDoGenerate(...args); + }; + + const oldDoStream = this.doStream; + this.doStream = async (...args) => { + this.streamStep += 1; + return await oldDoStream(...args); + }; + } +} + +beforeEach(() => callSpy.mockClear()); +afterAll(async () => await provider.shutdown()); + +test("generateText", async () => { + const model = new MockMultiStepLanguageModelV1({ + doGenerate: async () => { + if (model.generateStep === 0) { + return { + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + toolCalls: [ + { + toolCallType: "function", + toolName: "listOrders", + toolCallId: "tool-id", + args: JSON.stringify({ userId: "123" }), + }, + ], + }; + } + + return { + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + text: `Hello, world!`, + }; + }, + }); + + await generateText({ + model, + messages: [ + { + role: "user", + content: "What are my orders? My user ID is 123", + }, + ], + tools: { + listOrders: tool({ + description: "list all orders", + parameters: z.object({ userId: z.string() }), + execute: async ({ userId }) => + `User ${userId} has the following orders: 1`, + }), + }, + experimental_telemetry: AISDKExporter.getSettings({ + functionId: "functionId", + metadata: { userId: "123", language: "english" }, + }), + maxSteps: 10, + }); + + await provider.forceFlush(); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [ + "mock-provider:0", + "mock-provider:1", + "listOrders:2", + "mock-provider:3", + ], + edges: [ + ["mock-provider:0", "mock-provider:1"], + ["mock-provider:0", "listOrders:2"], + ["mock-provider:0", "mock-provider:3"], + ], + data: { + "mock-provider:0": { + inputs: { + messages: [ + { + type: "human", + data: { content: "What are my orders? My user ID is 123" }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(1, "000"), + }, + "mock-provider:1": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "000"), + }, + "listOrders:2": { + inputs: { userId: "123" }, + outputs: { output: "User 123 has the following orders: 1" }, + dotted_order: new ExecutionOrderSame(2, "001"), + }, + "mock-provider:3": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + }, + { + type: "tool", + data: { + content: '"User 123 has the following orders: 1"', + name: "listOrders", + tool_call_id: "tool-id", + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "002"), + }, + }, + }); +}); + +test("streamText", async () => { + const model = new MockMultiStepLanguageModelV1({ + doStream: async () => { + if (model.streamStep === 0) { + return { + stream: convertArrayToReadableStream([ + { + type: "tool-call", + toolCallType: "function", + toolName: "listOrders", + toolCallId: "tool-id", + args: JSON.stringify({ userId: "123" }), + }, + { + type: "finish", + finishReason: "stop", + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ] satisfies LanguageModelV1StreamPart[]), + rawCall: { rawPrompt: null, rawSettings: {} }, + }; + } + + return { + stream: convertArrayToReadableStream([ + { type: "text-delta", textDelta: "Hello" }, + { type: "text-delta", textDelta: ", " }, + { type: "text-delta", textDelta: `world!` }, + { + type: "finish", + finishReason: "stop", + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: null, rawSettings: {} }, + }; + }, + }); + + const result = await streamText({ + model, + messages: [ + { + role: "user", + content: "What are my orders? My user ID is 123", + }, + ], + tools: { + listOrders: tool({ + description: "list all orders", + parameters: z.object({ userId: z.string() }), + execute: async ({ userId }) => + `User ${userId} has the following orders: 1`, + }), + }, + experimental_telemetry: AISDKExporter.getSettings({ + functionId: "functionId", + metadata: { userId: "123", language: "english" }, + }), + maxSteps: 10, + }); + + await toArray(result.fullStream); + await provider.forceFlush(); + + const actual = getAssumedTreeFromCalls(callSpy.mock.calls); + expect(actual).toMatchObject({ + nodes: [ + "mock-provider:0", + "mock-provider:1", + "listOrders:2", + "mock-provider:3", + ], + edges: [ + ["mock-provider:0", "mock-provider:1"], + ["mock-provider:0", "listOrders:2"], + ["mock-provider:0", "mock-provider:3"], + ], + data: { + "mock-provider:0": { + inputs: { + messages: [ + { + type: "human", + data: { content: "What are my orders? My user ID is 123" }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 20, prompt_tokens: 6 }, + }, + }, + dotted_order: new ExecutionOrderSame(1, "000"), + }, + "mock-provider:1": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + token_usage: { completion_tokens: 10, prompt_tokens: 3 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "000"), + }, + "listOrders:2": { + inputs: { userId: "123" }, + outputs: { output: "User 123 has the following orders: 1" }, + dotted_order: new ExecutionOrderSame(2, "001"), + }, + "mock-provider:3": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + }, + { + type: "tool", + data: { + content: '"User 123 has the following orders: 1"', + name: "listOrders", + tool_call_id: "tool-id", + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 10, prompt_tokens: 3 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "002"), + }, + }, + }); +}); + +test("generateObject", async () => { + const model = new MockMultiStepLanguageModelV1({ + doGenerate: async () => ({ + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + toolCalls: [ + { + toolCallType: "function", + toolName: "json", + toolCallId: "tool-id", + args: JSON.stringify({ + weather: { city: "Prague", unit: "celsius" }, + }), + }, + ], + }), + defaultObjectGenerationMode: "tool", + }); + + await generateObject({ + model, + schema: z.object({ + weather: z.object({ + city: z.string(), + unit: z.union([z.literal("celsius"), z.literal("fahrenheit")]), + }), + }), + prompt: "What's the weather in Prague?", + experimental_telemetry: AISDKExporter.getSettings({ + functionId: "functionId", + metadata: { userId: "123", language: "english" }, + }), + }); + + await provider.forceFlush(); + const actual = getAssumedTreeFromCalls(callSpy.mock.calls); + + expect(actual).toMatchObject({ + nodes: ["mock-provider:0", "mock-provider:1"], + edges: [["mock-provider:0", "mock-provider:1"]], + data: { + "mock-provider:0": { + inputs: { + input: { prompt: "What's the weather in Prague?" }, + }, + outputs: { + output: { weather: { city: "Prague", unit: "celsius" } }, + llm_output: { + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(1, "000"), + }, + "mock-provider:1": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { type: "text", text: "What's the weather in Prague?" }, + ], + }, + }, + ], + }, + outputs: { + output: { weather: { city: "Prague", unit: "celsius" } }, + llm_output: { + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "000"), + }, + }, + }); +}); + +test("streamObject", async () => { + const model = new MockMultiStepLanguageModelV1({ + doGenerate: async () => ({ + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + toolCalls: [ + { + toolCallType: "function", + toolName: "json", + toolCallId: "tool-id", + args: JSON.stringify({ + weather: { city: "Prague", unit: "celsius" }, + }), + }, + ], + }), + + doStream: async () => { + return { + stream: convertArrayToReadableStream([ + { + type: "tool-call-delta", + toolCallType: "function", + toolName: "json", + toolCallId: "tool-id", + argsTextDelta: JSON.stringify({ + weather: { city: "Prague", unit: "celsius" }, + }), + }, + { + type: "finish", + finishReason: "stop", + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ] satisfies LanguageModelV1StreamPart[]), + rawCall: { rawPrompt: null, rawSettings: {} }, + }; + }, + defaultObjectGenerationMode: "tool", + }); + + const result = await streamObject({ + model, + schema: z.object({ + weather: z.object({ + city: z.string(), + unit: z.union([z.literal("celsius"), z.literal("fahrenheit")]), + }), + }), + prompt: "What's the weather in Prague?", + experimental_telemetry: AISDKExporter.getSettings({ + functionId: "functionId", + metadata: { userId: "123", language: "english" }, + }), + }); + + await toArray(result.partialObjectStream); + await provider.forceFlush(); + + const actual = getAssumedTreeFromCalls(callSpy.mock.calls); + expect(actual).toMatchObject({ + nodes: ["mock-provider:0", "mock-provider:1"], + edges: [["mock-provider:0", "mock-provider:1"]], + data: { + "mock-provider:0": { + inputs: { + input: { prompt: "What's the weather in Prague?" }, + }, + outputs: { + output: { weather: { city: "Prague", unit: "celsius" } }, + llm_output: { + token_usage: { completion_tokens: 10, prompt_tokens: 3 }, + }, + }, + dotted_order: new ExecutionOrderSame(1, "000"), + }, + "mock-provider:1": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { type: "text", text: "What's the weather in Prague?" }, + ], + }, + }, + ], + }, + outputs: { + output: { weather: { city: "Prague", unit: "celsius" } }, + llm_output: { + token_usage: { completion_tokens: 10, prompt_tokens: 3 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "000"), + }, + }, + }); +}); + +test("traceable", async () => { + const model = new MockMultiStepLanguageModelV1({ + doGenerate: async () => { + if (model.generateStep === 0) { + return { + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + toolCalls: [ + { + toolCallType: "function", + toolName: "listOrders", + toolCallId: "tool-id", + args: JSON.stringify({ userId: "123" }), + }, + ], + }; + } + + return { + rawCall: { rawPrompt: null, rawSettings: {} }, + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + text: `Hello, world!`, + }; + }, + }); + + const wrappedText = traceable( + async (content: string) => { + const { text } = await generateText({ + model, + messages: [{ role: "user", content }], + tools: { + listOrders: tool({ + description: "list all orders", + parameters: z.object({ userId: z.string() }), + execute: async ({ userId }) => + `User ${userId} has the following orders: 1`, + }), + }, + experimental_telemetry: AISDKExporter.getSettings({ + functionId: "functionId", + metadata: { userId: "123", language: "english" }, + }), + maxSteps: 10, + }); + + return { text }; + }, + { name: "wrappedText", client, tracingEnabled: true } + ); + + await wrappedText("What are my orders? My user ID is 123"); + await provider.forceFlush(); + + const actual = getAssumedTreeFromCalls(callSpy.mock.calls); + expect(actual).toMatchObject({ + nodes: [ + "wrappedText:0", + "mock-provider:1", + "mock-provider:2", + "listOrders:3", + "mock-provider:4", + ], + edges: [ + ["wrappedText:0", "mock-provider:1"], + ["mock-provider:1", "mock-provider:2"], + ["mock-provider:1", "listOrders:3"], + ["mock-provider:1", "mock-provider:4"], + ], + data: { + "wrappedText:0": { + inputs: { + input: "What are my orders? My user ID is 123", + }, + outputs: { + text: "Hello, world!", + }, + dotted_order: new ExecutionOrderSame(1, "001"), + }, + "mock-provider:1": { + inputs: { + messages: [ + { + type: "human", + data: { content: "What are my orders? My user ID is 123" }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(2, "000"), + }, + "mock-provider:2": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(3, "000"), + }, + "listOrders:3": { + inputs: { userId: "123" }, + outputs: { output: "User 123 has the following orders: 1" }, + dotted_order: new ExecutionOrderSame(3, "001"), + }, + "mock-provider:4": { + inputs: { + messages: [ + { + type: "human", + data: { + content: [ + { + type: "text", + text: "What are my orders? My user ID is 123", + }, + ], + }, + }, + { + type: "ai", + data: { + content: [ + { + type: "tool_use", + name: "listOrders", + id: "tool-id", + input: { userId: "123" }, + }, + ], + additional_kwargs: { + tool_calls: [ + { + id: "tool-id", + type: "function", + function: { + name: "listOrders", + id: "tool-id", + arguments: '{"userId":"123"}', + }, + }, + ], + }, + }, + }, + { + type: "tool", + data: { + content: '"User 123 has the following orders: 1"', + name: "listOrders", + tool_call_id: "tool-id", + }, + }, + ], + }, + outputs: { + llm_output: { + type: "ai", + data: { content: "Hello, world!" }, + token_usage: { completion_tokens: 20, prompt_tokens: 10 }, + }, + }, + dotted_order: new ExecutionOrderSame(3, "002"), + }, + }, + }); +}); diff --git a/js/src/vercel.ts b/js/src/vercel.ts index 9ab1e127f..fe1250381 100644 --- a/js/src/vercel.ts +++ b/js/src/vercel.ts @@ -203,6 +203,14 @@ function convertToTimestamp([seconds, nanoseconds]: [ return Number(String(seconds) + ms); } +function sortByHr( + a: [seconds: number, nanoseconds: number], + b: [seconds: number, nanoseconds: number] +): number { + if (a[0] !== b[0]) return Math.sign(a[0] - b[0]); + return Math.sign(a[1] - b[1]); +} + const ROOT = "$"; const RUN_ID_NAMESPACE = "5c718b20-9078-11ef-9a3d-325096b39f47"; @@ -611,7 +619,10 @@ export class AISDKExporter { spans: unknown[], resultCallback: (result: { code: 0 | 1; error?: Error }) => void ): void { - const typedSpans = spans as AISDKSpan[]; + const typedSpans = (spans as AISDKSpan[]) + .slice() + .sort((a, b) => sortByHr(a.startTime, b.startTime)); + for (const span of typedSpans) { const { traceId, spanId } = span.spanContext(); const parentId = span.parentSpanId ?? undefined;