Skip to content

Commit

Permalink
Clean up CLI.
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfmanstout committed Dec 1, 2024
1 parent fb55d57 commit 316b179
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ venv
.pytest_cache
*.egg-info
.DS_Store
prompts/
/prompts/
/images/
.env
48 changes: 38 additions & 10 deletions src/perfect_prompt/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from pathlib import Path

import click

from . import flux, fluxapi, refine


@click.command()
@click.version_option()
@click.argument("prompt_path", type=click.Path(exists=True))
@click.option("--iterations", "-n", default=3, help="Number of refinement iterations")
@click.argument("prompt", required=False)
@click.option(
"--comfy-output-dir",
"--prompt-path",
type=click.Path(exists=True, path_type=Path),
help="Path to the prompt file",
)
@click.option(
"-o",
"--output-dir",
required=True,
type=click.Path(exists=True),
help="Directory where generated images are found",
type=click.Path(writable=True, path_type=Path),
help="Directory where final images will be saved",
)
@click.option("--iterations", "-n", default=3, help="Number of refinement iterations")
@click.option(
"--refine-model",
default="local-pixtral",
Expand All @@ -26,6 +34,11 @@
),
help="Model to use for generating images",
)
@click.option(
"--comfy-output-dir",
type=click.Path(exists=True, path_type=Path),
help="Directory where generated images are found (required if using local-flux)",
)
@click.option(
"--raw",
is_flag=True,
Expand All @@ -44,17 +57,28 @@
help="Temperature setting for the refine prompt",
)
def cli(
prompt_path,
prompt: str,
prompt_path: Path,
output_dir: Path,
iterations,
comfy_output_dir,
refine_model,
gen_model,
comfy_output_dir: Path,
raw,
review_temperature,
refine_temperature,
):
with open(prompt_path) as file:
initial_prompt = file.read().strip()
if prompt_path and prompt:
raise click.UsageError("Cannot use both --prompt-path and --prompt options.")
if not prompt_path and not prompt:
raise click.UsageError("One of --prompt-path or --prompt must be set.")

if gen_model == "local-flux" and not comfy_output_dir:
raise click.UsageError("--comfy-output-dir is required when using local-flux.")

initial_prompt = prompt_path.read_text().strip() if prompt_path else prompt.strip()

output_dir.mkdir(exist_ok=True, parents=True)

current_prompt = initial_prompt
previous_attempts = []
Expand All @@ -65,7 +89,11 @@ def cli(
image_module = flux if gen_model == "local-flux" else fluxapi

current_image_path = image_module.generate_image(
current_prompt, comfy_output_dir, model=gen_model, raw=raw
current_prompt,
output_dir,
comfy_output_dir=comfy_output_dir,
model=gen_model,
raw=raw,
)
if refine_model.startswith("local"):
# Free up memory for the local model to use
Expand Down
20 changes: 16 additions & 4 deletions src/perfect_prompt/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,35 @@
"""


def generate_image(prompt, comfy_output_dir, model=None, raw=False):
def generate_image(
prompt,
output_dir: Path,
*,
comfy_output_dir: Path,
**_,
):
# Get the initial list of files
initial_files = set(Path(comfy_output_dir).glob("*.png"))
initial_files = set(comfy_output_dir.glob("*.png"))

# Generate image with Flux
queue_prompt(prompt)

# Wait for the filesystem to update with the new image
while True:
current_files = set(Path(comfy_output_dir).glob("*.png"))
current_files = set(comfy_output_dir.glob("*.png"))
new_files = current_files - initial_files
if new_files:
latest_image = max(new_files, key=os.path.getctime)
break
time.sleep(5)

return latest_image
return move_image_to_output(latest_image, output_dir)


def move_image_to_output(image_path: Path, output_dir: Path):
output_path = output_dir / image_path.name
image_path.rename(output_path)
return output_path


def queue_prompt(prompt):
Expand Down
6 changes: 4 additions & 2 deletions src/perfect_prompt/fluxapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
HEADERS = {"x-key": API_KEY}


def generate_image(prompt, output_dir, model, width=1216, height=832, raw=False):
def generate_image(
prompt, output_dir: Path, *, model, width=1216, height=832, raw=False, **_
):
# Submit generation request
with httpx.Client() as client:
payload = {
Expand Down Expand Up @@ -52,7 +54,7 @@ def generate_image(prompt, output_dir, model, width=1216, height=832, raw=False)

# Save the image
timestamp = int(time.time() * 1000)
output_path = Path(output_dir) / f"{model}_{timestamp}.png"
output_path = output_dir / f"{model}_{timestamp}.png"
output_path.write_bytes(image_response.content)

# Embed metadata
Expand Down
8 changes: 5 additions & 3 deletions src/perfect_prompt/refine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import textwrap
from pathlib import Path

import llm
from mistral_common.protocol.instruct.messages import (
Expand All @@ -15,7 +16,7 @@
def refine_prompt(
original_prompt,
current_prompt,
current_image_path,
current_image_path: Path,
previous_attempt_pairs,
refine_model,
review_temperature=None,
Expand All @@ -37,8 +38,9 @@ def refine_prompt(

if refine_model == "local-pixtral":
# Read the image file and encode it in base64
with open(current_image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
encoded_string = base64.b64encode(current_image_path.read_bytes()).decode(
"utf-8"
)
url = f"data:image/png;base64,{encoded_string}"
review_request = ChatCompletionRequest(
messages=[
Expand Down

0 comments on commit 316b179

Please sign in to comment.