Skip to content

Commit

Permalink
feat: Change the typing for the LDAIConfig. (#688)
Browse files Browse the repository at this point in the history
Flatten the AI config moving model/prompt to the top level.
Change the LDAIDefaults to match this change.
Remove the generation config type which was the type of the `config`
field.

Add temperature and maxTokens as optionals to the model config.
Make `modelId` required.

Updated the examples to include temperature and maxTokens.

SDK-912
  • Loading branch information
kinyoklion authored Nov 15, 2024
1 parent 4cf34f9 commit 1f3f54a
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 66 deletions.
37 changes: 18 additions & 19 deletions packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDGenerationConfig } from '../src/api/config';
import { LDAIDefaults } from '../src/api/config';
import { LDAIClientImpl } from '../src/LDAIClientImpl';
import { LDClientMin } from '../src/LDClientMin';

Expand Down Expand Up @@ -32,13 +32,14 @@ it('handles empty variables in template interpolation', () => {
it('returns model config with interpolated prompts', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
enabled: true,
};

const mockVariation = {
model: { modelId: 'example-provider', name: 'imagination' },
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
{ role: 'system', content: 'Hello {{name}}' },
{ role: 'user', content: 'Score: {{score}}' },
Expand All @@ -55,13 +56,11 @@ it('returns model config with interpolated prompts', async () => {
const result = await client.modelConfig(key, testContext, defaultValue, variables);

expect(result).toEqual({
config: {
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [
{ role: 'system', content: 'Hello John' },
{ role: 'user', content: 'Score: 42' },
],
},
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
{ role: 'system', content: 'Hello John' },
{ role: 'user', content: 'Score: 42' },
],
tracker: expect.any(Object),
enabled: true,
});
Expand All @@ -70,7 +69,7 @@ it('returns model config with interpolated prompts', async () => {
it('includes context in variables for prompt interpolation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
};
Expand All @@ -84,13 +83,13 @@ it('includes context in variables for prompt interpolation', async () => {

const result = await client.modelConfig(key, testContext, defaultValue);

expect(result.config.prompt?.[0].content).toBe('User key: test-user');
expect(result.prompt?.[0].content).toBe('User key: test-user');
});

it('handles missing metadata in variation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
};
Expand All @@ -105,10 +104,8 @@ it('handles missing metadata in variation', async () => {
const result = await client.modelConfig(key, testContext, defaultValue);

expect(result).toEqual({
config: {
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
},
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
tracker: expect.any(Object),
enabled: false,
});
Expand All @@ -117,17 +114,19 @@ it('handles missing metadata in variation', async () => {
it('passes the default value to the underlying client', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'non-existent-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'default-model', name: 'default' },
prompt: [{ role: 'system', content: 'Default prompt' }],
enabled: true,
};

mockLdClient.variation.mockResolvedValue(defaultValue);

const result = await client.modelConfig(key, testContext, defaultValue);

expect(result).toEqual({
config: defaultValue,
model: defaultValue.model,
prompt: defaultValue.prompt,
tracker: expect.any(Object),
enabled: false,
});
Expand Down
8 changes: 6 additions & 2 deletions packages/sdk/server-ai/examples/bedrock/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ async function main() {
const completion = tracker.trackBedrockConverse(
await awsClient.send(
new ConverseCommand({
modelId: aiConfig.config.model?.modelId ?? 'no-model',
messages: mapPromptToConversation(aiConfig.config.prompt ?? []),
modelId: aiConfig.model?.modelId ?? 'no-model',
messages: mapPromptToConversation(aiConfig.prompt ?? []),
inferenceConfig: {
temperature: aiConfig.model?.temperature ?? 0.5,
maxTokens: aiConfig.model?.maxTokens ?? 4096,
},
}),
),
);
Expand Down
6 changes: 4 additions & 2 deletions packages/sdk/server-ai/examples/openai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ async function main(): Promise<void> {
const { tracker } = aiConfig;
const completion = await tracker.trackOpenAI(async () =>
client.chat.completions.create({
messages: aiConfig.config.prompt || [],
model: aiConfig.config.model?.modelId || 'gpt-4',
messages: aiConfig.prompt || [],
model: aiConfig.model?.modelId || 'gpt-4',
temperature: aiConfig.model?.temperature ?? 0.5,
max_tokens: aiConfig.model?.maxTokens ?? 4096,
}),
);

Expand Down
34 changes: 17 additions & 17 deletions packages/sdk/server-ai/src/LDAIClientImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as Mustache from 'mustache';

import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDAIConfig, LDGenerationConfig, LDMessage, LDModelConfig } from './api/config';
import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig } from './api/config';
import { LDAIClient } from './api/LDAIClient';
import { LDAIConfigTrackerImpl } from './LDAIConfigTrackerImpl';
import { LDClientMin } from './LDClientMin';
Expand Down Expand Up @@ -32,16 +32,28 @@ export class LDAIClientImpl implements LDAIClient {
return Mustache.render(template, variables, undefined, { escape: (item: any) => item });
}

async modelConfig<TDefault extends LDGenerationConfig>(
async modelConfig(
key: string,
context: LDContext,
defaultValue: TDefault,
defaultValue: LDAIDefaults,
variables?: Record<string, unknown>,
): Promise<LDAIConfig> {
const value: VariationContent = await this._ldClient.variation(key, context, defaultValue);
const tracker = new LDAIConfigTrackerImpl(
this._ldClient,
key,
// eslint-disable-next-line no-underscore-dangle
value._ldMeta?.versionKey ?? '',
context,
);
// eslint-disable-next-line no-underscore-dangle
const enabled = !!value._ldMeta?.enabled;
const config: LDAIConfig = {
tracker,
enabled,
};
// We are going to modify the contents before returning them, so we make a copy.
// This isn't a deep copy and the application developer should not modify the returned content.
const config: LDGenerationConfig = {};
if (value.model) {
config.model = { ...value.model };
}
Expand All @@ -54,18 +66,6 @@ export class LDAIClientImpl implements LDAIClient {
}));
}

return {
config,
// eslint-disable-next-line no-underscore-dangle
tracker: new LDAIConfigTrackerImpl(
this._ldClient,
key,
// eslint-disable-next-line no-underscore-dangle
value._ldMeta?.versionKey ?? '',
context,
),
// eslint-disable-next-line no-underscore-dangle
enabled: !!value._ldMeta?.enabled,
};
return config;
}
}
16 changes: 3 additions & 13 deletions packages/sdk/server-ai/src/api/LDAIClient.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDAIConfig, LDGenerationConfig } from './config/LDAIConfig';

/**
* Interface for default model configuration.
*/
export interface LDAIDefaults extends LDGenerationConfig {
/**
* Whether the configuration is enabled.
*/
enabled?: boolean;
}
import { LDAIConfig, LDAIDefaults } from './config/LDAIConfig';

/**
* Interface for performing AI operations using LaunchDarkly.
Expand Down Expand Up @@ -77,10 +67,10 @@ export interface LDAIClient {
* }
* ```
*/
modelConfig<TDefault extends LDAIDefaults>(
modelConfig(
key: string,
context: LDContext,
defaultValue: TDefault,
defaultValue: LDAIDefaults,
variables?: Record<string, unknown>,
): Promise<LDAIConfig>;
}
29 changes: 16 additions & 13 deletions packages/sdk/server-ai/src/api/config/LDAIConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export interface LDModelConfig {
/**
* The ID of the model.
*/
modelId?: string;
modelId: string;

/**
* Tuning parameter for randomness versus determinism. Exact effect will be determined by the
Expand Down Expand Up @@ -41,9 +41,9 @@ export interface LDMessage {
}

/**
* Configuration which affects generation.
* AI configuration and tracker.
*/
export interface LDGenerationConfig {
export interface LDAIConfig {
/**
* Optional model configuration.
*/
Expand All @@ -52,16 +52,6 @@ export interface LDGenerationConfig {
* Optional prompt data.
*/
prompt?: LDMessage[];
}

/**
* AI Config value and tracker.
*/
export interface LDAIConfig {
/**
* The result of the AI Config customization.
*/
config: LDGenerationConfig;

/**
* A tracker which can be used to generate analytics.
Expand All @@ -73,3 +63,16 @@ export interface LDAIConfig {
*/
enabled: boolean;
}

/**
* Default value for a `modelConfig`. This is the same as the LDAIConfig, but it does not include
* a tracker and `enabled` is optional.
*/
export type LDAIDefaults = Omit<LDAIConfig, 'tracker' | 'enabled'> & {
/**
* Whether the configuration is enabled.
*
* defaults to false
*/
enabled?: boolean;
};

0 comments on commit 1f3f54a

Please sign in to comment.