diff --git a/libs/vertexai/langchain_google_vertexai/_image_utils.py b/libs/vertexai/langchain_google_vertexai/_image_utils.py index 4f94e0b1..754fd4b3 100644 --- a/libs/vertexai/langchain_google_vertexai/_image_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_image_utils.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import base64 import os import re -from typing import Union +from typing import Dict, Union from urllib.parse import urlparse import requests @@ -168,3 +170,80 @@ def image_bytes_to_b64_string( """ encoded_bytes = base64.b64encode(image_bytes).decode(encoding) return f"data:image/{image_format};base64,{encoded_bytes}" + + +def create_text_content_part(message_str: str) -> Dict: + """Create a dictionary that can be part of a message content list. + + Args: + message_str: Message as an string. + + Returns: + Dictionary that can be part of a message content list. + """ + return {"type": "text", "text": message_str} + + +def create_image_content_part(image_str: str) -> Dict: + """Create a dictionary that can be part of a message content list. + + Args: + image_str: Can be either: + - b64 encoded image data + - GCS uri + - Url + - Path to an image. + + Returns: + Dictionary that can be part of a message content list. + """ + return {"type": "image_url", "image_url": {"url": image_str}} + + +def get_image_str_from_content_part(content_part: str | Dict) -> str | None: + """Parses an image string from a dictionary with the correct format. + + Args: + content_part: String or dictionary. + + Returns: + Image string if the dictionary has the correct format otherwise None. + """ + + if isinstance(content_part, str): + return None + + if content_part.get("type") != "image_url": + return None + + image_str = content_part.get("image_url", {}).get("url") + + if isinstance(image_str, str): + return image_str + else: + return None + + +def get_text_str_from_content_part(content_part: str | Dict) -> str | None: + """Parses an string from a dictionary or string with the correct format. + + Args: + content_part: String or dictionary. + + Returns: + String if the dictionary has the correct format or the input is an string, + otherwise None. + """ + + if isinstance(content_part, str): + return content_part + + if content_part.get("type") != "text": + return None + + text = content_part.get("text") + + if isinstance(text, str): + return text + else: + return None diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index f4c384e6..45469800 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel, BaseLLM @@ -9,9 +9,19 @@ from langchain_core.outputs.chat_generation import ChatGeneration from langchain_core.outputs.generation import Generation from langchain_core.pydantic_v1 import BaseModel, Field +from vertexai.preview.vision_models import ( # type: ignore[import-untyped] + GeneratedImage, + ImageGenerationModel, +) from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped] -from langchain_google_vertexai._image_utils import ImageBytesLoader +from langchain_google_vertexai._image_utils import ( + ImageBytesLoader, + create_image_content_part, + get_image_str_from_content_part, + get_text_str_from_content_part, + image_bytes_to_b64_string, +) class _BaseImageTextModel(BaseModel): @@ -23,7 +33,7 @@ class _BaseImageTextModel(BaseModel): """Number of results to return from one query""" language: str = Field(default="en") """Language of the query""" - project: str = Field(default=None) + project: Union[str, None] = Field(default=None) """Google cloud project""" def _create_model(self) -> ImageTextModel: @@ -40,21 +50,15 @@ def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None Image is successful otherwise None. """ - if isinstance(message_part, str): - return None - - if message_part.get("type") != "image_url": - return None - - image_str = message_part.get("image_url", {}).get("url") + image_str = get_image_str_from_content_part(message_part) - if not isinstance(image_str, str): + if isinstance(image_str, str): + loader = ImageBytesLoader(project=self.project) + image_bytes = loader.load_bytes(image_str) + return Image(image_bytes=image_bytes) + else: return None - loader = ImageBytesLoader(project=self.project) - image_bytes = loader.load_bytes(image_str) - return Image(image_bytes=image_bytes) - def _get_text_from_message_part(self, message_part: str | Dict) -> str | None: """Given a message part obtain a text if the part represents it. @@ -64,19 +68,7 @@ def _get_text_from_message_part(self, message_part: str | Dict) -> str | None: Returns: str is successful otherwise None. """ - - if isinstance(message_part, str): - return message_part - - if message_part.get("type") != "text": - return None - - message_text = message_part.get("text") - - if not isinstance(message_text, str): - return None - - return message_text + return get_text_str_from_content_part(message_part) @property def _llm_type(self) -> str: @@ -279,3 +271,202 @@ def _ask_questions(self, image: Image, query: str) -> List[str]: image=image, question=query, number_of_results=self.number_of_results ) return answers + + +class _BaseVertexAIImageGenerator(BaseModel): + """Base class form generation and edition of images.""" + + model_name: str = Field(default="imagegeneration@002") + """Name of the base model""" + negative_prompt: Union[str, None] = Field(default=None) + """A description of what you want to omit in + the generated images""" + number_of_images: int = Field(default=1) + """Number of images to generate""" + guidance_scale: Union[float, None] = Field(default=None) + """Controls the strength of the prompt""" + language: Union[str, None] = Field(default=None) + """Language of the text prompt for the image Supported values are "en" for English, + "hi" for Hindi, "ja" for Japanese, "ko" for Korean, and "auto" for automatic + language detection""" + seed: Union[int, None] = Field(default=None) + """Random seed for the image generation""" + project: Union[str, None] = Field(default=None) + """Google cloud project id""" + + def _generate_images(self, prompt: str) -> List[str]: + """Generates images given a prompt. + + Args: + prompt: Description of what the image should look like. + + Returns: + List of b64 encoded strings. + """ + + model = ImageGenerationModel.from_pretrained(self.model_name) + + generation_result = model.generate_images( + prompt=prompt, + negative_prompt=self.negative_prompt, + number_of_images=self.number_of_images, + language=self.language, + guidance_scale=self.guidance_scale, + seed=self.seed, + ) + + image_str_list = [ + self._to_b64_string(image) for image in generation_result.images + ] + + return image_str_list + + def _edit_images(self, image_str: str, prompt: str) -> List[str]: + """Edit an image given a image and a prompt. + + Args: + image_str: String representation of the image. + prompt: Description of what the image should look like. + + Returns: + List of b64 encoded strings. + """ + + model = ImageGenerationModel.from_pretrained(self.model_name) + + image_loader = ImageBytesLoader(project=self.project) + image_bytes = image_loader.load_bytes(image_str) + image = Image(image_bytes=image_bytes) + + generation_result = model.edit_image( + prompt=prompt, + base_image=image, + negative_prompt=self.negative_prompt, + number_of_images=self.number_of_images, + language=self.language, + guidance_scale=self.guidance_scale, + seed=self.seed, + ) + + image_str_list = [ + self._to_b64_string(image) for image in generation_result.images + ] + + return image_str_list + + def _to_b64_string(self, image: GeneratedImage) -> str: + """Transforms a generated image into a b64 encoded string. + + Args: + image: Image to convert. + + Returns: + b64 encoded string of the image. + """ + + # This is a hack because at the moment, GeneratedImage doesn't provide + # a way to get the bytes of the image (or anything else). There is + # only private methods that are not reliable. + + from tempfile import NamedTemporaryFile + + temp_file = NamedTemporaryFile() + image.save(temp_file.name, include_generation_parameters=False) + temp_file.seek(0) + image_bytes = temp_file.read() + temp_file.close() + + return image_bytes_to_b64_string(image_bytes=image_bytes) + + @property + def _llm_type(self) -> str: + """Returns the type of LLM""" + return "vertexai-vision" + + +class VertexAIImageGeneratorChat(_BaseVertexAIImageGenerator, BaseChatModel): + """Generates an image from a prompt.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """ + Args: + messages: The message must be a list of only one element with one part: + The user prompt. + """ + + # Only one message allowed with one text part. + user_query = None + is_valid = len(messages) == 1 and len(messages[0].content) == 1 + if is_valid: + user_query = get_text_str_from_content_part(messages[0].content[0]) + if user_query is None: + raise ValueError( + "Only one message with one text part allowed for image generation" + " Must The prompt of the image" + ) + + image_str_list = self._generate_images(prompt=user_query) + image_content_part_list = [ + create_image_content_part(image_str=image_str) + for image_str in image_str_list + ] + + generations = [ + ChatGeneration(message=AIMessage(content=[content_part])) + for content_part in image_content_part_list + ] + + return ChatResult(generations=generations) + + +class VertexAIImageEditorChat(_BaseVertexAIImageGenerator, BaseChatModel): + """Given an image and a prompt, edits the image. + Currently only supports mask free editing. + """ + + def _generate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """ + Args: + messages: The message must be a list of only one element with two part: + - The image as a dict { + 'type': 'image_url', 'image_url': {'url': } + } + - The user prompt. + """ + + # Only one message allowed with two parts: the image and the text. + user_query = None + is_valid = len(messages) == 1 and len(messages[0].content) == 2 + if is_valid: + image_str = get_image_str_from_content_part(messages[0].content[0]) + user_query = get_text_str_from_content_part(messages[0].content[1]) + if (user_query is None) or (image_str is None): + raise ValueError( + "Only one message allowed for image edition. The message must have" + "two parts: First the image and then the user prompt." + ) + + image_str_list = self._edit_images(image_str=image_str, prompt=user_query) + image_content_part_list = [ + create_image_content_part(image_str=image_str) + for image_str in image_str_list + ] + + generations = [ + ChatGeneration(message=AIMessage(content=[content_part])) + for content_part in image_content_part_list + ] + + return ChatResult(generations=generations) diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py index afda3a50..a3edaa04 100644 --- a/libs/vertexai/tests/integration_tests/test_vision_models.py +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -4,6 +4,8 @@ from langchain_google_vertexai.vision_models import ( VertexAIImageCaptioning, VertexAIImageCaptioningChat, + VertexAIImageEditorChat, + VertexAIImageGeneratorChat, VertexAIVisualQnAChat, ) @@ -110,6 +112,23 @@ def test_vertex_ai_visual_qna_chat(base64_image: str): ) +def test_vertex_ai_image_generation_and_edition(): + generator = VertexAIImageGeneratorChat() + + messages = [HumanMessage(content=["Generate a dog reading the newspaper"])] + response = generator.invoke(messages) + assert isinstance(response, AIMessage) + + generated_image = response.content[0] + + editor = VertexAIImageEditorChat() + + messages = [HumanMessage(content=[generated_image, "Change the dog for a cat"])] + + response = editor.invoke(messages) + assert isinstance(response, AIMessage) + + @pytest.fixture def base64_image() -> str: return ( diff --git a/libs/vertexai/tests/unit_tests/test_image_utils.py b/libs/vertexai/tests/unit_tests/test_image_utils.py index 5d900d8b..034cf850 100644 --- a/libs/vertexai/tests/unit_tests/test_image_utils.py +++ b/libs/vertexai/tests/unit_tests/test_image_utils.py @@ -4,10 +4,63 @@ from langchain_google_vertexai._image_utils import ( ImageBytesLoader, + create_image_content_part, + create_text_content_part, + get_image_str_from_content_part, + get_text_str_from_content_part, image_bytes_to_b64_string, ) +def test_get_text_str_from_content_part(): + content_part = "This is a text" + result = get_text_str_from_content_part(content_part) + assert result == content_part + + content_part_dict = {"type": "text", "text": "This is a text"} + result = get_text_str_from_content_part(content_part_dict) + assert result == content_part_dict["text"] + + content_part_dict = {"type": "image", "text": "This is a text"} + result = get_text_str_from_content_part(content_part_dict) + assert result is None + + content_part_dict = {"foo": "image", "bar": "This is a text"} + result = get_text_str_from_content_part(content_part_dict) + assert result is None + + +def test_get_image_str_from_content_part(): + content_part = "This is a text" + result = get_image_str_from_content_part(content_part) + assert result is None + + content_part_dict = {"type": "image_url", "image_url": {"url": "img_url"}} + result = get_image_str_from_content_part(content_part_dict) + assert isinstance(content_part_dict["image_url"], dict) + assert result == content_part_dict["image_url"]["url"] + + content_part_dict = {"type": "image", "text": "This is a text"} + result = get_image_str_from_content_part(content_part_dict) + assert result is None + + content_part_dict = {"foo": "image", "bar": "This is a text"} + result = get_image_str_from_content_part(content_part_dict) + assert result is None + + +def test_create_content_parts(): + message_str = "This is a message str" + text_content_part = create_text_content_part(message_str) + result = get_text_str_from_content_part(text_content_part) + assert message_str == result + + message_str = "This is a image str" + text_content_part = create_image_content_part(message_str) + result = get_image_str_from_content_part(text_content_part) + assert message_str == result + + def test_image_bytes_loader(): loader = ImageBytesLoader()