From ad813957db7543b5d6f848bcf8e734e14a676fff Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 21 Feb 2024 22:03:40 +0100 Subject: [PATCH 1/4] Added support for imagetext models --- .../vision_models.py | 274 ++++++++++++++++++ .../integration_tests/test_vision_models.py | 82 ++++++ .../tests/unit_tests/test_vision_models.py | 62 ++++ 3 files changed, 418 insertions(+) create mode 100644 libs/vertexai/langchain_google_vertexai/vision_models.py create mode 100644 libs/vertexai/tests/integration_tests/test_vision_models.py create mode 100644 libs/vertexai/tests/unit_tests/test_vision_models.py diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py new file mode 100644 index 00000000..4f9c41a6 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -0,0 +1,274 @@ +from typing import Any, Dict, List + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, BaseLLM +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatResult, LLMResult +from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.outputs.generation import Generation +from langchain_core.pydantic_v1 import BaseModel, Field +from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped] + +from langchain_google_vertexai._image_utils import ImageBytesLoader + + +class BaseImageTextModel(BaseModel): + """Base class for all integrations that use ImageTextModel""" + + model_name: str = Field(default="imagetext@001") + """ Name of the model to use""" + number_of_results: int = Field(default=1) + """Number of results to return from one query""" + language: str = Field(default="en") + """Language of the query""" + project: str = Field(default=None) + """Google cloud project""" + + def _create_model(self) -> ImageTextModel: + """Builds the model object from the class attributes.""" + return ImageTextModel.from_pretrained(model_name=self.model_name) + + @staticmethod + def _get_image_from_message(message: BaseMessage) -> Image: + """Extracts an image from a message. + + Args: + message: Message to extract the image from. + + Returns: + Image extracted from the message. + """ + + loader = ImageBytesLoader() + + if isinstance(message.content, str): + return Image(loader.load_bytes(image_string=message.content)) + + if isinstance(message.content, List): + if len(message.content) > 1: + raise ValueError( + "Expected message content to have only one part" + f"but found {len(message.content)}." + ) + + content = message.content[0] + + if isinstance(content, str): + return Image(loader.load_bytes(content)) + + if isinstance(content, Dict): + image_url = content.get("image_url", {}).get("url") + + if image_url is not None: + return Image(loader.load_bytes(image_url)) + + raise ValueError(f"Message content: {content} is not an image.") + + raise ValueError( + "Expected message content part to be either a str or a " + f"list, but found a {content.__class__} instance" + ) + + raise ValueError( + "Message content must be either a str or a List, but found" + f"an instance of {message.content.__class__}." + ) + + @property + def _llm_type(self) -> str: + """Returns the type of LLM""" + return "vertexai-vision" + + +class BaseVertexAIImageCaptioning(BaseImageTextModel): + """Base class for Image Captioning models.""" + + def _get_captions(self, image: Image) -> List[str]: + """Uses the sdk methods to generate a list of captions. + + Args: + image: Image to get the captions for. + + Returns: + List of captions obtained from the image. + """ + model = self._create_model() + captions = model.get_captions( + image=image, + number_of_results=self.number_of_results, + language=self.language, + ) + return captions + + +class VertexAIImageCaptioning(BaseVertexAIImageCaptioning, BaseLLM): + """Implementation of the Image Captioning model as an LLM.""" + + def _generate( + self, + prompts: List[str], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> LLMResult: + """Gnerates the captions. + + Args: + prompts: List of prompts to use. Each prompt must be a string + that represents an image. Currently suported are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + + Returns: + Captions generated from every prompt. + """ + + generations = [self._generate_one(prompt=prompt) for prompt in prompts] + + return LLMResult(generations=generations) + + def _generate_one(self, prompt: str) -> List[Generation]: + """Generates the captions for a single prompt. + + Args: + prompt: Image url for the generation. + + Returns: + List of generations + """ + image_loader = ImageBytesLoader(project=self.project) + image_bytes = image_loader.load_bytes(prompt) + image = Image(image_bytes=image_bytes) + caption_list = self._get_captions(image=image) + return [Generation(text=caption) for caption in caption_list] + + +class VertexAIImageCaptioningChat(BaseVertexAIImageCaptioning, BaseChatModel): + """Implementation of the Image Captioning model as a chat.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Generates the results. + + Args: + messages: List of messages. Currently only one message is supported. + The message must contain a string representation of the image. + Currently supported are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + """ + + if len(messages) != 1: + raise ValueError( + "Image captioning only works with one message: the image. " + f"instead got {len(messages)}" + ) + + message = messages[0] + image = self._get_image_from_message(message) + captions = self._get_captions(image) + + generations = [ + ChatGeneration(message=AIMessage(content=caption)) for caption in captions + ] + + return ChatResult(generations=generations) + + +class VertexAIVisualQnAChat(BaseImageTextModel, BaseChatModel): + """Chat implementation of a visual QnA model""" + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Generates the results. + + Args: + messages: List of messages. The first message should contain a + string representation of the image. + Currently supported are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + There has to be at least other message with the first question. + """ + + if len(messages) < 2: + raise ValueError( + "Image QnA must have at least two messages: First the" + "image and then the question and answers. Instead got " + f"{len(messages)} messages." + ) + + image = self._get_image_from_message(messages[0]) + + query = self._build_query(messages=messages[1:]) + + answers = self._ask_questions(image=image, query=query) + + generations = [ + ChatGeneration(message=AIMessage(content=answer)) for answer in answers + ] + + return ChatResult(generations=generations) + + def _build_query(self, messages: List[BaseMessage]) -> str: + """Builds the query from the messages. + + Args: + messages: List of text messages. + + Retunrs: + Composed query. + """ + + query = "" + + for message in messages: + content = message.content + + if isinstance(content, str): + content = [ + content, + ] + + full_message_content = "" + for content_part in content: + if isinstance(content_part, str): + full_message_content += content_part + else: + raise ValueError("All query message content parts must be str.") + + query += f"{message.type}: {full_message_content}\n" + + return query + + def _ask_questions(self, image: Image, query: str) -> List[str]: + """Interfaces with the sdk to get the question. + + Args: + image: Image to question about. + query: User query. + + Returns: + List of responses to the query. + """ + model = self._create_model() + answers = model.ask_question( + image=image, question=query, number_of_results=self.number_of_results + ) + return answers diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py new file mode 100644 index 00000000..c3c23936 --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -0,0 +1,82 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_google_vertexai.vision_models import ( + VertexAIImageCaptioning, + VertexAIImageCaptioningChat, + VertexAIVisualQnAChat, +) + + +def test_vertex_ai_image_captioning_chat(base64_image: str): + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + ] + ) + assert isinstance(response, AIMessage) + + # Not more than one message allowed + with pytest.raises(ValueError): + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="Follow up"), + ] + ) + + +def test_vertex_ai_image_captioning(base64_image: str): + model = VertexAIImageCaptioning() + response = model.invoke(base64_image) + assert isinstance(response, str) + + +def test_vertex_ai_visual_qna_chat(base64_image: str): + model = VertexAIVisualQnAChat() + + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + ] + ) + + assert isinstance(response, AIMessage) + + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + AIMessage(content="yellow"), + HumanMessage(content="And the eyes?"), + ] + ) + assert isinstance(response, AIMessage) + + +@pytest.fixture +def base64_image() -> str: + return ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) diff --git a/libs/vertexai/tests/unit_tests/test_vision_models.py b/libs/vertexai/tests/unit_tests/test_vision_models.py new file mode 100644 index 00000000..fcff11bf --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_vision_models.py @@ -0,0 +1,62 @@ +import pytest +from langchain_core.messages import HumanMessage +from vertexai.vision_models import Image # type: ignore[import-untyped] + +from langchain_google_vertexai.vision_models import BaseImageTextModel + + +def test_image_from_message(base64_image: str): + message = HumanMessage(content=base64_image) + image = BaseImageTextModel._get_image_from_message(message) + assert isinstance(image, Image) + + message = HumanMessage( + content=[ + base64_image, + ] + ) + image = BaseImageTextModel._get_image_from_message(message) + assert isinstance(image, Image) + + message = HumanMessage( + content=[{"type": "image_url", "image_url": {"url": base64_image}}] + ) + image = BaseImageTextModel._get_image_from_message(message) + assert isinstance(image, Image) + + # Doesn't work with multiple message parts + with pytest.raises(ValueError): + message = HumanMessage(content=[base64_image, base64_image]) + image = BaseImageTextModel._get_image_from_message(message) + + # Doesn't work with malformed dicts + with pytest.raises(ValueError): + message = HumanMessage( + content=[{"bar": "image_url", "foo": {"url": base64_image}}] + ) + image = BaseImageTextModel._get_image_from_message(message) + + +@pytest.fixture +def base64_image() -> str: + return ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) From e324b5af8bed232eb0e7dd67272737d88efa80c9 Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 21 Feb 2024 22:22:00 +0100 Subject: [PATCH 2/4] Fixed selling, added annotations future for python 3.9 compat --- libs/vertexai/langchain_google_vertexai/vision_models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index 4f9c41a6..6b26c315 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, List from langchain_core.callbacks import CallbackManagerForLLMRun @@ -111,11 +113,11 @@ def _generate( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: - """Gnerates the captions. + """Generates the captions. Args: prompts: List of prompts to use. Each prompt must be a string - that represents an image. Currently suported are: + that represents an image. Currently supported are: - Google Cloud Storage URI - B64 encoded string - Local file path @@ -232,7 +234,7 @@ def _build_query(self, messages: List[BaseMessage]) -> str: Args: messages: List of text messages. - Retunrs: + Returns: Composed query. """ From 9851f7feb486917323c7885c649a8b37bdb43f7c Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 21 Feb 2024 22:23:53 +0100 Subject: [PATCH 3/4] Made base classes private --- .../langchain_google_vertexai/vision_models.py | 10 +++++----- libs/vertexai/tests/unit_tests/test_vision_models.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index 6b26c315..4ff91adc 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -14,7 +14,7 @@ from langchain_google_vertexai._image_utils import ImageBytesLoader -class BaseImageTextModel(BaseModel): +class _BaseImageTextModel(BaseModel): """Base class for all integrations that use ImageTextModel""" model_name: str = Field(default="imagetext@001") @@ -82,7 +82,7 @@ def _llm_type(self) -> str: return "vertexai-vision" -class BaseVertexAIImageCaptioning(BaseImageTextModel): +class _BaseVertexAIImageCaptioning(_BaseImageTextModel): """Base class for Image Captioning models.""" def _get_captions(self, image: Image) -> List[str]: @@ -103,7 +103,7 @@ def _get_captions(self, image: Image) -> List[str]: return captions -class VertexAIImageCaptioning(BaseVertexAIImageCaptioning, BaseLLM): +class VertexAIImageCaptioning(_BaseVertexAIImageCaptioning, BaseLLM): """Implementation of the Image Captioning model as an LLM.""" def _generate( @@ -147,7 +147,7 @@ def _generate_one(self, prompt: str) -> List[Generation]: return [Generation(text=caption) for caption in caption_list] -class VertexAIImageCaptioningChat(BaseVertexAIImageCaptioning, BaseChatModel): +class VertexAIImageCaptioningChat(_BaseVertexAIImageCaptioning, BaseChatModel): """Implementation of the Image Captioning model as a chat.""" def _generate( @@ -186,7 +186,7 @@ def _generate( return ChatResult(generations=generations) -class VertexAIVisualQnAChat(BaseImageTextModel, BaseChatModel): +class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel): """Chat implementation of a visual QnA model""" def _generate( diff --git a/libs/vertexai/tests/unit_tests/test_vision_models.py b/libs/vertexai/tests/unit_tests/test_vision_models.py index fcff11bf..d5bd65ec 100644 --- a/libs/vertexai/tests/unit_tests/test_vision_models.py +++ b/libs/vertexai/tests/unit_tests/test_vision_models.py @@ -2,12 +2,12 @@ from langchain_core.messages import HumanMessage from vertexai.vision_models import Image # type: ignore[import-untyped] -from langchain_google_vertexai.vision_models import BaseImageTextModel +from langchain_google_vertexai.vision_models import _BaseImageTextModel def test_image_from_message(base64_image: str): message = HumanMessage(content=base64_image) - image = BaseImageTextModel._get_image_from_message(message) + image = _BaseImageTextModel._get_image_from_message(message) assert isinstance(image, Image) message = HumanMessage( @@ -15,26 +15,26 @@ def test_image_from_message(base64_image: str): base64_image, ] ) - image = BaseImageTextModel._get_image_from_message(message) + image = _BaseImageTextModel._get_image_from_message(message) assert isinstance(image, Image) message = HumanMessage( content=[{"type": "image_url", "image_url": {"url": base64_image}}] ) - image = BaseImageTextModel._get_image_from_message(message) + image = _BaseImageTextModel._get_image_from_message(message) assert isinstance(image, Image) # Doesn't work with multiple message parts with pytest.raises(ValueError): message = HumanMessage(content=[base64_image, base64_image]) - image = BaseImageTextModel._get_image_from_message(message) + image = _BaseImageTextModel._get_image_from_message(message) # Doesn't work with malformed dicts with pytest.raises(ValueError): message = HumanMessage( content=[{"bar": "image_url", "foo": {"url": base64_image}}] ) - image = BaseImageTextModel._get_image_from_message(message) + image = _BaseImageTextModel._get_image_from_message(message) @pytest.fixture From 86b1fdafe845909b61c3e0414db6bfabb23b49b6 Mon Sep 17 00:00:00 2001 From: Jorge Date: Sat, 24 Feb 2024 13:08:36 +0100 Subject: [PATCH 4/4] Only allow one message with multiparts --- .../vision_models.py | 159 +++++++++--------- .../integration_tests/test_vision_models.py | 69 +++++++- .../tests/unit_tests/test_vision_models.py | 60 ++++--- 3 files changed, 176 insertions(+), 112 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index 4ff91adc..f4c384e6 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -30,51 +30,53 @@ def _create_model(self) -> ImageTextModel: """Builds the model object from the class attributes.""" return ImageTextModel.from_pretrained(model_name=self.model_name) - @staticmethod - def _get_image_from_message(message: BaseMessage) -> Image: - """Extracts an image from a message. + def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None: + """Given a message part obtain a image if the part represents it. Args: - message: Message to extract the image from. + message_part: Item of a message content. Returns: - Image extracted from the message. + Image is successful otherwise None. """ - loader = ImageBytesLoader() + if isinstance(message_part, str): + return None - if isinstance(message.content, str): - return Image(loader.load_bytes(image_string=message.content)) + if message_part.get("type") != "image_url": + return None - if isinstance(message.content, List): - if len(message.content) > 1: - raise ValueError( - "Expected message content to have only one part" - f"but found {len(message.content)}." - ) + image_str = message_part.get("image_url", {}).get("url") - content = message.content[0] + if not isinstance(image_str, str): + return None - if isinstance(content, str): - return Image(loader.load_bytes(content)) + loader = ImageBytesLoader(project=self.project) + image_bytes = loader.load_bytes(image_str) + return Image(image_bytes=image_bytes) - if isinstance(content, Dict): - image_url = content.get("image_url", {}).get("url") + def _get_text_from_message_part(self, message_part: str | Dict) -> str | None: + """Given a message part obtain a text if the part represents it. - if image_url is not None: - return Image(loader.load_bytes(image_url)) + Args: + message_part: Item of a message content. - raise ValueError(f"Message content: {content} is not an image.") + Returns: + str is successful otherwise None. + """ - raise ValueError( - "Expected message content part to be either a str or a " - f"list, but found a {content.__class__} instance" - ) + if isinstance(message_part, str): + return message_part - raise ValueError( - "Message content must be either a str or a List, but found" - f"an instance of {message.content.__class__}." - ) + if message_part.get("type") != "text": + return None + + message_text = message_part.get("text") + + if not isinstance(message_text, str): + return None + + return message_text @property def _llm_type(self) -> str: @@ -161,22 +163,41 @@ def _generate( Args: messages: List of messages. Currently only one message is supported. - The message must contain a string representation of the image. - Currently supported are: + The message content must be a list with only one element with + a dict with format: + { + 'type': 'image_url', + 'image_url': { + 'url' + } + } + Currently supported image strings are: - Google Cloud Storage URI - B64 encoded string - Local file path - Remote url """ - if len(messages) != 1: + image = None + + is_valid = ( + len(messages) == 1 + and isinstance(messages[0].content, List) + and len(messages[0].content) == 1 + ) + + if is_valid: + content = messages[0].content[0] + image = self._get_image_from_message_part(content) + + if image is None: raise ValueError( - "Image captioning only works with one message: the image. " - f"instead got {len(messages)}" + f"{self.__class__.__name__} messages should be a list with " + "only one message. This message content must be a list with " + "one dictionary with the format: " + "{'type': 'image_url', 'image_url': {'image': }}" ) - message = messages[0] - image = self._get_image_from_message(message) captions = self._get_captions(image) generations = [ @@ -209,18 +230,33 @@ def _generate( There has to be at least other message with the first question. """ - if len(messages) < 2: - raise ValueError( - "Image QnA must have at least two messages: First the" - "image and then the question and answers. Instead got " - f"{len(messages)} messages." - ) + image = None + user_question = None - image = self._get_image_from_message(messages[0]) + is_valid = ( + len(messages) == 1 + and isinstance(messages[0].content, List) + and len(messages[0].content) == 2 + ) + + if is_valid: + image_part = messages[0].content[0] + user_question_part = messages[0].content[1] + image = self._get_image_from_message_part(image_part) + user_question = self._get_text_from_message_part(user_question_part) - query = self._build_query(messages=messages[1:]) + if (image is None) or (user_question is None): + raise ValueError( + f"{self.__class__.__name__} messages should be a list with " + "only one message. The message content should be a list with " + "two elements. The first element should be the image, a dictionary " + "with format" + "{'type': 'image_url', 'image_url': {'image': }}." + "The second one should be the user question. Either a simple string" + "or a dictionary with format {'type': 'text', 'text': }" + ) - answers = self._ask_questions(image=image, query=query) + answers = self._ask_questions(image=image, query=user_question) generations = [ ChatGeneration(message=AIMessage(content=answer)) for answer in answers @@ -228,37 +264,6 @@ def _generate( return ChatResult(generations=generations) - def _build_query(self, messages: List[BaseMessage]) -> str: - """Builds the query from the messages. - - Args: - messages: List of text messages. - - Returns: - Composed query. - """ - - query = "" - - for message in messages: - content = message.content - - if isinstance(content, str): - content = [ - content, - ] - - full_message_content = "" - for content_part in content: - if isinstance(content_part, str): - full_message_content += content_part - else: - raise ValueError("All query message content parts must be str.") - - query += f"{message.type}: {full_message_content}\n" - - return query - def _ask_questions(self, image: Image, query: str) -> List[str]: """Interfaces with the sdk to get the question. diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py index c3c23936..afda3a50 100644 --- a/libs/vertexai/tests/integration_tests/test_vision_models.py +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -9,14 +9,27 @@ def test_vertex_ai_image_captioning_chat(base64_image: str): + # This should work model = VertexAIImageCaptioningChat() response = model.invoke( input=[ - HumanMessage(content=base64_image), + HumanMessage( + content=[{"type": "image_url", "image_url": {"url": base64_image}}] + ), ] ) + assert isinstance(response, AIMessage) + # Content should be an image + with pytest.raises(ValueError): + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage(content="Text message"), + ] + ) + # Not more than one message allowed with pytest.raises(ValueError): model = VertexAIImageCaptioningChat() @@ -37,25 +50,65 @@ def test_vertex_ai_image_captioning(base64_image: str): def test_vertex_ai_visual_qna_chat(base64_image: str): model = VertexAIVisualQnAChat() + # This should work response = model.invoke( input=[ - HumanMessage(content=base64_image), - HumanMessage(content="What color is the image?"), + HumanMessage( + content=[ + {"type": "image_url", "image_url": {"url": base64_image}}, + "What color is the image?", + ] + ) ] ) - assert isinstance(response, AIMessage) response = model.invoke( input=[ - HumanMessage(content=base64_image), - HumanMessage(content="What color is the image?"), - AIMessage(content="yellow"), - HumanMessage(content="And the eyes?"), + HumanMessage( + content=[ + {"type": "image_url", "image_url": {"url": base64_image}}, + {"type": "text", "text": "What color is the image?"}, + ] + ) ] ) assert isinstance(response, AIMessage) + # This should not work, the image must be first + + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage( + content=[ + {"type": "text", "text": "What color is the image?"}, + {"type": "image_url", "image_url": {"url": base64_image}}, + ] + ) + ] + ) + + # This should not work, only one message with multiparts allowed + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + ] + ) + + # This should not work, only one message with multiparts allowed + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + AIMessage(content="yellow"), + HumanMessage(content="And the eyes?"), + ] + ) + @pytest.fixture def base64_image() -> str: diff --git a/libs/vertexai/tests/unit_tests/test_vision_models.py b/libs/vertexai/tests/unit_tests/test_vision_models.py index d5bd65ec..87880a20 100644 --- a/libs/vertexai/tests/unit_tests/test_vision_models.py +++ b/libs/vertexai/tests/unit_tests/test_vision_models.py @@ -1,40 +1,46 @@ import pytest -from langchain_core.messages import HumanMessage from vertexai.vision_models import Image # type: ignore[import-untyped] from langchain_google_vertexai.vision_models import _BaseImageTextModel -def test_image_from_message(base64_image: str): - message = HumanMessage(content=base64_image) - image = _BaseImageTextModel._get_image_from_message(message) - assert isinstance(image, Image) +def test_get_image_from_message_part(base64_image: str): + model = _BaseImageTextModel() - message = HumanMessage( - content=[ - base64_image, - ] - ) - image = _BaseImageTextModel._get_image_from_message(message) + # Should work with a well formatted dictionary: + message = {"type": "image_url", "image_url": {"url": base64_image}} + image = model._get_image_from_message_part(message) assert isinstance(image, Image) - message = HumanMessage( - content=[{"type": "image_url", "image_url": {"url": base64_image}}] - ) - image = _BaseImageTextModel._get_image_from_message(message) - assert isinstance(image, Image) + # Should not work with a simple string + simple_string = base64_image + image = model._get_image_from_message_part(simple_string) + assert image is None + + # Should not work with a string message + message = {"type": "text", "text": "I'm a text message"} + image = model._get_image_from_message_part(message) + assert image is None + + +def test_get_text_from_message_part(): + DUMMY_MESSAGE = "Some message" + model = _BaseImageTextModel() + + # Should not work with an image + message = {"type": "image_url", "image_url": {"url": base64_image}} + text = model._get_text_from_message_part(message) + assert text is None + + # Should work with a simple string + simple_message = DUMMY_MESSAGE + text = model._get_text_from_message_part(simple_message) + assert text == DUMMY_MESSAGE - # Doesn't work with multiple message parts - with pytest.raises(ValueError): - message = HumanMessage(content=[base64_image, base64_image]) - image = _BaseImageTextModel._get_image_from_message(message) - - # Doesn't work with malformed dicts - with pytest.raises(ValueError): - message = HumanMessage( - content=[{"bar": "image_url", "foo": {"url": base64_image}}] - ) - image = _BaseImageTextModel._get_image_from_message(message) + # Should work with a text message + message = {"type": "text", "text": DUMMY_MESSAGE} + text = model._get_text_from_message_part(message) + assert text == DUMMY_MESSAGE @pytest.fixture