Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added gemma on HF #26

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
GemmaLocalKaggle,
GemmaVertexAIModelGarden,
)
Expand All @@ -18,6 +18,8 @@
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaLocalHF",
"GemmaChatLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
Expand Down
8 changes: 7 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Utilities to init Vertex AI."""

import dataclasses
import re
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import google.api_core
import proto # type: ignore[import-untyped]
Expand Down Expand Up @@ -162,3 +163,8 @@ def get_generation_info(
info.pop("is_blocked")

return info


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
118 changes: 112 additions & 6 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from langchain_core.pydantic_v1 import BaseModel, root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
from langchain_google_vertexai.model_garden import VertexAIModelGarden

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
Expand Down Expand Up @@ -118,9 +119,12 @@ def _generate(
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)
Expand All @@ -135,19 +139,22 @@ async def _agenerate(
"""Top Level call"""
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = await self.async_client.predict_(
output = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)


class _GemmaLocalKaggleBase(_GemmaBase):
"""Local gemma model."""
"""Local gemma model loaded from Kaggle."""

client: Any = None #: :meta private:
keras_backend: str = "jax"
Expand Down Expand Up @@ -178,6 +185,8 @@ def _default_params(self) -> Dict[str, Any]:


class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""

def _generate(
self,
prompts: List[str],
Expand All @@ -189,6 +198,8 @@ def _generate(
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = self.client.generate(prompts, **params)
results = results if isinstance(results, str) else [results]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=result)] for result in results])

@property
Expand All @@ -207,11 +218,106 @@ def _generate(
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
output = self.client.generate(prompt, **params)
generation = ChatGeneration(message=AIMessage(content=output))
text = self.client.generate(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_kaggle"


class _GemmaLocalHFBase(_GemmaBase):
"""Local gemma model loaded from HuggingFace."""

tokenizer: Any = None #: :meta private:
client: Any = None #: :meta private:
hf_access_token: str
cache_dir: Optional[str] = None
model_name: str = "gemma_2b_en"
"""Gemma model name."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
try:
from transformers import AutoTokenizer, GemmaForCausalLM # type: ignore
except ImportError:
raise ImportError(
"Could not import GemmaForCausalLM library. "
"Please install the GemmaForCausalLM library to "
"use this model: pip install transformers>=4.38.1"
)

values["tokenizer"] = AutoTokenizer.from_pretrained(
values["model_name"], token=values["hf_access_token"]
)
values["client"] = GemmaForCausalLM.from_pretrained(
values["model_name"],
token=values["hf_access_token"],
cache_dir=values["cache_dir"],
)
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling gemma."""
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _run(self, prompt: str, kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
return self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]


class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM):
"""Local gemma model loaded from HuggingFace."""

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = [self._run(prompt, **params) for prompt in prompts]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=text)] for text in results])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_hf"


class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):
"""Local gemma chat model loaded from HuggingFace."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self._run(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_hf"
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ def _get_default_embeddings(cls) -> Embeddings:
)

# TODO: Change to vertexai embbedingss
from langchain_community import (
embeddings, # type: ignore[import-not-found, unused-ignore]
from langchain_community.embeddings import ( # type: ignore[import-not-found, unused-ignore]
TensorflowHubEmbeddings,
)

return embeddings.TensorflowHubEmbeddings()
return TensorflowHubEmbeddings()

def _generate_unique_ids(self, number: int) -> List[str]:
"""Generates a list of unique ids of length `number`
Expand Down
1 change: 0 additions & 1 deletion libs/vertexai/tests/integration_tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)


@pytest.mark.skip("CI testing not set up")
@pytest.mark.skip("CI testing not set up")
def test_gemma_model_garden() -> None:
"""In order to run this test, you should provide endpoint names.
Expand Down
2 changes: 2 additions & 0 deletions libs/vertexai/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaChatLocalHF",
"GemmaLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
Expand Down
Loading