Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Custom and Imported models in ChatBedrock/ChatBedrockConverse #253

Open
3coins opened this issue Oct 24, 2024 · 2 comments
Open

Comments

@3coins
Copy link
Collaborator

3coins commented Oct 24, 2024

Problem

The current Bedrock API has 5 different types of model ids that it supports as documented here. While some of these have a defined pattern for the provider/model and can be used to extract these values, others don't have the necessary components to derive these values.

Both ChatBedrockConverse and ChatBedrock classes currently rely on the provider and model values to determine some of the logic required to process provider and model specific code. Our current approach of parsing the model id and determining the provider and model though works for foundation models and cross-region model ids, it does not work well for custom and imported models, that might not have the necessary components to derive the provider and model values. Here are examples of the code used in the 2 implementations.

def set_disable_streaming(cls, values: Dict) -> Any:
model_id = values.get("model_id", values.get("model"))
model_parts = model_id.split(".")
values["provider"] = values.get("provider") or (
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
)
# As of 09/15/24 Anthropic and Cohere models support streamed tool calling
if "disable_streaming" not in values:
values["disable_streaming"] = (
False
if values["provider"] in ["anthropic", "cohere"]
else "tool_calling"
)
return values

def _get_provider(self) -> str:
# If provider supplied by user, return as-is
if self.provider:
return self.provider
# If model_id is an arn, can't extract provider from model_id,
# so this requires passing in the provider by user
if self.model_id.startswith("arn"):
raise ValueError(
"Model provider should be supplied when passing a model ARN as "
"model_id"
)
# If model_id has region prefixed to them,
# for example eu.anthropic.claude-3-haiku-20240307-v1:0,
# provider is the second part, otherwise, the first part
parts = self.model_id.split(".", maxsplit=2)
return (
parts[1]
if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "ap", "sa"})
else parts[0]
)

Solution

Short Term

  1. Current implementation do have a provider field available that users can use to pass in the provider value. This does need a slight change in the ChatBedrockConverse where we should raise an exception if either a custom or imported model id or arn is passed by the user and the provider value is missing.
  2. We also need a base_model or a base_model_id field that users can pass in to identify the base model used to create the custom or imported model.

Long Term

Longer term, the model info should be built into the Bedrock API, so that when new types of models or providers are introduced, we don't have to update logic in chat model implementations to accommodate them. Having this within the Bedrock API would also be helpful in other places not just LangChain implementation. This API should provide a comprehensive set of capabilities for the model not just the provider, or base model. For example, streaming, tool support, tool streaming, multi-modal support etc and should work for all kinds of models not just foundation models.

While the current Bedrock API get-foundation-model provides some info about the foundation models, it is not clear if it would work for custom and imported models. It also doesn't give any information about whether the model supports tool calls, streaming tool calls or other features. Here is an example of running the API call for the Anthropic haiku model.

% aws bedrock get-foundation-model --model-identifier anthropic.claude-3-haiku-20240307-v1:0
{
    "modelDetails": {
        "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-haiku-20240307-v1:0",
        "modelId": "anthropic.claude-3-haiku-20240307-v1:0",
        "modelName": "Claude 3 Haiku",
        "providerName": "Anthropic",
        "inputModalities": [
            "TEXT",
            "IMAGE"
        ],
        "outputModalities": [
            "TEXT"
        ],
        "responseStreamingSupported": true,
        "customizationsSupported": [],
        "inferenceTypesSupported": [
            "ON_DEMAND"
        ],
        "modelLifecycle": {
            "status": "ACTIVE"
        }
    }
}

Note: Bedrock does seem to have a get-custom-model API as well, but I don't have a custom model or imported model setup to verify these APIs. There is also a list-foundation-models and list-custom-models APIs, that provide the same info for all models present in a user's account, however these seem to be a heavy-handed approach than the get APIs that provide info for specific models.

@rsgrewal-aws
Copy link

can we link this to the #153

@rsgrewal-aws
Copy link

The model provider values should be input to the ChatBedrock and ChatBedrockConverse classes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants