Skip to content

Commit

Permalink
added gemma on HF
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 26, 2024
1 parent 8d13520 commit 925f524
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
4 changes: 3 additions & 1 deletion libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
GemmaLocalKaggle,
GemmaVertexAIModelGarden,
)
Expand All @@ -18,6 +18,8 @@
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaLocalHF",
"GemmaChatLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
Expand Down
8 changes: 7 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Utilities to init Vertex AI."""

import dataclasses
import re
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import google.api_core
import proto # type: ignore[import-untyped]
Expand Down Expand Up @@ -162,3 +163,8 @@ def get_generation_info(
info.pop("is_blocked")

return info


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
118 changes: 112 additions & 6 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from langchain_core.pydantic_v1 import BaseModel, root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
from langchain_google_vertexai.model_garden import VertexAIModelGarden

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
Expand Down Expand Up @@ -118,9 +119,12 @@ def _generate(
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)
Expand All @@ -135,19 +139,22 @@ async def _agenerate(
"""Top Level call"""
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = await self.async_client.predict_(
output = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)


class _GemmaLocalKaggleBase(_GemmaBase):
"""Local gemma model."""
"""Local gemma model loaded from Kaggle."""

client: Any = None #: :meta private:
keras_backend: str = "jax"
Expand Down Expand Up @@ -178,6 +185,8 @@ def _default_params(self) -> Dict[str, Any]:


class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""

def _generate(
self,
prompts: List[str],
Expand All @@ -189,6 +198,8 @@ def _generate(
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = self.client.generate(prompts, **params)
results = results if isinstance(results, str) else [results]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=result)] for result in results])

@property
Expand All @@ -207,11 +218,106 @@ def _generate(
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
output = self.client.generate(prompt, **params)
generation = ChatGeneration(message=AIMessage(content=output))
text = self.client.generate(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_kaggle"


class _GemmaLocalHFBase(_GemmaBase):
"""Local gemma model loaded from HuggingFace."""

tokenizer: Any = None #: :meta private:
client: Any = None #: :meta private:
hf_access_token: str
cache_dir: Optional[str] = None
model_name: str = "gemma_2b_en"
"""Gemma model name."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
try:
from transformers import AutoTokenizer, GemmaForCausalLM # type: ignore
except ImportError:
raise ImportError(
"Could not import GemmaForCausalLM library. "
"Please install the GemmaForCausalLM library to "
"use this model: pip install transformers>=4.38.1"
)

values["tokenizer"] = AutoTokenizer.from_pretrained(
values["model_name"], token=values["hf_access_token"]
)
values["client"] = GemmaForCausalLM.from_pretrained(
values["model_name"],
token=values["hf_access_token"],
cache_dir=values["cache_dir"],
)
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling gemma."""
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _run(self, prompt: str, kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
return self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]


class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM):
"""Local gemma model loaded from HuggingFace."""

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = [self._run(prompt, **params) for prompt in prompts]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=text)] for text in results])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_hf"


class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):
"""Local gemma chat model loaded from HuggingFace."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self._run(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_hf"
1 change: 0 additions & 1 deletion libs/vertexai/tests/integration_tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)


@pytest.mark.skip("CI testing not set up")
@pytest.mark.skip("CI testing not set up")
def test_gemma_model_garden() -> None:
"""In order to run this test, you should provide endpoint names.
Expand Down
2 changes: 2 additions & 0 deletions libs/vertexai/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaChatLocalHF",
"GemmaLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
Expand Down

0 comments on commit 925f524

Please sign in to comment.