Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisionModels: Added ImageGeneration and ImageEdition #25

Merged
merged 11 commits into from
Feb 26, 2024
Merged
81 changes: 80 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_image_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import base64
import os
import re
from typing import Union
from typing import Dict, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -168,3 +170,80 @@ def image_bytes_to_b64_string(
"""
encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
return f"data:image/{image_format};base64,{encoded_bytes}"


def create_text_content_part(message_str: str) -> Dict:
"""Create a dictionary that can be part of a message content list.

Args:
message_str: Message as an string.

Returns:
Dictionary that can be part of a message content list.
"""
return {"type": "text", "text": message_str}


def create_image_content_part(image_str: str) -> Dict:
"""Create a dictionary that can be part of a message content list.

Args:
image_str: Can be either:
- b64 encoded image data
- GCS uri
- Url
- Path to an image.

Returns:
Dictionary that can be part of a message content list.
"""
return {"type": "image_url", "image_url": {"url": image_str}}


def get_image_str_from_content_part(content_part: str | Dict) -> str | None:
"""Parses an image string from a dictionary with the correct format.

Args:
content_part: String or dictionary.

Returns:
Image string if the dictionary has the correct format otherwise None.
"""

if isinstance(content_part, str):
return None

if content_part.get("type") != "image_url":
return None

image_str = content_part.get("image_url", {}).get("url")

if isinstance(image_str, str):
return image_str
else:
return None


def get_text_str_from_content_part(content_part: str | Dict) -> str | None:
"""Parses an string from a dictionary or string with the correct format.

Args:
content_part: String or dictionary.

Returns:
String if the dictionary has the correct format or the input is an string,
otherwise None.
"""

if isinstance(content_part, str):
return content_part

if content_part.get("type") != "text":
return None

text = content_part.get("text")

if isinstance(text, str):
return text
else:
return None
247 changes: 219 additions & 28 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, BaseLLM
Expand All @@ -9,9 +9,19 @@
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import BaseModel, Field
from vertexai.preview.vision_models import ( # type: ignore[import-untyped]
GeneratedImage,
ImageGenerationModel,
)
from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped]

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
create_image_content_part,
get_image_str_from_content_part,
get_text_str_from_content_part,
image_bytes_to_b64_string,
)


class _BaseImageTextModel(BaseModel):
Expand All @@ -23,7 +33,7 @@ class _BaseImageTextModel(BaseModel):
"""Number of results to return from one query"""
language: str = Field(default="en")
"""Language of the query"""
project: str = Field(default=None)
project: Union[str, None] = Field(default=None)
"""Google cloud project"""

def _create_model(self) -> ImageTextModel:
Expand All @@ -40,21 +50,15 @@ def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None
Image is successful otherwise None.
"""

if isinstance(message_part, str):
return None

if message_part.get("type") != "image_url":
return None

image_str = message_part.get("image_url", {}).get("url")
image_str = get_image_str_from_content_part(message_part)

if not isinstance(image_str, str):
if isinstance(image_str, str):
loader = ImageBytesLoader(project=self.project)
image_bytes = loader.load_bytes(image_str)
return Image(image_bytes=image_bytes)
else:
return None

loader = ImageBytesLoader(project=self.project)
image_bytes = loader.load_bytes(image_str)
return Image(image_bytes=image_bytes)

def _get_text_from_message_part(self, message_part: str | Dict) -> str | None:
"""Given a message part obtain a text if the part represents it.

Expand All @@ -64,19 +68,7 @@ def _get_text_from_message_part(self, message_part: str | Dict) -> str | None:
Returns:
str is successful otherwise None.
"""

if isinstance(message_part, str):
return message_part

if message_part.get("type") != "text":
return None

message_text = message_part.get("text")

if not isinstance(message_text, str):
return None

return message_text
return get_text_str_from_content_part(message_part)

@property
def _llm_type(self) -> str:
Expand Down Expand Up @@ -279,3 +271,202 @@ def _ask_questions(self, image: Image, query: str) -> List[str]:
image=image, question=query, number_of_results=self.number_of_results
)
return answers


class _BaseVertexAIImageGenerator(BaseModel):
"""Base class form generation and edition of images."""

