diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index d17606b4b..c668e6a8b 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -428,16 +428,15 @@ ConfigOptions CreateConfigOptions(std::optional &env_prefix, } // namespace NB_MODULE(lib, m) { -// Tragically, debug builds of Python do the right thing and don't immortalize -// many identifiers and such. This makes the last chance leak checking that -// nanobind does somewhat unreliable since the reports it prints may be -// to identifiers that are no longer live (at a time in process shutdown -// where it is expected that everything left just gets dropped on the floor). -// This causes segfaults or ASAN violations in the leak checker on exit in -// certain scenarios where we have spurious "leaks" of global objects. -#if defined(Py_DEBUG) + // Tragically, debug builds of Python do the right thing and don't immortalize + // many identifiers and such. This makes the last chance leak checking that + // nanobind does somewhat unreliable since the reports it prints may be + // to identifiers that are no longer live (at a time in process shutdown + // where it is expected that everything left just gets dropped on the floor). + // This causes segfaults or ASAN violations in the leak checker on exit in + // certain scenarios where we have spurious "leaks" of global objects. + py::set_leak_warnings(false); -#endif logging::InitializeFromEnv(); diff --git a/shortfin/python/shortfin_apps/sd/_deps.py b/shortfin/python/shortfin_apps/sd/_deps.py index 9965065ce..92bd089ec 100644 --- a/shortfin/python/shortfin_apps/sd/_deps.py +++ b/shortfin/python/shortfin_apps/sd/_deps.py @@ -9,7 +9,7 @@ try: import transformers except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "diffusers") from e + raise ShortfinDepNotFoundError(__name__, "transformers") from e try: import tokenizers diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index 2e4f5b688..98678c46d 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -165,6 +165,7 @@ def needs_update(ctx): def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + needed = True if os.path.exists(out_file): if url: needed = not is_valid_size(out_file, url) @@ -176,16 +177,13 @@ def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): def needs_compile(filename, target, ctx): - device = "amdgpu" if "gfx" in target else "llvmcpu" - vmfb_name = f"{filename}_{device}-{target}.vmfb" + vmfb_name = f"{filename}_{target}.vmfb" namespace = FileNamespace.BIN return needs_file(vmfb_name, ctx, namespace=namespace) def get_cached_vmfb(filename, target, ctx): - device = "amdgpu" if "gfx" in target else "llvmcpu" - vmfb_name = f"{filename}_{device}-{target}.vmfb" - namespace = FileNamespace.BIN + vmfb_name = f"{filename}_{target}.vmfb" return ctx.file(vmfb_name) diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py index cfc3192cc..432f08b4e 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -6,10 +6,7 @@ from iree.build import * from iree.build.executor import FileNamespace -import itertools import os -import shortfin.array as sfnp -import copy ARTIFACT_VERSION = "11182024" SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" @@ -72,21 +69,18 @@ def sdxlconfig( model_config_filenames = [f"{model}_config_i8.json"] model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) for f, url in model_config_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) topology_config_filenames = [f"topology_config_{topology}.txt"] topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) for f, url in topology_config_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) flagfile_filenames = [f"{model}_flagfile_{target}.txt"] flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) for f, url in flagfile_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) @@ -95,7 +89,6 @@ def sdxlconfig( ) tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) for f, url in tuning_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) filenames = [ diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 478d03ad8..2b954c18b 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -15,7 +15,6 @@ from dataclasses import dataclass from pathlib import Path -import dataclasses_json from dataclasses_json import dataclass_json, Undefined import shortfin.array as sfnp diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 1afa73d5e..62ac5e855 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -5,18 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio -import io import logging import json import shortfin as sf -import shortfin.array as sfnp # TODO: Have a generic "Responder" interface vs just the concrete impl. from shortfin.interop.fastapi import FastAPIResponder from .io_struct import GenerateReqInput -from .messages import InferenceExecRequest, InferencePhase +from .messages import InferenceExecRequest from .service import GenerateService from .metrics import measure @@ -83,7 +81,6 @@ def __init__( self.batcher = service.batcher self.complete_infeed = self.system.create_queue() - @measure(type="throughput", num_items="num_output_images", freq=1, label="samples") async def run(self): logger.debug("Started ClientBatchGenerateProcess: %r", self) try: diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index d1d9cf41a..73e77316f 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -4,12 +4,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union from dataclasses import dataclass import uuid -import shortfin.array as sfnp - @dataclass class GenerateReqInput: diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index ea29b69a4..e416592d0 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True): sb.visible_devices = sb.available_devices sb.visible_devices = get_selected_devices(sb, device_ids) self.ls = sb.create_system() - logging.info(f"Created local system with {self.ls.device_names} devices") + logger.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually # running a thread. self.t = threading.Thread(target=lambda: self.ls.run(self.run())) @@ -39,9 +39,10 @@ def start(self): def shutdown(self): logger.info("Shutting down system manager") self.command_queue.close() + self.ls.shutdown() async def run(self): reader = self.command_queue.reader() while command := await reader(): ... - logging.info("System manager command processor stopped") + logger.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index a1811beea..62e855698 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -6,8 +6,7 @@ import logging import time -import asyncio -from typing import Callable, Any +from typing import Any import functools logger = logging.getLogger("shortfin-sd.metrics") diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index ad3fd9404..9b09632a6 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -6,12 +6,10 @@ import asyncio import logging -import math import numpy as np from tqdm.auto import tqdm from pathlib import Path from PIL import Image -import io import base64 import shortfin as sf @@ -23,9 +21,7 @@ from .tokenizer import Tokenizer from .metrics import measure - logger = logging.getLogger("shortfin-sd.service") -logger.setLevel(logging.DEBUG) prog_isolations = { "none": sf.ProgramIsolation.NONE, @@ -79,23 +75,32 @@ def __init__( self.workers = [] self.fibers = [] - self.fiber_status = [] + self.idle_fibers = set() for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.workers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") self.workers.append(worker) + for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.fibers_per_device): - fiber = sysman.ls.create_fiber( - self.workers[i % len(self.workers)], devices=[device] - ) + tgt_worker = self.workers[i % len(self.workers)] + fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) self.fibers.append(fiber) - self.fiber_status.append(0) + self.idle_fibers.add(fiber) for idx in range(len(self.workers)): self.inference_programs[idx] = {} self.inference_functions[idx] = {} # Scope dependent objects. self.batcher = BatcherProcess(self) + def get_worker_index(self, fiber): + if fiber not in self.fibers: + raise ValueError("A worker was requested from a rogue fiber.") + fiber_idx = self.fibers.index(fiber) + worker_idx = int( + (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + ) + return worker_idx + def load_inference_module(self, vmfb_path: Path, component: str = None): if not self.inference_modules.get(component): self.inference_modules[component] = [] @@ -112,7 +117,7 @@ def load_inference_parameters( ): p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) for path in paths: - logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) p.load(path, format=format) if not self.inference_parameters.get(component): self.inference_parameters[component] = [] @@ -121,6 +126,7 @@ def load_inference_parameters( def start(self): # Initialize programs. for component in self.inference_modules: + logger.info(f"Loading component: {component}") component_modules = [ sf.ProgramModule.parameter_provider( self.sysman.ls, *self.inference_parameters.get(component, []) @@ -141,7 +147,6 @@ def start(self): isolation=self.prog_isolation, trace_execution=self.trace_execution, ) - logger.info("Program loaded.") for worker_idx, worker in enumerate(self.workers): self.inference_functions[worker_idx]["encode"] = {} @@ -270,14 +275,17 @@ def board_flights(self): return self.strobes = 0 batches = self.sort_batches() - for idx, batch in batches.items(): - for fidx, status in enumerate(self.service.fiber_status): - if ( - status == 0 - or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL - ): - self.board(batch["reqs"], index=fidx) - break + for batch in batches.values(): + # Assign the batch to the next idle fiber. + if len(self.service.idle_fibers) == 0: + return + fiber = self.service.idle_fibers.pop() + fiber_idx = self.service.fibers.index(fiber) + worker_idx = self.service.get_worker_index(fiber) + logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") + self.board(batch["reqs"], fiber=fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(fiber) def sort_batches(self): """Files pending requests into sorted batches suitable for program invocations.""" @@ -310,11 +318,11 @@ def sort_batches(self): } return batches - def board(self, request_bundle, index): + def board(self, request_bundle, fiber): pending = request_bundle if len(pending) == 0: return - exec_process = InferenceExecutorProcess(self.service, index) + exec_process = InferenceExecutorProcess(self.service, fiber) for req in pending: if len(exec_process.exec_requests) >= self.ideal_batch_size: break @@ -322,8 +330,6 @@ def board(self, request_bundle, index): if exec_process.exec_requests: for flighted_request in exec_process.exec_requests: self.pending_requests.remove(flighted_request) - if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL: - self.service.fiber_status[index] = 1 exec_process.launch() @@ -338,15 +344,11 @@ class InferenceExecutorProcess(sf.Process): def __init__( self, service: GenerateService, - index: int, + fiber, ): - super().__init__(fiber=service.fibers[index]) + super().__init__(fiber=fiber) self.service = service - self.fiber_index = index - self.worker_index = int( - (index - index % self.service.fibers_per_worker) - / self.service.fibers_per_worker - ) + self.worker_index = self.service.get_worker_index(fiber) self.exec_requests: list[InferenceExecRequest] = [] @measure(type="exec", task="inference process") @@ -360,7 +362,7 @@ async def run(self): phase = req.phase phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - device0 = self.service.fibers[self.fiber_index].device(0) + device0 = self.fiber.device(0) if phases[InferencePhase.PREPARE]["required"]: await self._prepare(device=device0, requests=self.exec_requests) if phases[InferencePhase.ENCODE]["required"]: @@ -375,7 +377,8 @@ async def run(self): for i in range(req_count): req = self.exec_requests[i] req.done.set_success() - self.service.fiber_status[self.fiber_index] = 0 + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(self.fiber) except Exception: logger.exception("Fatal error in image generation") @@ -574,7 +577,7 @@ async def _denoise(self, device, requests): for i, t in tqdm( enumerate(range(step_count)), disable=(not self.service.show_progress), - desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})", + desc=f"DENOISE (bs{req_bs})", ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() diff --git a/shortfin/python/shortfin_apps/sd/components/tokenizer.py b/shortfin/python/shortfin_apps/sd/components/tokenizer.py index 2bd3781d1..5903d89a5 100644 --- a/shortfin/python/shortfin_apps/sd/components/tokenizer.py +++ b/shortfin/python/shortfin_apps/sd/components/tokenizer.py @@ -4,12 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from pathlib import Path - from transformers import CLIPTokenizer, BatchEncoding -import numpy as np - import shortfin as sf import shortfin.array as sfnp diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 9cd624241..4e3835690 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -5,23 +5,21 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Any - import argparse import logging from pathlib import Path import sys import os -import io import copy import subprocess +from contextlib import asynccontextmanager +import uvicorn # Import first as it does dep checking and reporting. from shortfin.interop.fastapi import FastAPIResponder - -from contextlib import asynccontextmanager +from shortfin.support.logging_setup import native_handler from fastapi import FastAPI, Request, Response -import uvicorn from .components.generate import ClientGenerateBatchProcess from .components.config_struct import ModelParams @@ -29,25 +27,49 @@ from .components.manager import SystemManager from .components.service import GenerateService from .components.tokenizer import Tokenizer -from .components.builders import sdxl -from shortfin.support.logging_setup import native_handler, configure_main_logger logger = logging.getLogger("shortfin-sd") logger.addHandler(native_handler) -logger.setLevel(logging.INFO) logger.propagate = False THIS_DIR = Path(__file__).resolve().parent +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + @asynccontextmanager async def lifespan(app: FastAPI): sysman.start() try: for service_name, service in services.items(): - logging.info("Initializing service '%s':", service_name) - logging.info(str(service)) + logger.info("Initializing service '%s':", service_name) + logger.info(str(service)) service.start() except: sysman.shutdown() @@ -55,7 +77,7 @@ async def lifespan(app: FastAPI): yield try: for service_name, service in services.items(): - logging.info("Shutting down service '%s'", service_name) + logger.info("Shutting down service '%s'", service_name) service.shutdown() finally: sysman.shutdown() @@ -83,11 +105,14 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): app.put("/generate")(generate_request) -def configure(args) -> SystemManager: +def configure_sys(args) -> SystemManager: # Setup system (configure devices, etc). model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + return sysman, model_config, flagfile, tuning_spec + +def configure_service(args, sysman, model_config, flagfile, tuning_spec): # Setup each service we are hosting. tokenizers = [] for idx, tok_name in enumerate(args.tokenizers): @@ -163,13 +188,13 @@ def get_configs(args): try: val = int(val) except ValueError: - continue + val = val elif len(arglist) == 2: value = arglist[-1] try: value = int(value) except ValueError: - continue + value = value else: # It's a boolean arg. value = True @@ -178,7 +203,6 @@ def get_configs(args): # It's an env var. arglist = spec.split("=") os.environ[arglist[0]] = arglist[1] - return model_config, topology_config, flagfile, tuning_spec, args @@ -207,6 +231,7 @@ def get_modules(args, model_config, flagfile, td_spec): filenames = [] for modelname in vmfbs.keys(): ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) builder_args = [ sys.executable, "-m", @@ -220,8 +245,12 @@ def get_modules(args, model_config, flagfile, td_spec): f"--model={modelname}", f"--iree-hal-target-device={args.device}", f"--iree-hip-target={args.target}", - f"--iree-compile-extra-args={' '.join(ireec_args)}", + f"--iree-compile-extra-args={ireec_extra_args}", ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.debug( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) output = subprocess.check_output(builder_args).decode() output_paths = output.splitlines() @@ -229,16 +258,14 @@ def get_modules(args, model_config, flagfile, td_spec): for name in filenames: for key in vmfbs.keys(): if key in name.lower(): - if any([x in name for x in [".irpa", ".safetensors", ".gguf"]]): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): params[key].extend([name]) elif "vmfb" in name: vmfbs[key].extend([name]) return vmfbs, params -def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): - from pathlib import Path - +def main(argv, log_config=UVICORN_LOG_CONFIG): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) @@ -257,7 +284,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): type=str, required=False, default="gfx942", - choices=["gfx942", "gfx1100"], + choices=["gfx942", "gfx1100", "gfx90a"], help="Primary inferencing device LLVM target arch.", ) parser.add_argument( @@ -297,7 +324,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--isolation", type=str, - default="per_fiber", + default="per_call", choices=["per_fiber", "per_call", "none"], help="Concurrency control -- How to isolate programs.", ) @@ -365,15 +392,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): default=1, help="Use tunings for attention and matmul ops. 0 to disable.", ) - args = parser.parse_args(argv) if not args.artifacts_dir: home = Path.home() artdir = home / ".cache" / "shark" args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = Path(args.artifacts_dir).resolve() global sysman - sysman = configure(args) + sysman, model_config, flagfile, tuning_spec = configure_sys(args) + configure_service(args, sysman, model_config, flagfile, tuning_spec) uvicorn.run( app, host=args.host, @@ -388,27 +417,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): main( sys.argv[1:], # Make logging defer to the default shortfin logging config. - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(asctime)s - %(levelname)s - %(message)s", - "datefmt": "%Y-%m-%d %H:%M:%S", - }, - }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "default", - }, - }, - "loggers": { - "uvicorn": { - "handlers": ["console"], - "level": "INFO", - "propagate": False, - }, - }, - }, + log_config=UVICORN_LOG_CONFIG, ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index bc0f10655..0d88a59c7 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -4,17 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from datetime import datetime as dt +import os +import sys +import time import json -import requests import argparse import base64 -import time import asyncio import aiohttp -import sys -import os +import requests -from datetime import datetime as dt from PIL import Image sample_request = { @@ -32,10 +32,10 @@ } -def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024): +def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024): timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") image = Image.frombytes( - mode="RGB", size=(width, height), data=base64.b64decode(bytes) + mode="RGB", size=(width, height), data=base64.b64decode(in_bytes) ) if not os.path.isdir(outputdir): os.mkdir(outputdir) @@ -58,14 +58,13 @@ def get_batched(request, arg, idx): async def send_request(session, rep, args, data): print("Sending request batch #", rep) - url = f"http://0.0.0.0:{args.port}/generate" + url = f"{args.host}:{args.port}/generate" start = time.time() async with session.post(url, json=data) as response: end = time.time() # Check if the response was successful if response.status == 200: response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") res_json = await response.json(content_type=None) if args.save: for idx, item in enumerate(res_json["images"]): @@ -78,9 +77,8 @@ async def send_request(session, rep, args, data): latency = end - start print("Responses processed.") return latency, len(data["prompt"]) - else: - print(f"Error: Received {response.status} from server") - raise Exception + print(f"Error: Received {response.status} from server") + raise Exception async def static(args): @@ -116,7 +114,7 @@ async def static(args): latencies.append(latency) sample_counts.append(num_samples) end = time.time() - if not any([i is None for i in [latencies, sample_counts]]): + if not any(i is None for i in [latencies, sample_counts]): total_num_samples = sum(sample_counts) sps = str(total_num_samples / (end - start)) # Until we have better measurements, don't report the throughput that includes saving images. @@ -163,9 +161,9 @@ async def interactive(args): pending, return_when=asyncio.ALL_COMPLETED ) for task in done: - latency, num_samples = await task + _, _ = await task pending = [] - if any([i is None for i in [latencies, sample_counts]]): + if any(i is None for i in [latencies, sample_counts]): raise ValueError("Received error response from server.") @@ -175,11 +173,27 @@ async def ainput(prompt: str) -> str: async def async_range(count): for i in range(count): - yield (i) + yield i await asyncio.sleep(0.0) -def main(argv): +def check_health(url): + ready = False + print("Waiting for server.", end=None) + while not ready: + try: + if requests.get(f"{url}/health", timeout=20).status_code == 200: + print("Successfully connected to server.") + ready = True + return + time.sleep(2) + print(".", end=None) + except: + time.sleep(2) + print(".", end=None) + + +def main(): p = argparse.ArgumentParser() p.add_argument( "--file", @@ -205,6 +219,9 @@ def main(argv): default="gen_imgs", help="Directory to which images get saved.", ) + p.add_argument( + "--host", type=str, default="http://0.0.0.0", help="Server host address." + ) p.add_argument("--port", type=str, default="8000", help="Server port") p.add_argument( "--steps", @@ -218,6 +235,7 @@ def main(argv): help="Start as an example CLI client instead of sending static requests.", ) args = p.parse_args() + check_health(f"{args.host}:{args.port}") if args.interactive: asyncio.run(interactive(args)) else: @@ -225,4 +243,4 @@ def main(argv): if __name__ == "__main__": - main(sys.argv) + main()