From 028ea48ce092b95e143bc723532fec59f33b55fc Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 26 Feb 2024 21:37:22 +0100 Subject: [PATCH] fix gemma_hf --- libs/vertexai/langchain_google_vertexai/__init__.py | 13 +++++++++++++ libs/vertexai/langchain_google_vertexai/gemma.py | 4 ++-- libs/vertexai/pyproject.toml | 2 +- libs/vertexai/tests/unit_tests/test_imports.py | 5 +++++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index 858d0f43..9adfcd22 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -3,6 +3,7 @@ from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser from langchain_google_vertexai.gemma import ( + GemmaChatLocalHF, GemmaChatLocalKaggle, GemmaChatVertexAIModelGarden, GemmaLocalHF, @@ -12,6 +13,13 @@ from langchain_google_vertexai.llms import VertexAI from langchain_google_vertexai.model_garden import VertexAIModelGarden from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore +from langchain_google_vertexai.vision_models import ( + VertexAIImageCaptioning, + VertexAIImageCaptioningChat, + VertexAIImageEditorChat, + VertexAIImageGeneratorChat, + VertexAIVisualQnAChat, +) __all__ = [ "ChatVertexAI", @@ -29,4 +37,9 @@ "PydanticFunctionsOutputParser", "create_structured_runnable", "VectorSearchVectorStore", + "VertexAIImageCaptioning", + "VertexAIImageCaptioningChat", + "VertexAIImageEditorChat", + "VertexAIImageGeneratorChat", + "VertexAIVisualQnAChat", ] diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index 30280de0..c24cb0b1 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -197,7 +197,7 @@ def _generate( """Run the LLM on the given prompt and input.""" params = {"max_length": self.max_tokens} if self.max_tokens else {} results = self.client.generate(prompts, **params) - results = results if isinstance(results, str) else [results] + results = [results] if isinstance(results, str) else results if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=result)] for result in results]) @@ -268,7 +268,7 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} - def _run(self, prompt: str, kwargs: Any) -> str: + def _run(self, prompt: str, **kwargs: Any) -> str: inputs = self.tokenizer(prompt, return_tensors="pt") generate_ids = self.client.generate(inputs.input_ids, **kwargs) return self.tokenizer.batch_decode( diff --git a/libs/vertexai/pyproject.toml b/libs/vertexai/pyproject.toml index ff0f6976..4dbc328e 100644 --- a/libs/vertexai/pyproject.toml +++ b/libs/vertexai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-google-vertexai" -version = "0.0.6" +version = "0.0.7" description = "An integration package connecting GoogleVertexAI and LangChain" authors = [] readme = "README.md" diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index ab722808..02f721fe 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -16,6 +16,11 @@ "PydanticFunctionsOutputParser", "create_structured_runnable", "VectorSearchVectorStore", + "VertexAIImageCaptioning", + "VertexAIImageCaptioningChat", + "VertexAIImageEditorChat", + "VertexAIImageGeneratorChat", + "VertexAIVisualQnAChat", ]