From 23dc3376daea950fa9acc1417f4935290e65d66a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 30 Oct 2024 10:42:07 -0500 Subject: [PATCH 1/6] Add basic perf measurements --- .../shortfin_apps/sd/components/generate.py | 2 + .../shortfin_apps/sd/components/service.py | 6 ++- .../shortfin_apps/sd/components/utils.py | 46 +++++++++++++++++++ .../sd/examples/sdxl_request_bs8.json | 16 +------ shortfin/python/shortfin_apps/sd/server.py | 9 +++- shortfin/tests/apps/sd/e2e_test.py | 3 +- 6 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 shortfin/python/shortfin_apps/sd/components/utils.py diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index ca4f9799d..68b01e0d8 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -18,6 +18,7 @@ from .io_struct import GenerateReqInput from .messages import InferenceExecRequest, InferencePhase from .service import GenerateService +from .utils import measure logger = logging.getLogger(__name__) @@ -82,6 +83,7 @@ def __init__( self.batcher = service.batcher self.complete_infeed = self.system.create_queue() + @measure async def run(self): logger.debug("Started ClientBatchGenerateProcess: %r", self) try: diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 2deec49c0..df3fbeee0 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -21,6 +21,8 @@ from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage from .tokenizer import Tokenizer +from .utils import measure + logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ def __init__( fibers_per_device: int, prog_isolation: str = "per_fiber", show_progress: bool = False, + trace_execution: bool = False, ): self.name = name @@ -59,7 +62,7 @@ def __init__( self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} self.inference_programs: dict[str, sf.Program] = {} - self.trace_execution = False + self.trace_execution = trace_execution self.show_progress = show_progress self.fibers_per_device = fibers_per_device self.prog_isolation = prog_isolations[prog_isolation] @@ -301,6 +304,7 @@ def __init__( self.worker_index = index self.exec_requests: list[InferenceExecRequest] = [] + @measure async def run(self): try: phase = None diff --git a/shortfin/python/shortfin_apps/sd/components/utils.py b/shortfin/python/shortfin_apps/sd/components/utils.py new file mode 100644 index 000000000..bf57a5f50 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/utils.py @@ -0,0 +1,46 @@ +import logging +import time +import asyncio +from typing import Callable, Any + +logger = logging.getLogger(__name__) + +# measurement helper +def measure(fn): + + """ + Decorator log test start and end time of a function + :param fn: Function to decorate + :return: Decorated function + Example: + >>> @timed + >>> def test_fn(): + >>> time.sleep(1) + >>> test_fn() + """ + + def wrapped_fn(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = fn(*args, **kwargs) + duration_str = get_duration_str(start) + logger.info(f"Completed {fn.__qualname__} in {duration_str}") + return ret + + async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = await fn(*args, **kwargs) + duration_str = get_duration_str(start) + logger.info(f"Completed {fn.__qualname__} in {duration_str}") + return ret + + if asyncio.iscoroutinefunction(fn): + return wrapped_fn_async + else: + return wrapped_fn + + +def get_duration_str(start: float) -> str: + """Get human readable duration string from start time""" + duration = time.time() - start + duration_str = f"{round(duration * 1e3)}ms" + return duration_str diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json index be94293ae..394e3568e 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json @@ -19,23 +19,9 @@ 1024 ], "steps": [ - 20, - 30, - 40, - 50, - 20, - 30, - 40, - 50 + 20 ], "guidance_scale": [ - 7.5, - 7.5, - 7.5, - 7.5, - 7.5, - 7.5, - 7.5, 7.5 ], "seed": [ diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 0327b0a9f..837afb152 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -41,7 +41,8 @@ async def lifespan(app: FastAPI): sysman.start() try: for service_name, service in services.items(): - logging.info("Initializing service '%s': %r", service_name, service) + logging.info("Initializing service '%s':", service_name) + logging.info(str(service)) service.start() except: sysman.shutdown() @@ -96,6 +97,7 @@ def configure(args) -> SystemManager: fibers_per_device=args.fibers_per_device, prog_isolation=args.isolation, show_progress=args.show_progress, + trace_execution=args.trace_execution, ) sm.load_inference_module(args.clip_vmfb, component="clip") sm.load_inference_module(args.unet_vmfb, component="unet") @@ -217,6 +219,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): action="store_true", help="enable tqdm progress for unet iterations.", ) + parser.add_argument( + "--trace_execution", + action="store_true", + help="Enable tracing of program modules.", + ) log_levels = { "info": logging.INFO, "debug": logging.DEBUG, diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index 05b9ef69b..366fb9c2d 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -198,7 +198,6 @@ def __init__(self, args): stdout=sys.stdout, stderr=sys.stderr, ) - print(self.process.args) self._wait_for_ready() def _wait_for_ready(self): @@ -210,7 +209,7 @@ def _wait_for_ready(self): return except Exception as e: if self.process.errors is not None: - raise RuntimeError("API server processs terminated") from e + raise RuntimeError("API server process terminated") from e time.sleep(1.0) if (time.time() - start) > 30: raise RuntimeError("Timeout waiting for server start") From ac001a91270a5598966d0f0cf4f9570b9efb4bf7 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 30 Oct 2024 14:12:06 -0500 Subject: [PATCH 2/6] Improvements to logging/metrics reporting --- .../sd/components/config_struct.py | 16 +++++++++++++-- .../shortfin_apps/sd/components/service.py | 20 ++++++++++++++++--- .../shortfin_apps/sd/components/utils.py | 10 ++++++++++ 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 96b9c32fd..8c2875be8 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -46,6 +46,7 @@ class ModelParams: # Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]] dims: list[list[int]] + base_model_name: str = "SDXL" # Name of the IREE module for each submodel. clip_module_name: str = "compiled_clip" unet_module_name: str = "compiled_unet" @@ -78,10 +79,13 @@ def max_unet_batch_size(self) -> int: def max_vae_batch_size(self) -> int: return self.vae_batch_sizes[-1] + @property + def all_batch_sizes(self) -> list: + return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes] + @property def max_batch_size(self): - # TODO: a little work on the batcher should loosen this up. - return max(self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes) + return max(self.all_batch_sizes) @staticmethod def load_json(path: Path | str): @@ -91,3 +95,11 @@ def load_json(path: Path | str): if isinstance(raw_params.unet_dtype, str): raw_params.unet_dtype = str_to_dtype[raw_params.unet_dtype] return raw_params + + def __repr__(self): + return ( + f"base model: {self.base_model_name} \n" + f" output size (H,W): {self.dims} \n" + f" max token sequence length : {self.max_seq_len} \n" + f" classifier free guidance : {self.cfg_mode} \n" + ) diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index df3fbeee0..2489d4e17 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -155,11 +155,25 @@ def shutdown(self): self.batcher.shutdown() def __repr__(self): + modules = [ + f" {key} : {value}" for key, value in self.inference_modules.items() + ] + params = [ + f" {key} : {value}" for key, value in self.inference_parameters.items() + ] return ( f"ServiceManager(\n" - f" model_params={self.model_params}\n" - f" inference_modules={self.inference_modules}\n" - f" inference_parameters={self.inference_parameters}\n" + f"\n INFERENCE DEVICES : \n" + f" {self.sysman.ls.devices}\n" + f"\n MODEL PARAMS: \n" + f" {self.model_params}" + f"\n SERVICE PARAMS: \n" + f" fibers per device: {self.fibers_per_device}\n" + f" program isolation mode: {self.prog_isolation}\n" + f"\n INFERENCE MODULES : \n" + f"{'\n'.join(modules)}\n" + f"\n INFERENCE PARAMETERS : \n" + f"{'\n'.join(params)}\n" f")" ) diff --git a/shortfin/python/shortfin_apps/sd/components/utils.py b/shortfin/python/shortfin_apps/sd/components/utils.py index bf57a5f50..f7b853139 100644 --- a/shortfin/python/shortfin_apps/sd/components/utils.py +++ b/shortfin/python/shortfin_apps/sd/components/utils.py @@ -31,6 +31,9 @@ async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: ret = await fn(*args, **kwargs) duration_str = get_duration_str(start) logger.info(f"Completed {fn.__qualname__} in {duration_str}") + if fn.__qualname__ == "ClientGenerateBatchProcess.run": + sps_str = get_samples_per_second(start, *args) + logger.info(f"SAMPLES PER SECOND = {sps_str}") return ret if asyncio.iscoroutinefunction(fn): @@ -39,6 +42,13 @@ async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: return wrapped_fn +def get_samples_per_second(start, *args: Any) -> str: + duration = time.time() - start + bs = args[0].gen_req.num_output_images + sps = str(float(bs) / duration) + return sps + + def get_duration_str(start: float) -> str: """Get human readable duration string from start time""" duration = time.time() - start From 8a3476bc98b34ff24fd4fbbdbe4b9074ec42fbb3 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 30 Oct 2024 14:17:51 -0500 Subject: [PATCH 3/6] Remove extra newline --- shortfin/python/shortfin_apps/sd/components/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 2489d4e17..6196db424 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -162,7 +162,7 @@ def __repr__(self): f" {key} : {value}" for key, value in self.inference_parameters.items() ] return ( - f"ServiceManager(\n" + f"ServiceManager(" f"\n INFERENCE DEVICES : \n" f" {self.sysman.ls.devices}\n" f"\n MODEL PARAMS: \n" From fdb50d178dc3540a373f6f43d768e96748fe0f40 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 30 Oct 2024 14:20:19 -0500 Subject: [PATCH 4/6] Rename utils -> metrics --- shortfin/python/shortfin_apps/sd/components/generate.py | 2 +- .../python/shortfin_apps/sd/components/{utils.py => metrics.py} | 0 shortfin/python/shortfin_apps/sd/components/service.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename shortfin/python/shortfin_apps/sd/components/{utils.py => metrics.py} (100%) diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 68b01e0d8..a5eb6ad38 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -18,7 +18,7 @@ from .io_struct import GenerateReqInput from .messages import InferenceExecRequest, InferencePhase from .service import GenerateService -from .utils import measure +from .metrics import measure logger = logging.getLogger(__name__) diff --git a/shortfin/python/shortfin_apps/sd/components/utils.py b/shortfin/python/shortfin_apps/sd/components/metrics.py similarity index 100% rename from shortfin/python/shortfin_apps/sd/components/utils.py rename to shortfin/python/shortfin_apps/sd/components/metrics.py diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 6196db424..a20aabc42 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -21,7 +21,7 @@ from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage from .tokenizer import Tokenizer -from .utils import measure +from .metrics import measure logger = logging.getLogger(__name__) From d8ff71ac0b83d837b08a2625b9b6b41b5065d18d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 30 Oct 2024 20:13:41 -0500 Subject: [PATCH 5/6] fixup indents in prints --- .../shortfin_apps/sd/components/config_struct.py | 4 ++-- shortfin/python/shortfin_apps/sd/components/service.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 8c2875be8..0d68aad8e 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -98,8 +98,8 @@ def load_json(path: Path | str): def __repr__(self): return ( - f"base model: {self.base_model_name} \n" - f" output size (H,W): {self.dims} \n" + f" base model : {self.base_model_name} \n" + f" output size (H,W) : {self.dims} \n" f" max token sequence length : {self.max_seq_len} \n" f" classifier free guidance : {self.cfg_mode} \n" ) diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index a20aabc42..81f4d8648 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -165,11 +165,11 @@ def __repr__(self): f"ServiceManager(" f"\n INFERENCE DEVICES : \n" f" {self.sysman.ls.devices}\n" - f"\n MODEL PARAMS: \n" - f" {self.model_params}" - f"\n SERVICE PARAMS: \n" - f" fibers per device: {self.fibers_per_device}\n" - f" program isolation mode: {self.prog_isolation}\n" + f"\n MODEL PARAMS : \n" + f"{self.model_params}" + f"\n SERVICE PARAMS : \n" + f" fibers per device : {self.fibers_per_device}\n" + f" program isolation mode : {self.prog_isolation}\n" f"\n INFERENCE MODULES : \n" f"{'\n'.join(modules)}\n" f"\n INFERENCE PARAMETERS : \n" From e49ee523c31515b0fd8fdac1ed01bbd47bee1649 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 31 Oct 2024 10:03:46 -0500 Subject: [PATCH 6/6] Update decorator implementation. --- .../shortfin_apps/sd/components/generate.py | 2 +- .../shortfin_apps/sd/components/metrics.py | 74 ++++++++----------- .../shortfin_apps/sd/components/service.py | 2 +- 3 files changed, 34 insertions(+), 44 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index a5eb6ad38..ebb5ea08a 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -83,7 +83,7 @@ def __init__( self.batcher = service.batcher self.complete_infeed = self.system.create_queue() - @measure + @measure(type="throughput", num_items="num_output_images", freq=1, label="samples") async def run(self): logger.debug("Started ClientBatchGenerateProcess: %r", self) try: diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index f7b853139..6d3c1aa8b 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -2,55 +2,45 @@ import time import asyncio from typing import Callable, Any +import functools logger = logging.getLogger(__name__) -# measurement helper -def measure(fn): - - """ - Decorator log test start and end time of a function - :param fn: Function to decorate - :return: Decorated function - Example: - >>> @timed - >>> def test_fn(): - >>> time.sleep(1) - >>> test_fn() - """ - - def wrapped_fn(*args: Any, **kwargs: Any) -> Any: - start = time.time() - ret = fn(*args, **kwargs) - duration_str = get_duration_str(start) - logger.info(f"Completed {fn.__qualname__} in {duration_str}") - return ret - - async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: - start = time.time() - ret = await fn(*args, **kwargs) - duration_str = get_duration_str(start) - logger.info(f"Completed {fn.__qualname__} in {duration_str}") - if fn.__qualname__ == "ClientGenerateBatchProcess.run": - sps_str = get_samples_per_second(start, *args) - logger.info(f"SAMPLES PER SECOND = {sps_str}") - return ret - - if asyncio.iscoroutinefunction(fn): + +def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): + assert callable(fn) or fn is None + + def _decorator(func): + @functools.wraps(func) + async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = await func(*args, **kwargs) + duration = time.time() - start + if type == "exec": + batch_size = len(getattr(args[0], "exec_requests", [])) + log_duration_str(duration, task=task, batch_size=batch_size) + if type == "throughput": + if isinstance(num_items, str): + items = getattr(args[0].gen_req, num_items) + else: + items = str(num_items) + log_throughput(duration, items, freq, label) + return ret + return wrapped_fn_async - else: - return wrapped_fn + + return _decorator(fn) if callable(fn) else _decorator -def get_samples_per_second(start, *args: Any) -> str: - duration = time.time() - start - bs = args[0].gen_req.num_output_images - sps = str(float(bs) / duration) - return sps +def log_throughput(duration, num_items, freq, label) -> str: + sps = str(float(num_items) / duration) * freq + freq_str = "second" if freq == 1 else f"{freq} seconds" + logger.info(f"THROUGHPUT: {sps} {label} per {freq_str}") -def get_duration_str(start: float) -> str: +def log_duration_str(duration: float, task, batch_size=0) -> str: """Get human readable duration string from start time""" - duration = time.time() - start + if batch_size > 0: + task = f"{task} (batch size {batch_size})" duration_str = f"{round(duration * 1e3)}ms" - return duration_str + logger.info(f"Completed {task} in {duration_str}") diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 81f4d8648..f19591bd9 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -318,7 +318,7 @@ def __init__( self.worker_index = index self.exec_requests: list[InferenceExecRequest] = [] - @measure + @measure(type="exec", task="inference process") async def run(self): try: phase = None