Skip to content

Commit

Permalink
feat: Add custom parameter support to model config
Browse files Browse the repository at this point in the history
  • Loading branch information
keelerm84 committed Nov 22, 2024
1 parent a2cc966 commit 95015f1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
17 changes: 15 additions & 2 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class ModelConfig:
Configuration related to the model.
"""

def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None):
def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
"""
:param id: The ID of the model.
:param parameters: Additional model-specific parameters.
:param custom: Additional customer provided data.
"""
self._id = id
self._parameters = parameters
self._custom = custom

@property
def id(self) -> str:
Expand All @@ -51,6 +53,15 @@ def get_parameter(self, key: str) -> Any:

return self._parameters.get(key)

def get_custom(self, key: str) -> Any:
"""
Retrieve customer provided data.
"""
if self._custom is None:
return None

return self._custom.get(key)


class ProviderConfig:
"""
Expand Down Expand Up @@ -128,9 +139,11 @@ def config(
model = None
if 'model' in variation and isinstance(variation['model'], dict):
parameters = variation['model'].get('parameters', None)
custom = variation['model'].get('custom', None)
model = ModelConfig(
id=variation['model']['id'],
parameters=parameters
parameters=parameters,
custom=custom
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
Expand Down
10 changes: 9 additions & 1 deletion ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def td() -> TestData:
td.flag('model-config')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}},
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}},
'provider': {'id': 'fakeProvider'},
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
Expand Down Expand Up @@ -117,6 +117,14 @@ def test_model_config_delegates_to_properties():
assert model.id == model.get_parameter('id')


def test_model_config_handles_custom():
model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.get_parameter('extra-attribute') is None
assert model.get_custom('non-existent') is None
assert model.get_custom('id') is None


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(
Expand Down

0 comments on commit 95015f1

Please sign in to comment.