Skip to content

Commit

Permalink
fix gemma_hf
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 27, 2024
1 parent 078b25c commit 028ea48
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
13 changes: 13 additions & 0 deletions libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
Expand All @@ -12,6 +13,13 @@
from langchain_google_vertexai.llms import VertexAI
from langchain_google_vertexai.model_garden import VertexAIModelGarden
from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore
from langchain_google_vertexai.vision_models import (
VertexAIImageCaptioning,
VertexAIImageCaptioningChat,
VertexAIImageEditorChat,
VertexAIImageGeneratorChat,
VertexAIVisualQnAChat,
)

__all__ = [
"ChatVertexAI",
Expand All @@ -29,4 +37,9 @@
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
"VertexAIImageEditorChat",
"VertexAIImageGeneratorChat",
"VertexAIVisualQnAChat",
]
4 changes: 2 additions & 2 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _generate(
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = self.client.generate(prompts, **params)
results = results if isinstance(results, str) else [results]
results = [results] if isinstance(results, str) else results
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=result)] for result in results])
Expand Down Expand Up @@ -268,7 +268,7 @@ def _default_params(self) -> Dict[str, Any]:
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _run(self, prompt: str, kwargs: Any) -> str:
def _run(self, prompt: str, **kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
return self.tokenizer.batch_decode(
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.0.6"
version = "0.0.7"
description = "An integration package connecting GoogleVertexAI and LangChain"
authors = []
readme = "README.md"
Expand Down
5 changes: 5 additions & 0 deletions libs/vertexai/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
"VertexAIImageEditorChat",
"VertexAIImageGeneratorChat",
"VertexAIVisualQnAChat",
]


Expand Down

0 comments on commit 028ea48

Please sign in to comment.