diff --git a/sdxl-turbo/entrypoint.py b/sdxl-turbo/entrypoint.py index a2400a5..e2c43d0 100644 --- a/sdxl-turbo/entrypoint.py +++ b/sdxl-turbo/entrypoint.py @@ -4,7 +4,7 @@ import os import sys import traceback - + import uvicorn from fastapi import FastAPI, HTTPException, Response, status @@ -16,7 +16,7 @@ if MODEL_NAME is None or CACHED_MODEL_PATH is None: logging.error("Environment variables MODEL_NAME and CACHED_MODEL_PATH must be set. See Dockerfile for values.") - sys.exit(1) + sys.exit(1) app = FastAPI() @@ -40,6 +40,7 @@ async def generate_t2i(request: SdxlTurboRequest): try: image = pipe_t2i( prompt=request.prompt, + negative_prompt=request.negative_prompt, strength=request.strength, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt, @@ -57,6 +58,7 @@ async def generate_t2i(request: SdxlTurboRequest): logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}") + # sdxl turbo image to image endpoint @app.post("/sdxl-turbo-i2i") async def generate_i2i(request: SdxlTurboRequest): @@ -64,10 +66,11 @@ async def generate_i2i(request: SdxlTurboRequest): init_image = request.image base64_decoded = base64.b64decode(init_image) input_image = Image.frombytes("RGB", (512, 512), base64_decoded, "raw") - + image = pipe_i2i( image=input_image, prompt=request.prompt, + negative_prompt=request.negative_prompt, strength=request.strength, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt, @@ -87,4 +90,4 @@ async def generate_i2i(request: SdxlTurboRequest): if __name__ == "__main__": port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 - uvicorn.run("entrypoint:app", host="0.0.0.0", port=port) \ No newline at end of file + uvicorn.run("entrypoint:app", host="0.0.0.0", port=port) diff --git a/sdxl-turbo/sdxl_turbo.py b/sdxl-turbo/sdxl_turbo.py index 65c5f16..858be0c 100644 --- a/sdxl-turbo/sdxl_turbo.py +++ b/sdxl-turbo/sdxl_turbo.py @@ -2,16 +2,18 @@ from pydantic import BaseModel, Field import torch from typing import Optional - + class SdxlTurboRequest(BaseModel): prompt: str + negative_prompt: Optional[str] = None num_inference_steps: int = Field(default=4) guidance_scale: float = Field(default=0.0) strength: float = Field(default=1.0) num_images_per_prompt: int = Field(default=1) image: Optional[str] = None + def setup_pipeline(model_name: str, cached_model_path): pipe_t2i = AutoPipelineForText2Image.from_pretrained( model_name,