diff --git a/ldai/client.py b/ldai/client.py index b3708da..e2d53a2 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -21,13 +21,15 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None): + def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None): """ :param id: The ID of the model. :param parameters: Additional model-specific parameters. + :param custom: Additional customer provided data. """ self._id = id self._parameters = parameters + self._custom = custom @property def id(self) -> str: @@ -51,6 +53,15 @@ def get_parameter(self, key: str) -> Any: return self._parameters.get(key) + def get_custom(self, key: str) -> Any: + """ + Retrieve customer provided data. + """ + if self._custom is None: + return None + + return self._custom.get(key) + class ProviderConfig: """ @@ -128,9 +139,11 @@ def config( model = None if 'model' in variation and isinstance(variation['model'], dict): parameters = variation['model'].get('parameters', None) + custom = variation['model'].get('custom', None) model = ModelConfig( id=variation['model']['id'], - parameters=parameters + parameters=parameters, + custom=custom ) enabled = variation.get('_ldMeta', {}).get('enabled', False) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 8bf902a..6f97a4d 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -13,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}}, + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}}, 'provider': {'id': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, @@ -117,6 +117,14 @@ def test_model_config_delegates_to_properties(): assert model.id == model.get_parameter('id') +def test_model_config_handles_custom(): + model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'}) + assert model.id == 'fakeModel' + assert model.get_parameter('extra-attribute') is None + assert model.get_custom('non-existent') is None + assert model.get_custom('id') is None + + def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(