From 316b17900df5c2e2acb56525c722d98ea9c8b89c Mon Sep 17 00:00:00 2001 From: James Stout Date: Sun, 1 Dec 2024 14:23:25 -0800 Subject: [PATCH] Clean up CLI. --- .gitignore | 3 ++- src/perfect_prompt/cli.py | 48 +++++++++++++++++++++++++++-------- src/perfect_prompt/flux.py | 20 ++++++++++++--- src/perfect_prompt/fluxapi.py | 6 +++-- src/perfect_prompt/refine.py | 8 +++--- 5 files changed, 65 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index fdee42f..e170f64 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,6 @@ venv .pytest_cache *.egg-info .DS_Store -prompts/ +/prompts/ +/images/ .env diff --git a/src/perfect_prompt/cli.py b/src/perfect_prompt/cli.py index 6d74c3f..1f103c4 100644 --- a/src/perfect_prompt/cli.py +++ b/src/perfect_prompt/cli.py @@ -1,3 +1,5 @@ +from pathlib import Path + import click from . import flux, fluxapi, refine @@ -5,14 +7,20 @@ @click.command() @click.version_option() -@click.argument("prompt_path", type=click.Path(exists=True)) -@click.option("--iterations", "-n", default=3, help="Number of refinement iterations") +@click.argument("prompt", required=False) @click.option( - "--comfy-output-dir", + "--prompt-path", + type=click.Path(exists=True, path_type=Path), + help="Path to the prompt file", +) +@click.option( + "-o", + "--output-dir", required=True, - type=click.Path(exists=True), - help="Directory where generated images are found", + type=click.Path(writable=True, path_type=Path), + help="Directory where final images will be saved", ) +@click.option("--iterations", "-n", default=3, help="Number of refinement iterations") @click.option( "--refine-model", default="local-pixtral", @@ -26,6 +34,11 @@ ), help="Model to use for generating images", ) +@click.option( + "--comfy-output-dir", + type=click.Path(exists=True, path_type=Path), + help="Directory where generated images are found (required if using local-flux)", +) @click.option( "--raw", is_flag=True, @@ -44,17 +57,28 @@ help="Temperature setting for the refine prompt", ) def cli( - prompt_path, + prompt: str, + prompt_path: Path, + output_dir: Path, iterations, - comfy_output_dir, refine_model, gen_model, + comfy_output_dir: Path, raw, review_temperature, refine_temperature, ): - with open(prompt_path) as file: - initial_prompt = file.read().strip() + if prompt_path and prompt: + raise click.UsageError("Cannot use both --prompt-path and --prompt options.") + if not prompt_path and not prompt: + raise click.UsageError("One of --prompt-path or --prompt must be set.") + + if gen_model == "local-flux" and not comfy_output_dir: + raise click.UsageError("--comfy-output-dir is required when using local-flux.") + + initial_prompt = prompt_path.read_text().strip() if prompt_path else prompt.strip() + + output_dir.mkdir(exist_ok=True, parents=True) current_prompt = initial_prompt previous_attempts = [] @@ -65,7 +89,11 @@ def cli( image_module = flux if gen_model == "local-flux" else fluxapi current_image_path = image_module.generate_image( - current_prompt, comfy_output_dir, model=gen_model, raw=raw + current_prompt, + output_dir, + comfy_output_dir=comfy_output_dir, + model=gen_model, + raw=raw, ) if refine_model.startswith("local"): # Free up memory for the local model to use diff --git a/src/perfect_prompt/flux.py b/src/perfect_prompt/flux.py index d672c59..b790e82 100644 --- a/src/perfect_prompt/flux.py +++ b/src/perfect_prompt/flux.py @@ -199,23 +199,35 @@ """ -def generate_image(prompt, comfy_output_dir, model=None, raw=False): +def generate_image( + prompt, + output_dir: Path, + *, + comfy_output_dir: Path, + **_, +): # Get the initial list of files - initial_files = set(Path(comfy_output_dir).glob("*.png")) + initial_files = set(comfy_output_dir.glob("*.png")) # Generate image with Flux queue_prompt(prompt) # Wait for the filesystem to update with the new image while True: - current_files = set(Path(comfy_output_dir).glob("*.png")) + current_files = set(comfy_output_dir.glob("*.png")) new_files = current_files - initial_files if new_files: latest_image = max(new_files, key=os.path.getctime) break time.sleep(5) - return latest_image + return move_image_to_output(latest_image, output_dir) + + +def move_image_to_output(image_path: Path, output_dir: Path): + output_path = output_dir / image_path.name + image_path.rename(output_path) + return output_path def queue_prompt(prompt): diff --git a/src/perfect_prompt/fluxapi.py b/src/perfect_prompt/fluxapi.py index 6b0a538..0fa6636 100644 --- a/src/perfect_prompt/fluxapi.py +++ b/src/perfect_prompt/fluxapi.py @@ -13,7 +13,9 @@ HEADERS = {"x-key": API_KEY} -def generate_image(prompt, output_dir, model, width=1216, height=832, raw=False): +def generate_image( + prompt, output_dir: Path, *, model, width=1216, height=832, raw=False, **_ +): # Submit generation request with httpx.Client() as client: payload = { @@ -52,7 +54,7 @@ def generate_image(prompt, output_dir, model, width=1216, height=832, raw=False) # Save the image timestamp = int(time.time() * 1000) - output_path = Path(output_dir) / f"{model}_{timestamp}.png" + output_path = output_dir / f"{model}_{timestamp}.png" output_path.write_bytes(image_response.content) # Embed metadata diff --git a/src/perfect_prompt/refine.py b/src/perfect_prompt/refine.py index 2494c58..7fcaed8 100644 --- a/src/perfect_prompt/refine.py +++ b/src/perfect_prompt/refine.py @@ -1,5 +1,6 @@ import base64 import textwrap +from pathlib import Path import llm from mistral_common.protocol.instruct.messages import ( @@ -15,7 +16,7 @@ def refine_prompt( original_prompt, current_prompt, - current_image_path, + current_image_path: Path, previous_attempt_pairs, refine_model, review_temperature=None, @@ -37,8 +38,9 @@ def refine_prompt( if refine_model == "local-pixtral": # Read the image file and encode it in base64 - with open(current_image_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + encoded_string = base64.b64encode(current_image_path.read_bytes()).decode( + "utf-8" + ) url = f"data:image/png;base64,{encoded_string}" review_request = ChatCompletionRequest( messages=[