From 925f524ea8f0ac0be5c11add2bf9220da654d73c Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 26 Feb 2024 18:58:24 +0100 Subject: [PATCH] added gemma on HF --- .../langchain_google_vertexai/__init__.py | 4 +- .../langchain_google_vertexai/_utils.py | 8 +- .../langchain_google_vertexai/gemma.py | 118 +++++++++++++++++- .../tests/integration_tests/test_gemma.py | 1 - .../vertexai/tests/unit_tests/test_imports.py | 2 + 5 files changed, 124 insertions(+), 9 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index 9f4b90ee..81aec35a 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -1,11 +1,11 @@ from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai.chains import create_structured_runnable from langchain_google_vertexai.chat_models import ChatVertexAI -from langchain_google_vertexai.embeddings import VertexAIEmbeddings from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser from langchain_google_vertexai.gemma import ( GemmaChatLocalKaggle, GemmaChatVertexAIModelGarden, + GemmaLocalHF, GemmaLocalKaggle, GemmaVertexAIModelGarden, ) @@ -18,6 +18,8 @@ "GemmaChatVertexAIModelGarden", "GemmaLocalKaggle", "GemmaChatLocalKaggle", + "GemmaLocalHF", + "GemmaChatLocalHF", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden", diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index 641721af..e2487940 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -1,8 +1,9 @@ """Utilities to init Vertex AI.""" import dataclasses +import re from importlib import metadata -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import google.api_core import proto # type: ignore[import-untyped] @@ -162,3 +163,8 @@ def get_generation_info( info.pop("is_blocked") return info + + +def enforce_stop_tokens(text: str, stop: List[str]) -> str: + """Cut off the text as soon as any stop words occur.""" + return re.split("|".join(stop), text, maxsplit=1)[0] diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index 5b90d0bc..30280de0 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_google_vertexai._base import _BaseVertexAIModelGarden +from langchain_google_vertexai._utils import enforce_stop_tokens from langchain_google_vertexai.model_garden import VertexAIModelGarden USER_CHAT_TEMPLATE = "user\n{prompt}\n" @@ -118,9 +119,12 @@ def _generate( request = self._get_params(**kwargs) request["prompt"] = gemma_messages_to_prompt(messages) output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) + text = output.predictions[0] + if stop: + text = enforce_stop_tokens(text, stop) generations = [ ChatGeneration( - message=AIMessage(content=output.predictions[0]), + message=AIMessage(content=text), ) ] return ChatResult(generations=generations) @@ -135,19 +139,22 @@ async def _agenerate( """Top Level call""" request = self._get_params(**kwargs) request["prompt"] = gemma_messages_to_prompt(messages) - output = await self.async_client.predict_( + output = await self.async_client.predict( endpoint=self.endpoint_path, instances=[request] ) + text = output.predictions[0] + if stop: + text = enforce_stop_tokens(text, stop) generations = [ ChatGeneration( - message=AIMessage(content=output.predictions[0]), + message=AIMessage(content=text), ) ] return ChatResult(generations=generations) class _GemmaLocalKaggleBase(_GemmaBase): - """Local gemma model.""" + """Local gemma model loaded from Kaggle.""" client: Any = None #: :meta private: keras_backend: str = "jax" @@ -178,6 +185,8 @@ def _default_params(self) -> Dict[str, Any]: class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): + """Local gemma chat model loaded from Kaggle.""" + def _generate( self, prompts: List[str], @@ -189,6 +198,8 @@ def _generate( params = {"max_length": self.max_tokens} if self.max_tokens else {} results = self.client.generate(prompts, **params) results = results if isinstance(results, str) else [results] + if stop: + results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=result)] for result in results]) @property @@ -207,11 +218,106 @@ def _generate( ) -> ChatResult: params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - output = self.client.generate(prompt, **params) - generation = ChatGeneration(message=AIMessage(content=output)) + text = self.client.generate(prompt, **params) + if stop: + text = enforce_stop_tokens(text, stop) + generation = ChatGeneration(message=AIMessage(content=text)) return ChatResult(generations=[generation]) @property def _llm_type(self) -> str: """Return type of llm.""" return "gemma_local_chat_kaggle" + + +class _GemmaLocalHFBase(_GemmaBase): + """Local gemma model loaded from HuggingFace.""" + + tokenizer: Any = None #: :meta private: + client: Any = None #: :meta private: + hf_access_token: str + cache_dir: Optional[str] = None + model_name: str = "gemma_2b_en" + """Gemma model name.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that llama-cpp-python library is installed.""" + try: + from transformers import AutoTokenizer, GemmaForCausalLM # type: ignore + except ImportError: + raise ImportError( + "Could not import GemmaForCausalLM library. " + "Please install the GemmaForCausalLM library to " + "use this model: pip install transformers>=4.38.1" + ) + + values["tokenizer"] = AutoTokenizer.from_pretrained( + values["model_name"], token=values["hf_access_token"] + ) + values["client"] = GemmaForCausalLM.from_pretrained( + values["model_name"], + token=values["hf_access_token"], + cache_dir=values["cache_dir"], + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling gemma.""" + params = {"max_length": self.max_tokens} + return {k: v for k, v in params.items() if v is not None} + + def _run(self, prompt: str, kwargs: Any) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt") + generate_ids = self.client.generate(inputs.input_ids, **kwargs) + return self.tokenizer.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + +class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM): + """Local gemma model loaded from HuggingFace.""" + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + params = {"max_length": self.max_tokens} if self.max_tokens else {} + results = [self._run(prompt, **params) for prompt in prompts] + if stop: + results = [enforce_stop_tokens(text, stop) for text in results] + return LLMResult(generations=[[Generation(text=text)] for text in results]) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "gemma_local_hf" + + +class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): + """Local gemma chat model loaded from HuggingFace.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + params = {"max_length": self.max_tokens} if self.max_tokens else {} + prompt = gemma_messages_to_prompt(messages) + text = self._run(prompt, **params) + if stop: + text = enforce_stop_tokens(text, stop) + generation = ChatGeneration(message=AIMessage(content=text)) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "gemma_local_chat_hf" diff --git a/libs/vertexai/tests/integration_tests/test_gemma.py b/libs/vertexai/tests/integration_tests/test_gemma.py index de2d27ec..cb8705c9 100644 --- a/libs/vertexai/tests/integration_tests/test_gemma.py +++ b/libs/vertexai/tests/integration_tests/test_gemma.py @@ -14,7 +14,6 @@ ) -@pytest.mark.skip("CI testing not set up") @pytest.mark.skip("CI testing not set up") def test_gemma_model_garden() -> None: """In order to run this test, you should provide endpoint names. diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index e207e7e6..e1b1ab1e 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -6,6 +6,8 @@ "GemmaChatVertexAIModelGarden", "GemmaLocalKaggle", "GemmaChatLocalKaggle", + "GemmaChatLocalHF", + "GemmaLocalHF", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden",