Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Change the typing for the AIConfig #16

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,69 @@ class LDMessage:
content: str


@dataclass
class AIConfigData:
model: Optional[dict]
prompt: Optional[List[LDMessage]]
class ModelConfig:
"""
Configuration related to the model.
"""

def __init__(self, id: str, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, attributes: dict = {}):
"""
:param id: The ID of the model.
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
:param max_tokens: The maximum number of tokens.
:param attributes: Additional model-specific attributes.
"""
self._id = id
self._temperature = temperature
self._max_tokens = max_tokens
self._attributes = attributes

@property
def id(self) -> str:
"""
The ID of the model.
"""
return self._id

@property
def temperature(self) -> Optional[float]:
""""
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
"""
return self._temperature

@property
def max_tokens(self) -> Optional[int]:
"""
The maximum number of tokens.
"""

return self._max_tokens

def get_attribute(self, key: str) -> Any:
"""
Retrieve model-specific attributes.

Accessing a named, typed attribute (e.g. id) will result in the call
being delegated to the appropriate property.
"""
if key == 'id':
return self.id
if key == 'temperature':
return self.temperature
if key == 'maxTokens':
return self.max_tokens

return self._attributes.get(key)


class AIConfig:
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
self.config = config
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]):
self.tracker = tracker
self.enabled = enabled
self.model = model
self.prompt = prompt


class LDAIClient:
Expand Down Expand Up @@ -71,16 +123,26 @@ def model_config(
for entry in variation['prompt']
]

model = None
if 'model' in variation:
model = ModelConfig(
id=variation['model']['modelId'],
temperature=variation['model'].get('temperature'),
max_tokens=variation['model'].get('maxTokens'),
attributes=variation['model'],
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
return AIConfig(
config=AIConfigData(model=variation['model'], prompt=prompt),
tracker=LDAIConfigTracker(
self.client,
variation.get('_ldMeta', {}).get('versionKey', ''),
key,
context,
),
enabled=bool(enabled),
model=model,
prompt=prompt
)

def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
Expand Down
98 changes: 63 additions & 35 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData
from ldclient.testing.builders import *

from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
from ldai.tracker import LDAIConfigTracker


Expand All @@ -14,7 +13,7 @@ def td() -> TestData:
td.flag('model-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
},
Expand All @@ -27,7 +26,7 @@ def td() -> TestData:
td.flag('multiple-prompt')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192},
'prompt': [
{'role': 'system', 'content': 'Hello, {{name}}!'},
{'role': 'user', 'content': 'The day is, {{day}}!'},
Expand All @@ -43,7 +42,7 @@ def td() -> TestData:
td.flag('ctx-interpolation')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'},
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
}
Expand All @@ -55,7 +54,7 @@ def td() -> TestData:
td.flag('off-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'model': {'modelId': 'fakeModel', 'temperature': 0.1},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
}
Expand All @@ -82,81 +81,110 @@ def ldai_client(client: LDClient) -> LDAIClient:
return LDAIClient(client)


def test_model_config_delegates_to_properties():
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.temperature == 0.5
assert model.max_tokens == 4096
assert model.get_attribute('extra-attribute') == 'value'
assert model.get_attribute('non-existent') is None

assert model.id == model.get_attribute('id')
assert model.temperature == model.get_attribute('temperature')
assert model.max_tokens == model.get_attribute('maxTokens')
assert model.max_tokens != model.get_attribute('max_tokens')


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(
model={'modelId': 'fakeModel'},
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
),
tracker=tracker,
enabled=True,
model=ModelConfig('fakeModel'),
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
)
variables = {'name': 'World'}

config = ldai_client.model_config('model-config', context, default_value, variables)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, World!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, World!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096


def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])

config = ldai_client.model_config('model-config', context, default_value, {})

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, !'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, !'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.5
assert config.model.max_tokens == 4096


def test_context_interpolation(ldai_client: LDAIClient, tracker):
context = Context.builder('user-key').name("Sandy").build()
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
variables = {'name': 'World'}

config = ldai_client.model_config(
'ctx-interpolation', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, Sandy!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, Sandy!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature is None
assert config.model.max_tokens is None
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'


def test_model_config_multiple(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
variables = {'name': 'World', 'day': 'Monday'}

config = ldai_client.model_config(
'multiple-prompt', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, World!'
assert config.config.prompt[1].content == 'The day is, Monday!'
assert config.prompt is not None
assert len(config.prompt) > 0
assert config.prompt[0].content == 'Hello, World!'
assert config.prompt[1].content == 'The day is, Monday!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.7
assert config.model.max_tokens == 8192


def test_model_config_disabled(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False
)
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[])

config = ldai_client.model_config('off-config', context, default_value, {})

assert config.model is not None
assert config.enabled is False
assert config.model.id == 'fakeModel'
assert config.model.temperature == 0.1
assert config.model.max_tokens is None