-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support model_arn in AmazonBedrockGenerator #1244
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good to me already. Just add model_family param to the docstring of the get_model_adapter method and I'd suggest to set model_family to None by default.
@@ -250,17 +271,28 @@ def run( | |||
return {"replies": replies} | |||
|
|||
@classmethod | |||
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: | |||
def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]: | |||
""" | |||
Gets the model adapter for the given model. | |||
|
|||
:param model: The model name. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a docstring for model_family here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not set model_family
to None by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
], | ||
) | ||
def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): | ||
""" | ||
Test that the correct model adapter is returned for a given model | ||
""" | ||
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model) | ||
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If model_family is set to None by default, we don't need to pass it here explicitly.
Test that an error is raised when auto-detection of model_family fails | ||
""" | ||
with pytest.raises(AmazonBedrockConfigurationError): | ||
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave out model_family=None
if we set it to None by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 👍
Related Issues
Proposed Changes:
model_family
param which explicitly chooses model adapter if it can't be infered from model nameHow did you test it?
Notes for the reviewer
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
.