Skip to content

Commit

Permalink
Improvements to logging/metrics reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 30, 2024
1 parent 23dc337 commit ac001a9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
16 changes: 14 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ModelParams:
# Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]]
dims: list[list[int]]

base_model_name: str = "SDXL"
# Name of the IREE module for each submodel.
clip_module_name: str = "compiled_clip"
unet_module_name: str = "compiled_unet"
Expand Down Expand Up @@ -78,10 +79,13 @@ def max_unet_batch_size(self) -> int:
def max_vae_batch_size(self) -> int:
return self.vae_batch_sizes[-1]

@property
def all_batch_sizes(self) -> list:
return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes]

@property
def max_batch_size(self):
# TODO: a little work on the batcher should loosen this up.
return max(self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes)
return max(self.all_batch_sizes)

@staticmethod
def load_json(path: Path | str):
Expand All @@ -91,3 +95,11 @@ def load_json(path: Path | str):
if isinstance(raw_params.unet_dtype, str):
raw_params.unet_dtype = str_to_dtype[raw_params.unet_dtype]
return raw_params

def __repr__(self):
return (
f"base model: {self.base_model_name} \n"
f" output size (H,W): {self.dims} \n"
f" max token sequence length : {self.max_seq_len} \n"
f" classifier free guidance : {self.cfg_mode} \n"
)
20 changes: 17 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,25 @@ def shutdown(self):
self.batcher.shutdown()

def __repr__(self):
modules = [
f" {key} : {value}" for key, value in self.inference_modules.items()
]
params = [
f" {key} : {value}" for key, value in self.inference_parameters.items()
]
return (
f"ServiceManager(\n"
f" model_params={self.model_params}\n"
f" inference_modules={self.inference_modules}\n"
f" inference_parameters={self.inference_parameters}\n"
f"\n INFERENCE DEVICES : \n"
f" {self.sysman.ls.devices}\n"
f"\n MODEL PARAMS: \n"
f" {self.model_params}"
f"\n SERVICE PARAMS: \n"
f" fibers per device: {self.fibers_per_device}\n"
f" program isolation mode: {self.prog_isolation}\n"
f"\n INFERENCE MODULES : \n"
f"{'\n'.join(modules)}\n"
f"\n INFERENCE PARAMETERS : \n"
f"{'\n'.join(params)}\n"
f")"
)

Expand Down
10 changes: 10 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any:
ret = await fn(*args, **kwargs)
duration_str = get_duration_str(start)
logger.info(f"Completed {fn.__qualname__} in {duration_str}")
if fn.__qualname__ == "ClientGenerateBatchProcess.run":
sps_str = get_samples_per_second(start, *args)
logger.info(f"SAMPLES PER SECOND = {sps_str}")
return ret

if asyncio.iscoroutinefunction(fn):
Expand All @@ -39,6 +42,13 @@ async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any:
return wrapped_fn


def get_samples_per_second(start, *args: Any) -> str:
duration = time.time() - start
bs = args[0].gen_req.num_output_images
sps = str(float(bs) / duration)
return sps


def get_duration_str(start: float) -> str:
"""Get human readable duration string from start time"""
duration = time.time() - start
Expand Down

0 comments on commit ac001a9

Please sign in to comment.