Skip to content

Commit

Permalink
feat: Change the typing for the AIConfig (#16)
Browse files Browse the repository at this point in the history
This commit inlines the `AIConfigData` class into the existing
`AIConfig`.

It also introduces a new type, `ModelConfig` to replace the much looser
dictionary config that was originally in play. This new model contains
specific, typed properties for the model id, temperature, and max
tokens. Additional model-specific attributes can be accessed as well.
  • Loading branch information
keelerm84 authored Nov 15, 2024
1 parent c752739 commit daf9537
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 42 deletions.
76 changes: 69 additions & 7 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,69 @@ class LDMessage:
content: str


@dataclass
class AIConfigData:
model: Optional[dict]
prompt: Optional[List[LDMessage]]
class ModelConfig:
"""
Configuration related to the model.
"""

def __init__(self, id: str, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, attributes: dict = {}):
"""
:param id: The ID of the model.
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
:param max_tokens: The maximum number of tokens.
:param attributes: Additional model-specific attributes.
"""
self._id = id
self._temperature = temperature
self._max_tokens = max_tokens
self._attributes = attributes

@property
def id(self) -> str:
"""
The ID of the model.
"""
return self._id

@property
def temperature(self) -> Optional[float]:
""""
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
"""
return self._temperature

@property
def max_tokens(self) -> Optional[int]:
"""
The maximum number of tokens.
"""

return self._max_tokens

def get_attribute(self, key: str) -> Any:
"""
Retrieve model-specific attributes.
Accessing a named, typed attribute (e.g. id) will result in the call
being delegated to the appropriate property.
"""
if key == 'id':
return self.id
if key == 'temperature':
return self.temperature
if key == 'maxTokens':
return self.max_tokens

return self._attributes.get(key)


class AIConfig:
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
self.config = config
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]):
self.tracker = tracker
self.enabled = enabled
self.model = model
self.prompt = prompt


class LDAIClient:
Expand Down Expand Up @@ -71,16 +123,26 @@ def model_config(
for entry in variation['prompt']
]

model = None
if 'model' in variation:
model = ModelConfig(
id=variation['model']['modelId'],
temperature=variation['model'].get('temperature'),
max_tokens=variation['model'].get('maxTokens'),
attributes=variation['model'],
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
return AIConfig(
config=AIConfigData(model=variation['model'], prompt=prompt),
tracker=LDAIConfigTracker(
self.client,
variation.get('_ldMeta', {}).get('versionKey', ''),
key,
context,
),
enabled=bool(enabled),
model=model,
prompt=prompt
)

def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
Expand Down
98 changes: 63 additions & 35 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData
from ldclient.testing.builders import *

from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
from ldai.tracker import LDAIConfigTracker


Expand All @@ -14,7 +13,7 @@ def td() -> TestData:
td.flag('model-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
},
Expand All @@ -27,7 +26,7 @@ def td() -> TestData:
td.flag('multiple-prompt')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192},
'prompt': [
{'role': 'system', 'content': 'Hello, {{name}}!'},
{'role': 'user', 'content': 'The day is, {{day}}!'},
Expand All @@ -43,7 +42,7 @@ def td() -> TestData:
td.flag('ctx-interpolation')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'},
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
}
Expand All @@ -55,7 +54,7 @@ def td() -> TestData:
td.flag('off-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.1},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
}
Expand All @@ -82,81 +81,110 @@ def ldai_client(client: LDClient) -> LDAIClient:
return LDAIClient(client)


def test_model_config_delegates_to_properties():
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.temperature == 0.5
assert model.max_tokens == 4096
assert model.get_attribute('extra-attribute') == 'value'
assert model.get_attribute('non-existent') is None

assert model.id == model.get_attribute('id')
assert model.temperature == model.get_attribute('temperature')
assert model.max_tokens == model.get_attribute('maxTokens')
assert model.max_tokens != model.get_attribute('max_tokens')


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(
model={'modelId': 'fakeModel'},
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
),
tracker=tracker,
enabled=True,
model=ModelConfig('fakeModel'),
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
)
variables = {'name': 'World'}

config = ldai_client.model_config('model-config', context, default_value, variables)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, World!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, World!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096


def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])

config = ldai_client.model_config('model-config', context, default_value, {})

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, !'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, !'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096


def test_context_interpolation(ldai_client: LDAIClient, tracker):
context = Context.builder('user-key').name("Sandy").build()
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
variables = {'name': 'World'}

config = ldai_client.model_config(
'ctx-interpolation', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, Sandy!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, Sandy!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature is None
assert config.model.max_tokens is None
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'


def test_model_config_multiple(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
variables = {'name': 'World', 'day': 'Monday'}

config = ldai_client.model_config(
'multiple-prompt', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, World!'
assert config.config.prompt[1].content == 'The day is, Monday!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, World!'
assert config.prompt[1].content == 'The day is, Monday!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.7
assert config.model.max_tokens == 8192


def test_model_config_disabled(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False
)
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[])

config = ldai_client.model_config('off-config', context, default_value, {})

assert config.model is not None
assert config.enabled is False
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.1
assert config.model.max_tokens is None

0 comments on commit daf9537

Please sign in to comment.