Skip to content

Commit

Permalink
VisionModels: Added ImageGeneration and ImageEdition (#25)
Browse files Browse the repository at this point in the history
* Added support for imagetext models

---------

Co-authored-by: Jorge <[email protected]>
  • Loading branch information
jzaldi and Jorge authored Feb 26, 2024
1 parent d0a7848 commit ba64883
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 29 deletions.
81 changes: 80 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_image_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import base64
import os
import re
from typing import Union
from typing import Dict, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -168,3 +170,80 @@ def image_bytes_to_b64_string(
"""
encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
return f"data:image/{image_format};base64,{encoded_bytes}"


def create_text_content_part(message_str: str) -> Dict:
"""Create a dictionary that can be part of a message content list.
Args:
message_str: Message as an string.
Returns:
Dictionary that can be part of a message content list.
"""
return {"type": "text", "text": message_str}


def create_image_content_part(image_str: str) -> Dict:
"""Create a dictionary that can be part of a message content list.
Args:
image_str: Can be either:
- b64 encoded image data
- GCS uri
- Url
- Path to an image.
Returns:
Dictionary that can be part of a message content list.
"""
return {"type": "image_url", "image_url": {"url": image_str}}


def get_image_str_from_content_part(content_part: str | Dict) -> str | None:
"""Parses an image string from a dictionary with the correct format.
Args:
content_part: String or dictionary.
Returns:
Image string if the dictionary has the correct format otherwise None.
"""

if isinstance(content_part, str):
return None

if content_part.get("type") != "image_url":
return None

image_str = content_part.get("image_url", {}).get("url")

if isinstance(image_str, str):
return image_str
else:
return None


def get_text_str_from_content_part(content_part: str | Dict) -> str | None:
"""Parses an string from a dictionary or string with the correct format.
Args:
content_part: String or dictionary.
Returns:
String if the dictionary has the correct format or the input is an string,
otherwise None.
"""

if isinstance(content_part, str):
return content_part

if content_part.get("type") != "text":
return None

text = content_part.get("text")

if isinstance(text, str):
return text
else:
return None
247 changes: 219 additions & 28 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, BaseLLM
Expand All @@ -9,9 +9,19 @@
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import BaseModel, Field
from vertexai.preview.vision_models import ( # type: ignore[import-untyped]
GeneratedImage,
ImageGenerationModel,
)
from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped]

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
create_image_content_part,
get_image_str_from_content_part,
get_text_str_from_content_part,
image_bytes_to_b64_string,
)


class _BaseImageTextModel(BaseModel):
Expand All @@ -23,7 +33,7 @@ class _BaseImageTextModel(BaseModel):
"""Number of results to return from one query"""
language: str = Field(default="en")
"""Language of the query"""
project: str = Field(default=None)
project: Union[str, None] = Field(default=None)
"""Google cloud project"""

def _create_model(self) -> ImageTextModel:
Expand All @@ -40,21 +50,15 @@ def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None
Image is successful otherwise None.
"""

if isinstance(message_part, str):
return None

if message_part.get("type") != "image_url":
return None

image_str = message_part.get("image_url", {}).get("url")
image_str = get_image_str_from_content_part(message_part)

if not isinstance(image_str, str):
if isinstance(image_str, str):
loader = ImageBytesLoader(project=self.project)
image_bytes = loader.load_bytes(image_str)
return Image(image_bytes=image_bytes)
else:
return None

loader = ImageBytesLoader(project=self.project)
image_bytes = loader.load_bytes(image_str)
return Image(image_bytes=image_bytes)

def _get_text_from_message_part(self, message_part: str | Dict) -> str | None:
"""Given a message part obtain a text if the part represents it.
Expand All @@ -64,19 +68,7 @@ def _get_text_from_message_part(self, message_part: str | Dict) -> str | None:
Returns:
str is successful otherwise None.
"""

if isinstance(message_part, str):
return message_part

if message_part.get("type") != "text":
return None

message_text = message_part.get("text")

if not isinstance(message_text, str):
return None

return message_text
return get_text_str_from_content_part(message_part)

@property
def _llm_type(self) -> str:
Expand Down Expand Up @@ -279,3 +271,202 @@ def _ask_questions(self, image: Image, query: str) -> List[str]:
image=image, question=query, number_of_results=self.number_of_results
)
return answers


class _BaseVertexAIImageGenerator(BaseModel):
"""Base class form generation and edition of images."""

model_name: str = Field(default="imagegeneration@002")
"""Name of the base model"""
negative_prompt: Union[str, None] = Field(default=None)
"""A description of what you want to omit in
the generated images"""
number_of_images: int = Field(default=1)
"""Number of images to generate"""
guidance_scale: Union[float, None] = Field(default=None)
"""Controls the strength of the prompt"""
language: Union[str, None] = Field(default=None)
"""Language of the text prompt for the image Supported values are "en" for English,
"hi" for Hindi, "ja" for Japanese, "ko" for Korean, and "auto" for automatic
language detection"""
seed: Union[int, None] = Field(default=None)
"""Random seed for the image generation"""
project: Union[str, None] = Field(default=None)
"""Google cloud project id"""

def _generate_images(self, prompt: str) -> List[str]:
"""Generates images given a prompt.
Args:
prompt: Description of what the image should look like.
Returns:
List of b64 encoded strings.
"""

model = ImageGenerationModel.from_pretrained(self.model_name)

generation_result = model.generate_images(
prompt=prompt,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
)

image_str_list = [
self._to_b64_string(image) for image in generation_result.images
]

return image_str_list

def _edit_images(self, image_str: str, prompt: str) -> List[str]:
"""Edit an image given a image and a prompt.
Args:
image_str: String representation of the image.
prompt: Description of what the image should look like.
Returns:
List of b64 encoded strings.
"""

model = ImageGenerationModel.from_pretrained(self.model_name)

image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(image_str)
image = Image(image_bytes=image_bytes)

generation_result = model.edit_image(
prompt=prompt,
base_image=image,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
)

image_str_list = [
self._to_b64_string(image) for image in generation_result.images
]

return image_str_list

def _to_b64_string(self, image: GeneratedImage) -> str:
"""Transforms a generated image into a b64 encoded string.
Args:
image: Image to convert.
Returns:
b64 encoded string of the image.
"""

# This is a hack because at the moment, GeneratedImage doesn't provide
# a way to get the bytes of the image (or anything else). There is
# only private methods that are not reliable.

from tempfile import NamedTemporaryFile

temp_file = NamedTemporaryFile()
image.save(temp_file.name, include_generation_parameters=False)
temp_file.seek(0)
image_bytes = temp_file.read()
temp_file.close()

return image_bytes_to_b64_string(image_bytes=image_bytes)

@property
def _llm_type(self) -> str:
"""Returns the type of LLM"""
return "vertexai-vision"


class VertexAIImageGeneratorChat(_BaseVertexAIImageGenerator, BaseChatModel):
"""Generates an image from a prompt."""

def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Args:
messages: The message must be a list of only one element with one part:
The user prompt.
"""

# Only one message allowed with one text part.
user_query = None
is_valid = len(messages) == 1 and len(messages[0].content) == 1
if is_valid:
user_query = get_text_str_from_content_part(messages[0].content[0])
if user_query is None:
raise ValueError(
"Only one message with one text part allowed for image generation"
" Must The prompt of the image"
)

image_str_list = self._generate_images(prompt=user_query)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
]

generations = [
ChatGeneration(message=AIMessage(content=[content_part]))
for content_part in image_content_part_list
]

return ChatResult(generations=generations)


class VertexAIImageEditorChat(_BaseVertexAIImageGenerator, BaseChatModel):
"""Given an image and a prompt, edits the image.
Currently only supports mask free editing.
"""

def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Args:
messages: The message must be a list of only one element with two part:
- The image as a dict {
'type': 'image_url', 'image_url': {'url': <message_str>}
}
- The user prompt.
"""

# Only one message allowed with two parts: the image and the text.
user_query = None
is_valid = len(messages) == 1 and len(messages[0].content) == 2
if is_valid:
image_str = get_image_str_from_content_part(messages[0].content[0])
user_query = get_text_str_from_content_part(messages[0].content[1])
if (user_query is None) or (image_str is None):
raise ValueError(
"Only one message allowed for image edition. The message must have"
"two parts: First the image and then the user prompt."
)

image_str_list = self._edit_images(image_str=image_str, prompt=user_query)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
]

generations = [
ChatGeneration(message=AIMessage(content=[content_part]))
for content_part in image_content_part_list
]

return ChatResult(generations=generations)
Loading

0 comments on commit ba64883

Please sign in to comment.