Skip to content

Commit

Permalink
draw tool supports local storage.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mgrsc committed Dec 22, 2024
1 parent 8d68710 commit 04a1478
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 61 deletions.
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ services:
image: bitfennec/llmq-horizon:latest
container_name: llmq-horizon
restart: always
ports:
- "40000:5000"
volumes:
- ./config.toml:/app/config.toml
- ./config-tools.toml:/app/config-tools.toml
Expand Down
108 changes: 48 additions & 60 deletions tools/draw.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,53 @@
import fal_client
from openai import OpenAI
import os
from pathlib import Path
import base64
import requests
from datetime import datetime
from .config import config
from .prompt.prompt import prompt_all

draw_config = config.get("draw", {})
root_path = Path(__file__).resolve().parents[1]
temp_server_dir = root_path / "temp_server"
temp_server_dir.mkdir(parents=True, exist_ok=True)

os.environ["FAL_KEY"] = draw_config.get("fal_key")
os.environ["OPENAI_API_KEY"] = draw_config.get("openai_api_key")
os.environ["OPENAI_BASE_URL"]= draw_config.get("openai_base_url")
model = draw_config.get("model")
draw_config = config.get("draw", {})
os.environ.update({
"FAL_KEY": draw_config.get("fal_key"),
"OPENAI_API_KEY": draw_config.get("openai_api_key"),
"OPENAI_BASE_URL": draw_config.get("openai_base_url")
})

def save_image(url: str) -> None:
"""Save image from URL or base64 data to temp directory"""
filename = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
save_path = temp_server_dir / filename

try:
if url.startswith("data:image"):
image_data = base64.b64decode(url.split(",")[1])
save_path.write_bytes(image_data)
else:
response = requests.get(url)
response.raise_for_status()
save_path.write_bytes(response.content)
print(f"Image saved to {save_path}")
except Exception as e:
print(f"Error saving image: {e}")

def optimization_prompt(prompt: str) -> str:
def fal_draw(prompt: str, image_size: str = "square_hd", style: str = "any") -> str:
client = OpenAI()

system_prompt = prompt_all.get("draw")
completion = client.chat.completions.create(
model=model,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
model=draw_config.get("model"),
messages=[
{"role": "system", "content": prompt_all.get("draw")},
{"role": "user", "content": prompt}
]
)
content = completion.choices[0].message.content.strip()
print("Prompt Optimization: [" + content +"]")
return content
optimized_prompt = completion.choices[0].message.content.strip()
print(f"Optimized prompt: [{optimized_prompt}]")

def fal_draw(prompt: str, image_size: str = "square_hd", style: str = "any"):
print("Conduct prompt optimization ..........")
optimized_prompt = optimization_prompt(prompt)
# 提交请求
result = fal_client.submit(
"fal-ai/recraft-v3",
arguments={
Expand All @@ -39,58 +58,27 @@ def fal_draw(prompt: str, image_size: str = "square_hd", style: str = "any"):
"sync_mode": True
}
)

request_id = result.request_id

# 获取结果
result = fal_client.result("fal-ai/recraft-v3", request_id)

# 只处理第一张图片

result = fal_client.result("fal-ai/recraft-v3", result.request_id)
if result and result.get('images'):
images = result['images']
if images: # 确保 images 列表不为空
image = images[0]

url = image.get('url')
if url:
print(f"Image URL: {url}")
if url.startswith("data:image"):
try:
base64_data = url.split(",")[1]
image_data = base64.b64decode(base64_data)
image_file = "image.jpeg"
with open(image_file, "wb") as f:
f.write(image_data)
print(f"Image saved as {image_file}")
return image_file # 返回本地文件路径
except Exception as e:
print(f"Error decoding or saving Base64 image: {e}")
return None # 如果解码失败,返回 None
else:
print("URL is not a Base64 data URI.")
return "图片: " + url # 返回URL
else:
print("Image URL not found in the response")
return None # 没有 URL,返回 None
else:
print("No images found in the response")
return None # images 为空,返回 None
else:
print("Invalid result or no images in the response")
return None # result 或 images 不存在,返回 None

url = result['images'][0].get('url', '')
if url:
save_image(url)
if url.startswith("data:image"):
return "图片: " + url
return "图片: " + url
return None

from langchain_core.tools import tool

@tool
def draw(prompt: str, image_size: str = "square_hd", style: str = "any"):
"""根据prompt要求进行绘画然后返回链接
Args:
prompt: 要画的内容
image_size: 图片尺寸,可选值为 "square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9". 默认为 "square_hd"。
style: 图片风格,可选值为 "any", "realistic_image", "digital_illustration", "vector_illustration". 默认为 "any"。
image_size: 图片尺寸,可选值为 "square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
style: 图片风格,可选值为 "any", "realistic_image", "digital_illustration", "vector_illustration"
"""

return fal_draw(prompt, image_size, style)


tools = [draw]
1 change: 0 additions & 1 deletion tools/get_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import pytz
from langchain_core.tools import tool
from typing import Optional

@tool
def get_time(timezone: str, format: str = "%Y-%m-%d %H:%M:%S"):
Expand Down

0 comments on commit 04a1478

Please sign in to comment.