diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py new file mode 100644 index 00000000..f4c384e6 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, BaseLLM +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatResult, LLMResult +from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.outputs.generation import Generation +from langchain_core.pydantic_v1 import BaseModel, Field +from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped] + +from langchain_google_vertexai._image_utils import ImageBytesLoader + + +class _BaseImageTextModel(BaseModel): + """Base class for all integrations that use ImageTextModel""" + + model_name: str = Field(default="imagetext@001") + """ Name of the model to use""" + number_of_results: int = Field(default=1) + """Number of results to return from one query""" + language: str = Field(default="en") + """Language of the query""" + project: str = Field(default=None) + """Google cloud project""" + + def _create_model(self) -> ImageTextModel: + """Builds the model object from the class attributes.""" + return ImageTextModel.from_pretrained(model_name=self.model_name) + + def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None: + """Given a message part obtain a image if the part represents it. + + Args: + message_part: Item of a message content. + + Returns: + Image is successful otherwise None. + """ + + if isinstance(message_part, str): + return None + + if message_part.get("type") != "image_url": + return None + + image_str = message_part.get("image_url", {}).get("url") + + if not isinstance(image_str, str): + return None + + loader = ImageBytesLoader(project=self.project) + image_bytes = loader.load_bytes(image_str) + return Image(image_bytes=image_bytes) + + def _get_text_from_message_part(self, message_part: str | Dict) -> str | None: + """Given a message part obtain a text if the part represents it. + + Args: + message_part: Item of a message content. + + Returns: + str is successful otherwise None. + """ + + if isinstance(message_part, str): + return message_part + + if message_part.get("type") != "text": + return None + + message_text = message_part.get("text") + + if not isinstance(message_text, str): + return None + + return message_text + + @property + def _llm_type(self) -> str: + """Returns the type of LLM""" + return "vertexai-vision" + + +class _BaseVertexAIImageCaptioning(_BaseImageTextModel): + """Base class for Image Captioning models.""" + + def _get_captions(self, image: Image) -> List[str]: + """Uses the sdk methods to generate a list of captions. + + Args: + image: Image to get the captions for. + + Returns: + List of captions obtained from the image. + """ + model = self._create_model() + captions = model.get_captions( + image=image, + number_of_results=self.number_of_results, + language=self.language, + ) + return captions + + +class VertexAIImageCaptioning(_BaseVertexAIImageCaptioning, BaseLLM): + """Implementation of the Image Captioning model as an LLM.""" + + def _generate( + self, + prompts: List[str], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> LLMResult: + """Generates the captions. + + Args: + prompts: List of prompts to use. Each prompt must be a string + that represents an image. Currently supported are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + + Returns: + Captions generated from every prompt. + """ + + generations = [self._generate_one(prompt=prompt) for prompt in prompts] + + return LLMResult(generations=generations) + + def _generate_one(self, prompt: str) -> List[Generation]: + """Generates the captions for a single prompt. + + Args: + prompt: Image url for the generation. + + Returns: + List of generations + """ + image_loader = ImageBytesLoader(project=self.project) + image_bytes = image_loader.load_bytes(prompt) + image = Image(image_bytes=image_bytes) + caption_list = self._get_captions(image=image) + return [Generation(text=caption) for caption in caption_list] + + +class VertexAIImageCaptioningChat(_BaseVertexAIImageCaptioning, BaseChatModel): + """Implementation of the Image Captioning model as a chat.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Generates the results. + + Args: + messages: List of messages. Currently only one message is supported. + The message content must be a list with only one element with + a dict with format: + { + 'type': 'image_url', + 'image_url': { + 'url' + } + } + Currently supported image strings are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + """ + + image = None + + is_valid = ( + len(messages) == 1 + and isinstance(messages[0].content, List) + and len(messages[0].content) == 1 + ) + + if is_valid: + content = messages[0].content[0] + image = self._get_image_from_message_part(content) + + if image is None: + raise ValueError( + f"{self.__class__.__name__} messages should be a list with " + "only one message. This message content must be a list with " + "one dictionary with the format: " + "{'type': 'image_url', 'image_url': {'image': }}" + ) + + captions = self._get_captions(image) + + generations = [ + ChatGeneration(message=AIMessage(content=caption)) for caption in captions + ] + + return ChatResult(generations=generations) + + +class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel): + """Chat implementation of a visual QnA model""" + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Generates the results. + + Args: + messages: List of messages. The first message should contain a + string representation of the image. + Currently supported are: + - Google Cloud Storage URI + - B64 encoded string + - Local file path + - Remote url + There has to be at least other message with the first question. + """ + + image = None + user_question = None + + is_valid = ( + len(messages) == 1 + and isinstance(messages[0].content, List) + and len(messages[0].content) == 2 + ) + + if is_valid: + image_part = messages[0].content[0] + user_question_part = messages[0].content[1] + image = self._get_image_from_message_part(image_part) + user_question = self._get_text_from_message_part(user_question_part) + + if (image is None) or (user_question is None): + raise ValueError( + f"{self.__class__.__name__} messages should be a list with " + "only one message. The message content should be a list with " + "two elements. The first element should be the image, a dictionary " + "with format" + "{'type': 'image_url', 'image_url': {'image': }}." + "The second one should be the user question. Either a simple string" + "or a dictionary with format {'type': 'text', 'text': }" + ) + + answers = self._ask_questions(image=image, query=user_question) + + generations = [ + ChatGeneration(message=AIMessage(content=answer)) for answer in answers + ] + + return ChatResult(generations=generations) + + def _ask_questions(self, image: Image, query: str) -> List[str]: + """Interfaces with the sdk to get the question. + + Args: + image: Image to question about. + query: User query. + + Returns: + List of responses to the query. + """ + model = self._create_model() + answers = model.ask_question( + image=image, question=query, number_of_results=self.number_of_results + ) + return answers diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py new file mode 100644 index 00000000..afda3a50 --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -0,0 +1,135 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_google_vertexai.vision_models import ( + VertexAIImageCaptioning, + VertexAIImageCaptioningChat, + VertexAIVisualQnAChat, +) + + +def test_vertex_ai_image_captioning_chat(base64_image: str): + # This should work + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage( + content=[{"type": "image_url", "image_url": {"url": base64_image}}] + ), + ] + ) + + assert isinstance(response, AIMessage) + + # Content should be an image + with pytest.raises(ValueError): + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage(content="Text message"), + ] + ) + + # Not more than one message allowed + with pytest.raises(ValueError): + model = VertexAIImageCaptioningChat() + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="Follow up"), + ] + ) + + +def test_vertex_ai_image_captioning(base64_image: str): + model = VertexAIImageCaptioning() + response = model.invoke(base64_image) + assert isinstance(response, str) + + +def test_vertex_ai_visual_qna_chat(base64_image: str): + model = VertexAIVisualQnAChat() + + # This should work + response = model.invoke( + input=[ + HumanMessage( + content=[ + {"type": "image_url", "image_url": {"url": base64_image}}, + "What color is the image?", + ] + ) + ] + ) + assert isinstance(response, AIMessage) + + response = model.invoke( + input=[ + HumanMessage( + content=[ + {"type": "image_url", "image_url": {"url": base64_image}}, + {"type": "text", "text": "What color is the image?"}, + ] + ) + ] + ) + assert isinstance(response, AIMessage) + + # This should not work, the image must be first + + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage( + content=[ + {"type": "text", "text": "What color is the image?"}, + {"type": "image_url", "image_url": {"url": base64_image}}, + ] + ) + ] + ) + + # This should not work, only one message with multiparts allowed + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + ] + ) + + # This should not work, only one message with multiparts allowed + with pytest.raises(ValueError): + response = model.invoke( + input=[ + HumanMessage(content=base64_image), + HumanMessage(content="What color is the image?"), + AIMessage(content="yellow"), + HumanMessage(content="And the eyes?"), + ] + ) + + +@pytest.fixture +def base64_image() -> str: + return ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) diff --git a/libs/vertexai/tests/unit_tests/test_vision_models.py b/libs/vertexai/tests/unit_tests/test_vision_models.py new file mode 100644 index 00000000..87880a20 --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_vision_models.py @@ -0,0 +1,68 @@ +import pytest +from vertexai.vision_models import Image # type: ignore[import-untyped] + +from langchain_google_vertexai.vision_models import _BaseImageTextModel + + +def test_get_image_from_message_part(base64_image: str): + model = _BaseImageTextModel() + + # Should work with a well formatted dictionary: + message = {"type": "image_url", "image_url": {"url": base64_image}} + image = model._get_image_from_message_part(message) + assert isinstance(image, Image) + + # Should not work with a simple string + simple_string = base64_image + image = model._get_image_from_message_part(simple_string) + assert image is None + + # Should not work with a string message + message = {"type": "text", "text": "I'm a text message"} + image = model._get_image_from_message_part(message) + assert image is None + + +def test_get_text_from_message_part(): + DUMMY_MESSAGE = "Some message" + model = _BaseImageTextModel() + + # Should not work with an image + message = {"type": "image_url", "image_url": {"url": base64_image}} + text = model._get_text_from_message_part(message) + assert text is None + + # Should work with a simple string + simple_message = DUMMY_MESSAGE + text = model._get_text_from_message_part(simple_message) + assert text == DUMMY_MESSAGE + + # Should work with a text message + message = {"type": "text", "text": DUMMY_MESSAGE} + text = model._get_text_from_message_part(message) + assert text == DUMMY_MESSAGE + + +@pytest.fixture +def base64_image() -> str: + return ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + )