From 89ce6dbbb2889af66ca53dd546c5977953dea972 Mon Sep 17 00:00:00 2001 From: Ryan Lamb <4955475+kinyoklion@users.noreply.github.com> Date: Fri, 8 Nov 2024 08:44:39 -0800 Subject: [PATCH] fix: Do not include _ldMeta in returned config. (#668) The primary purpose of this was to add unit tests, but it ended up with a bug fix, so that is the PR title. --- .../__tests__/LDAIClientImpl.test.ts | 136 +++++++++ .../__tests__/LDAIConfigTrackerImpl.test.ts | 270 ++++++++++++++++++ .../server-ai/__tests__/TokenUsage.test.ts | 78 +++++ packages/sdk/server-ai/jest.config.js | 7 + packages/sdk/server-ai/package.json | 2 +- packages/sdk/server-ai/src/LDAIClientImpl.ts | 5 +- 6 files changed, 496 insertions(+), 2 deletions(-) create mode 100644 packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts create mode 100644 packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts create mode 100644 packages/sdk/server-ai/__tests__/TokenUsage.test.ts create mode 100644 packages/sdk/server-ai/jest.config.js diff --git a/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts new file mode 100644 index 000000000..a396c2c1a --- /dev/null +++ b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts @@ -0,0 +1,136 @@ +import { LDContext } from '@launchdarkly/js-server-sdk-common'; + +import { LDGenerationConfig } from '../src/api/config'; +import { LDAIClientImpl } from '../src/LDAIClientImpl'; +import { LDClientMin } from '../src/LDClientMin'; + +const mockLdClient: jest.Mocked = { + variation: jest.fn(), + track: jest.fn(), +}; + +const testContext: LDContext = { kind: 'user', key: 'test-user' }; + +it('interpolates template variables', () => { + const client = new LDAIClientImpl(mockLdClient); + const template = 'Hello {{name}}, your score is {{score}}'; + const variables = { name: 'John', score: 42 }; + + const result = client.interpolateTemplate(template, variables); + expect(result).toBe('Hello John, your score is 42'); +}); + +it('handles empty variables in template interpolation', () => { + const client = new LDAIClientImpl(mockLdClient); + const template = 'Hello {{name}}'; + const variables = {}; + + const result = client.interpolateTemplate(template, variables); + expect(result).toBe('Hello '); +}); + +it('returns model config with interpolated prompts', async () => { + const client = new LDAIClientImpl(mockLdClient); + const key = 'test-flag'; + const defaultValue: LDGenerationConfig = { + model: { modelId: 'test', name: 'test-model' }, + prompt: [], + }; + + const mockVariation = { + model: { modelId: 'example-provider', name: 'imagination' }, + prompt: [ + { role: 'system', content: 'Hello {{name}}' }, + { role: 'user', content: 'Score: {{score}}' }, + ], + _ldMeta: { + versionKey: 'v1', + enabled: true, + }, + }; + + mockLdClient.variation.mockResolvedValue(mockVariation); + + const variables = { name: 'John', score: 42 }; + const result = await client.modelConfig(key, testContext, defaultValue, variables); + + expect(result).toEqual({ + config: { + model: { modelId: 'example-provider', name: 'imagination' }, + prompt: [ + { role: 'system', content: 'Hello John' }, + { role: 'user', content: 'Score: 42' }, + ], + }, + tracker: expect.any(Object), + enabled: true, + }); +}); + +it('includes context in variables for prompt interpolation', async () => { + const client = new LDAIClientImpl(mockLdClient); + const key = 'test-flag'; + const defaultValue: LDGenerationConfig = { + model: { modelId: 'test', name: 'test-model' }, + prompt: [], + }; + + const mockVariation = { + prompt: [{ role: 'system', content: 'User key: {{ldctx.key}}' }], + _ldMeta: { versionKey: 'v1', enabled: true }, + }; + + mockLdClient.variation.mockResolvedValue(mockVariation); + + const result = await client.modelConfig(key, testContext, defaultValue); + + expect(result.config.prompt?.[0].content).toBe('User key: test-user'); +}); + +it('handles missing metadata in variation', async () => { + const client = new LDAIClientImpl(mockLdClient); + const key = 'test-flag'; + const defaultValue: LDGenerationConfig = { + model: { modelId: 'test', name: 'test-model' }, + prompt: [], + }; + + const mockVariation = { + model: { modelId: 'example-provider', name: 'imagination' }, + prompt: [{ role: 'system', content: 'Hello' }], + }; + + mockLdClient.variation.mockResolvedValue(mockVariation); + + const result = await client.modelConfig(key, testContext, defaultValue); + + expect(result).toEqual({ + config: { + model: { modelId: 'example-provider', name: 'imagination' }, + prompt: [{ role: 'system', content: 'Hello' }], + }, + tracker: expect.any(Object), + enabled: false, + }); +}); + +it('passes the default value to the underlying client', async () => { + const client = new LDAIClientImpl(mockLdClient); + const key = 'non-existent-flag'; + const defaultValue: LDGenerationConfig = { + model: { modelId: 'default-model', name: 'default' }, + prompt: [{ role: 'system', content: 'Default prompt' }], + }; + + mockLdClient.variation.mockResolvedValue(defaultValue); + + const result = await client.modelConfig(key, testContext, defaultValue); + + expect(result).toEqual({ + config: defaultValue, + tracker: expect.any(Object), + enabled: false, + }); + + expect(mockLdClient.variation).toHaveBeenCalledWith(key, testContext, defaultValue); +}); diff --git a/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts b/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts new file mode 100644 index 000000000..6572577c6 --- /dev/null +++ b/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts @@ -0,0 +1,270 @@ +import { LDContext } from '@launchdarkly/js-server-sdk-common'; + +import { LDFeedbackKind } from '../src/api/metrics'; +import { LDAIConfigTrackerImpl } from '../src/LDAIConfigTrackerImpl'; +import { LDClientMin } from '../src/LDClientMin'; + +const mockTrack = jest.fn(); +const mockVariation = jest.fn(); +const mockLdClient: LDClientMin = { + track: mockTrack, + variation: mockVariation, +}; + +const testContext: LDContext = { kind: 'user', key: 'test-user' }; +const configKey = 'test-config'; +const versionKey = 'v1'; + +beforeEach(() => { + jest.clearAllMocks(); +}); + +it('tracks duration', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + tracker.trackDuration(1000); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, versionKey }, + 1000, + ); +}); + +it('tracks duration of async function', async () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const result = await tracker.trackDurationOf(async () => 'test-result'); + + expect(result).toBe('test-result'); + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, versionKey }, + 1000, + ); +}); + +it('tracks positive feedback', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + tracker.trackFeedback({ kind: LDFeedbackKind.Positive }); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:feedback:user:positive', + testContext, + { configKey, versionKey }, + 1, + ); +}); + +it('tracks negative feedback', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + tracker.trackFeedback({ kind: LDFeedbackKind.Negative }); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:feedback:user:negative', + testContext, + { configKey, versionKey }, + 1, + ); +}); + +it('tracks success', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + tracker.trackSuccess(); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, versionKey }, + 1, + ); +}); + +it('tracks OpenAI usage', async () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const TOTAL_TOKENS = 100; + const PROMPT_TOKENS = 49; + const COMPLETION_TOKENS = 51; + + await tracker.trackOpenAI(async () => ({ + usage: { + total_tokens: TOTAL_TOKENS, + prompt_tokens: PROMPT_TOKENS, + completion_tokens: COMPLETION_TOKENS, + }, + })); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, versionKey }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, versionKey }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:total', + testContext, + { configKey, versionKey }, + TOTAL_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, versionKey }, + PROMPT_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:output', + testContext, + { configKey, versionKey }, + COMPLETION_TOKENS, + ); +}); + +it('tracks Bedrock conversation with successful response', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + + const TOTAL_TOKENS = 100; + const PROMPT_TOKENS = 49; + const COMPLETION_TOKENS = 51; + + const response = { + $metadata: { httpStatusCode: 200 }, + metrics: { latencyMs: 500 }, + usage: { + inputTokens: PROMPT_TOKENS, + outputTokens: COMPLETION_TOKENS, + totalTokens: TOTAL_TOKENS, + }, + }; + + tracker.trackBedrockConverse(response); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, versionKey }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, versionKey }, + 500, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:total', + testContext, + { configKey, versionKey }, + TOTAL_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, versionKey }, + PROMPT_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:output', + testContext, + { configKey, versionKey }, + COMPLETION_TOKENS, + ); +}); + +it('tracks Bedrock conversation with error response', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + + const response = { + $metadata: { httpStatusCode: 400 }, + }; + + // TODO: We may want a track failure. + + tracker.trackBedrockConverse(response); + + expect(mockTrack).not.toHaveBeenCalled(); +}); + +it('tracks tokens', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + + const TOTAL_TOKENS = 100; + const PROMPT_TOKENS = 49; + const COMPLETION_TOKENS = 51; + + tracker.trackTokens({ + total: TOTAL_TOKENS, + input: PROMPT_TOKENS, + output: COMPLETION_TOKENS, + }); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:total', + testContext, + { configKey, versionKey }, + TOTAL_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, versionKey }, + PROMPT_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:output', + testContext, + { configKey, versionKey }, + COMPLETION_TOKENS, + ); +}); + +it('only tracks non-zero token counts', () => { + const tracker = new LDAIConfigTrackerImpl(mockLdClient, configKey, versionKey, testContext); + + tracker.trackTokens({ + total: 0, + input: 50, + output: 0, + }); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:tokens:total', + expect.anything(), + expect.anything(), + expect.anything(), + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, versionKey }, + 50, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:tokens:output', + expect.anything(), + expect.anything(), + expect.anything(), + ); +}); diff --git a/packages/sdk/server-ai/__tests__/TokenUsage.test.ts b/packages/sdk/server-ai/__tests__/TokenUsage.test.ts new file mode 100644 index 000000000..3dbc8bf6b --- /dev/null +++ b/packages/sdk/server-ai/__tests__/TokenUsage.test.ts @@ -0,0 +1,78 @@ +import { createBedrockTokenUsage } from '../src/api/metrics/BedrockTokenUsage'; +import { createOpenAiUsage } from '../src/api/metrics/OpenAiUsage'; + +it('createBedrockTokenUsage should create token usage with all values provided', () => { + const usage = createBedrockTokenUsage({ + totalTokens: 100, + inputTokens: 40, + outputTokens: 60, + }); + + expect(usage).toEqual({ + total: 100, + input: 40, + output: 60, + }); +}); + +it('createBedrockTokenUsage should default to 0 for missing values', () => { + const usage = createBedrockTokenUsage({}); + + expect(usage).toEqual({ + total: 0, + input: 0, + output: 0, + }); +}); + +it('createBedrockTokenUsage should handle explicitly undefined values', () => { + const usage = createBedrockTokenUsage({ + totalTokens: undefined, + inputTokens: 40, + outputTokens: undefined, + }); + + expect(usage).toEqual({ + total: 0, + input: 40, + output: 0, + }); +}); + +it('createOpenAiUsage should create token usage with all values provided', () => { + const usage = createOpenAiUsage({ + total_tokens: 100, + prompt_tokens: 40, + completion_tokens: 60, + }); + + expect(usage).toEqual({ + total: 100, + input: 40, + output: 60, + }); +}); + +it('createOpenAiUsage should default to 0 for missing values', () => { + const usage = createOpenAiUsage({}); + + expect(usage).toEqual({ + total: 0, + input: 0, + output: 0, + }); +}); + +it('createOpenAiUsage should handle explicitly undefined values', () => { + const usage = createOpenAiUsage({ + total_tokens: undefined, + prompt_tokens: 40, + completion_tokens: undefined, + }); + + expect(usage).toEqual({ + total: 0, + input: 40, + output: 0, + }); +}); diff --git a/packages/sdk/server-ai/jest.config.js b/packages/sdk/server-ai/jest.config.js new file mode 100644 index 000000000..f106eb3bc --- /dev/null +++ b/packages/sdk/server-ai/jest.config.js @@ -0,0 +1,7 @@ +module.exports = { + transform: { '^.+\\.ts?$': 'ts-jest' }, + testMatch: ['**/__tests__/**/*test.ts?(x)'], + testEnvironment: 'node', + moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], + collectCoverageFrom: ['src/**/*.ts'], +}; diff --git a/packages/sdk/server-ai/package.json b/packages/sdk/server-ai/package.json index 6e4b431d1..b8bfa5d54 100644 --- a/packages/sdk/server-ai/package.json +++ b/packages/sdk/server-ai/package.json @@ -16,7 +16,7 @@ "prettier": "prettier --write '**/*.@(js|ts|tsx|json|css)' --ignore-path ../../../.prettierignore", "lint:fix": "yarn run lint --fix", "check": "yarn prettier && yarn lint && yarn build && yarn test", - "test": "echo No tests added yet." + "test": "jest" }, "keywords": [ "launchdarkly", diff --git a/packages/sdk/server-ai/src/LDAIClientImpl.ts b/packages/sdk/server-ai/src/LDAIClientImpl.ts index dd0dead9d..dc87d270e 100644 --- a/packages/sdk/server-ai/src/LDAIClientImpl.ts +++ b/packages/sdk/server-ai/src/LDAIClientImpl.ts @@ -41,7 +41,10 @@ export class LDAIClientImpl implements LDAIClient { const value: VariationContent = await this._ldClient.variation(key, context, defaultValue); // We are going to modify the contents before returning them, so we make a copy. // This isn't a deep copy and the application developer should not modify the returned content. - const config: LDGenerationConfig = { ...value }; + const config: LDGenerationConfig = {}; + if (value.model) { + config.model = { ...value.model }; + } const allVariables = { ...variables, ldctx: context }; if (value.prompt) {