Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add Integration for OpenAI image gen with v1 sdk #17771

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Tool to generate an image using DALLE OpenAI V1 SDK."""

from langchain_community.tools.openai_dalle_image_generation.tool import (
OpenAIDALLEImageGenerationTool,
)

__all__ = ["OpenAIDALLEImageGenerationTool"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Tool for the OpenAI DALLE V1 Image Generation SDK."""

from typing import Optional

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool

from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper


class OpenAIDALLEImageGenerationTool(BaseTool):
"""Tool that generates an image using OpenAI DALLE."""

name: str = "openai_dalle"
description: str = (
"A wrapper around OpenAI DALLE Image Generation. "
"Useful for when you need to generate an image of"
"people, places, paintings, animals, or other subjects. "
"Input should be a text prompt to generate an image."
)
api_wrapper: DallEAPIWrapper

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the OpenAI DALLE Image Generation tool."""
return self.api_wrapper.run(query)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest.mock import MagicMock

from langchain_community.tools.openai_dalle_image_generation import (
OpenAIDALLEImageGenerationTool,
)


def test_generate_image() -> None:
"""Test OpenAI DALLE Image Generation."""
mock_api_resource = MagicMock()
# bypass pydantic validation as openai is not a package dependency
tool = OpenAIDALLEImageGenerationTool.construct(api_wrapper=mock_api_resource)
tool_input = {"query": "parrot on a branch"}
result = tool.run(tool_input)
assert result.startswith("https://")
Loading