Skip to content

Commit

Permalink
Llama maas (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Aug 21, 2024
1 parent dedfe89 commit 86aade8
Show file tree
Hide file tree
Showing 10 changed files with 716 additions and 48 deletions.
2 changes: 2 additions & 0 deletions libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from langchain_google_vertexai.llms import VertexAI
from langchain_google_vertexai.model_garden import VertexAIModelGarden
from langchain_google_vertexai.model_garden_maas import get_vertex_maas_model
from langchain_google_vertexai.utils import create_context_cache
from langchain_google_vertexai.vectorstores import (
DataStoreDocumentStorage,
Expand Down Expand Up @@ -81,4 +82,5 @@
"VertexPairWiseStringEvaluator",
"VertexStringEvaluator",
"create_context_cache",
"get_vertex_maas_model",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from langchain_google_vertexai.model_garden_maas.llama import VertexModelGardenLlama

_MISTRAL_MODELS = [
"mistral-nemo@2407",
"mistral-large@2407",
]
_LLAMA_MODELS = ["meta/llama3-405b-instruct-maas"]
_MAAS_MODELS = _MISTRAL_MODELS + _LLAMA_MODELS


def get_vertex_maas_model(model_name, **kwargs):
"""Return a corresponding Vertex MaaS instance.
A factory method based on model's name.
"""
if model_name not in _MAAS_MODELS:
raise ValueError(f"model name {model_name} is not supported!")
if model_name in _MISTRAL_MODELS:
from langchain_google_vertexai.model_garden_maas.mistral import ( # noqa: F401
VertexModelGardenMistral,
)

return VertexModelGardenMistral(model=model_name, **kwargs)
return VertexModelGardenLlama(model=model_name, **kwargs)
42 changes: 37 additions & 5 deletions libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
Union,
)

import httpx # type: ignore[unused-ignore, import-not-found]
import httpx
from google import auth
from google.auth.credentials import Credentials
from google.auth.transport import requests as auth_requests
from httpx_sse import ( # type: ignore[import-not-found]
from httpx_sse import (
EventSource,
aconnect_sse,
connect_sse,
Expand Down Expand Up @@ -99,13 +99,36 @@ class _BaseVertexMaasModelGarden(_VertexAIBase):
append_tools_to_system_message: bool = False
"Whether to append tools to the system message or not."
model_family: Optional[VertexMaaSModelFamily] = None
timeout: int = 120

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
token = _get_token(credentials=self.credentials)
endpoint = self.get_url()
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {token}",
"x-goog-api-client": self._library_version,
"user_agent": self._user_agent,
}
self.client = httpx.Client(
base_url=endpoint,
headers=headers,
timeout=self.timeout,
)
self.async_client = httpx.AsyncClient(
base_url=endpoint,
headers=headers,
timeout=self.timeout,
)

@root_validator(pre=True)
def validate_environment_model_garden(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
Expand All @@ -132,7 +155,7 @@ def _get_url_part(self, stream: bool = False) -> str:
":streamRawPredict"
)
return f"publishers/mistralai/models/{self.full_model_name}:rawPredict"
return "openapi/chat/completions"
return "endpoints/openapi/chat/completions"

def get_url(self) -> str:
if self.model_family == VertexMaaSModelFamily.LLAMA:
Expand Down Expand Up @@ -173,12 +196,17 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream:
# Llama and Mistral expect different "Content-Type" for streaming
headers = {"Accept": "text/event-stream"}
if headers_content_type := kwargs.pop("headers_content_type", None):
headers["Content-Type"] = headers_content_type

event_source = aconnect_sse(
llm.async_client,
"POST",
llm._get_url_part(stream=True),
json=kwargs,
headers={"Accept": "text/event-stream"},
headers=headers,
)
return _aiter_sse(event_source)
else:
Expand All @@ -197,14 +225,18 @@ def completion_with_retry(llm: _BaseVertexMaasModelGarden, **kwargs):
kwargs = llm._enrich_params(kwargs)

if stream:
# Llama and Mistral expect different "Content-Type" for streaming
headers = {"Accept": "text/event-stream"}
if headers_content_type := kwargs.pop("headers_content_type", None):
headers["Content-Type"] = headers_content_type

def iter_sse():
with connect_sse(
llm.client,
"POST",
llm._get_url_part(stream=True),
json=kwargs,
headers={"Accept": "text/event-stream"},
headers=headers,
) as event_source:
_raise_on_error(event_source.response)
for event in event_source.iter_sse():
Expand Down
Loading

0 comments on commit 86aade8

Please sign in to comment.