Skip to content

Commit

Permalink
apply feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Dec 11, 2024
1 parent 7b626ae commit 44e0c82
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,17 @@ def run(
return {"replies": replies}

@classmethod
def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]:
def get_model_adapter(cls, model: str, model_family: Optional[str] = None) -> Type[BedrockModelAdapter]:
"""
Gets the model adapter for the given model.
If `model_family` is provided, the adapter for the model family is returned.
If `model_family` is not provided, the adapter is auto-detected based on the model name.
:param model: The model name.
:param model_family: The model family.
:returns: The model adapter class, or None if no adapter is found.
:raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be auto-detected.
"""
if model_family:
if model_family not in cls.SUPPORTED_MODEL_FAMILIES:
Expand Down
4 changes: 2 additions & 2 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed
"""
Test that the correct model adapter is returned for a given model
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None)
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model)
assert model_adapter == expected_model_adapter


Expand Down Expand Up @@ -345,7 +345,7 @@ def test_get_model_adapter_auto_detect_family_fails():
Test that an error is raised when auto-detection of model_family fails
"""
with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None)
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423")


def test_get_model_adapter_model_family_over_auto_detection():
Expand Down

0 comments on commit 44e0c82

Please sign in to comment.