model_name: str = Field(default="imagegeneration@002")
"""Name of the base model"""
negative_prompt: Union[str, None] = Field(default=None)
"""A description of what you want to omit in
the generated images"""
number_of_images: int = Field(default=1)
"""Number of images to generate"""
guidance_scale: Union[float, None] = Field(default=None)
"""Controls the strength of the prompt"""
language: Union[str, None] = Field(default=None)
"""Language of the text prompt for the image Supported values are "en" for English,
"hi" for Hindi, "ja" for Japanese, "ko" for Korean, and "auto" for automatic
language detection"""
seed: Union[int, None] = Field(default=None)
"""Random seed for the image generation"""
project: Union[str, None] = Field(default=None)
"""Google cloud project id"""

def _generate_images(self, prompt: str) -> List[str]:
"""Generates images given a prompt.

Args:
prompt: Description of what the image should look like.

Returns:
List of b64 encoded strings.
"""

model = ImageGenerationModel.from_pretrained(self.model_name)

generation_result = model.generate_images(
prompt=prompt,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
)

image_str_list = [
self._to_b64_string(image) for image in generation_result.images
]

return image_str_list

def _edit_images(self, image_str: str, prompt: str) -> List[str]:
"""Edit an image given a image and a prompt.

Args:
image_str: String representation of the image.
prompt: Description of what the image should look like.

Returns:
List of b64 encoded strings.
"""

model = ImageGenerationModel.from_pretrained(self.model_name)

image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(image_str)
image = Image(image_bytes=image_bytes)

generation_result = model.edit_image(
prompt=prompt,
base_image=image,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
)

image_str_list = [
self._to_b64_string(image) for image in generation_result.images
]

return image_str_list

def _to_b64_string(self, image: GeneratedImage) -> str:
"""Transforms a generated image into a b64 encoded string.

Args:
image: Image to convert.

Returns:
b64 encoded string of the image.
"""

# This is a hack because at the moment, GeneratedImage doesn't provide
# a way to get the bytes of the image (or anything else). There is
# only private methods that are not reliable.

from tempfile import NamedTemporaryFile

temp_file = NamedTemporaryFile()
image.save(temp_file.name, include_generation_parameters=False)
temp_file.seek(0)
image_bytes = temp_file.read()
temp_file.close()

return image_bytes_to_b64_string(image_bytes=image_bytes)

@property
def _llm_type(self) -> str:
"""Returns the type of LLM"""
return "vertexai-vision"


class VertexAIImageGeneratorChat(_BaseVertexAIImageGenerator, BaseChatModel):
"""Generates an image from a prompt."""

def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Args:
messages: The message must be a list of only one element with one part:
The user prompt.
"""

# Only one message allowed with one text part.
user_query = None
is_valid = len(messages) == 1 and len(messages[0].content) == 1
if is_valid:
user_query = get_text_str_from_content_part(messages[0].content[0])
if user_query is None:
raise ValueError(
"Only one message with one text part allowed for image generation"
" Must The prompt of the image"
)

image_str_list = self._generate_images(prompt=user_query)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
]

generations = [
ChatGeneration(message=AIMessage(content=[content_part]))
for content_part in image_content_part_list
]

return ChatResult(generations=generations)


class VertexAIImageEditorChat(_BaseVertexAIImageGenerator, BaseChatModel):
"""Given an image and a prompt, edits the image.
Currently only supports mask free editing.
"""

def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Args:
messages: The message must be a list of only one element with two part:
- The image as a dict {
'type': 'image_url', 'image_url': {'url': <message_str>}
}
- The user prompt.
"""

# Only one message allowed with two parts: the image and the text.
user_query = None
is_valid = len(messages) == 1 and len(messages[0].content) == 2
if is_valid:
image_str = get_image_str_from_content_part(messages[0].content[0])
user_query = get_text_str_from_content_part(messages[0].content[1])
if (user_query is None) or (image_str is None):
raise ValueError(
"Only one message allowed for image edition. The message must have"
"two parts: First the image and then the user prompt."
)

image_str_list = self._edit_images(image_str=image_str, prompt=user_query)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
]

generations = [
ChatGeneration(message=AIMessage(content=[content_part]))
for content_part in image_content_part_list
]

return ChatResult(generations=generations)
Loading
Loading