Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shortfin-sd) Adds tooling for performance measurement. #380

Merged
merged 8 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ModelParams:
# Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]]
dims: list[list[int]]

base_model_name: str = "SDXL"
# Name of the IREE module for each submodel.
clip_module_name: str = "compiled_clip"
unet_module_name: str = "compiled_unet"
Expand Down Expand Up @@ -78,10 +79,13 @@ def max_unet_batch_size(self) -> int:
def max_vae_batch_size(self) -> int:
return self.vae_batch_sizes[-1]

@property
def all_batch_sizes(self) -> list:
return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes]

@property
def max_batch_size(self):
# TODO: a little work on the batcher should loosen this up.
return max(self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes)
return max(self.all_batch_sizes)

@staticmethod
def load_json(path: Path | str):
Expand All @@ -91,3 +95,11 @@ def load_json(path: Path | str):
if isinstance(raw_params.unet_dtype, str):
raw_params.unet_dtype = str_to_dtype[raw_params.unet_dtype]
return raw_params

def __repr__(self):
return (
f" base model : {self.base_model_name} \n"
f" output size (H,W) : {self.dims} \n"
f" max token sequence length : {self.max_seq_len} \n"
f" classifier free guidance : {self.cfg_mode} \n"
)
2 changes: 2 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest, InferencePhase
from .service import GenerateService
from .metrics import measure

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
self.batcher = service.batcher
self.complete_infeed = self.system.create_queue()

@measure
async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)
try:
Expand Down
56 changes: 56 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
import time
import asyncio
from typing import Callable, Any

logger = logging.getLogger(__name__)

# measurement helper
def measure(fn):

"""
Decorator log test start and end time of a function
:param fn: Function to decorate
:return: Decorated function
Example:
>>> @timed
>>> def test_fn():
>>> time.sleep(1)
>>> test_fn()
"""

def wrapped_fn(*args: Any, **kwargs: Any) -> Any:
start = time.time()
ret = fn(*args, **kwargs)
duration_str = get_duration_str(start)
logger.info(f"Completed {fn.__qualname__} in {duration_str}")
return ret

async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any:
start = time.time()
ret = await fn(*args, **kwargs)
duration_str = get_duration_str(start)
logger.info(f"Completed {fn.__qualname__} in {duration_str}")
if fn.__qualname__ == "ClientGenerateBatchProcess.run":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on with this? Seems like you should just have a kwarg on the decorator to tell it what to do. There's a standard (but tricky) idiom for that using functools.partial...

Copy link
Contributor Author

@monorimet monorimet Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer -- I didn't end up using partial here but adapted to take a few kwargs. There's still some yuck there where we ping attributes of the decorated method's class to figure out batch size, but still more flexible than before. I may revisit later.

sps_str = get_samples_per_second(start, *args)
logger.info(f"SAMPLES PER SECOND = {sps_str}")
return ret

if asyncio.iscoroutinefunction(fn):
return wrapped_fn_async
else:
return wrapped_fn


def get_samples_per_second(start, *args: Any) -> str:
duration = time.time() - start
bs = args[0].gen_req.num_output_images
sps = str(float(bs) / duration)
return sps


def get_duration_str(start: float) -> str:
"""Get human readable duration string from start time"""
duration = time.time() - start
duration_str = f"{round(duration * 1e3)}ms"
return duration_str
28 changes: 23 additions & 5 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
from .tokenizer import Tokenizer
from .metrics import measure


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,6 +50,7 @@ def __init__(
fibers_per_device: int,
prog_isolation: str = "per_fiber",
show_progress: bool = False,
trace_execution: bool = False,
):
self.name = name

Expand All @@ -59,7 +62,7 @@ def __init__(
self.inference_modules: dict[str, sf.ProgramModule] = {}
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
self.inference_programs: dict[str, sf.Program] = {}
self.trace_execution = False
self.trace_execution = trace_execution
self.show_progress = show_progress
self.fibers_per_device = fibers_per_device
self.prog_isolation = prog_isolations[prog_isolation]
Expand Down Expand Up @@ -152,11 +155,25 @@ def shutdown(self):
self.batcher.shutdown()

def __repr__(self):
modules = [
f" {key} : {value}" for key, value in self.inference_modules.items()
]
params = [
f" {key} : {value}" for key, value in self.inference_parameters.items()
]
return (
f"ServiceManager(\n"
f" model_params={self.model_params}\n"
f" inference_modules={self.inference_modules}\n"
f" inference_parameters={self.inference_parameters}\n"
f"ServiceManager("
f"\n INFERENCE DEVICES : \n"
f" {self.sysman.ls.devices}\n"
f"\n MODEL PARAMS : \n"
f"{self.model_params}"
f"\n SERVICE PARAMS : \n"
f" fibers per device : {self.fibers_per_device}\n"
f" program isolation mode : {self.prog_isolation}\n"
f"\n INFERENCE MODULES : \n"
f"{'\n'.join(modules)}\n"
f"\n INFERENCE PARAMETERS : \n"
f"{'\n'.join(params)}\n"
f")"
)

Expand Down Expand Up @@ -301,6 +318,7 @@ def __init__(
self.worker_index = index
self.exec_requests: list[InferenceExecRequest] = []

@measure
async def run(self):
try:
phase = None
Expand Down
16 changes: 1 addition & 15 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,9 @@
1024
],
"steps": [
20,
30,
40,
50,
20,
30,
40,
50
20
],
"guidance_scale": [
7.5,
7.5,
7.5,
7.5,
7.5,
7.5,
7.5,
7.5
],
"seed": [
Expand Down
9 changes: 8 additions & 1 deletion shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ async def lifespan(app: FastAPI):
sysman.start()
try:
for service_name, service in services.items():
logging.info("Initializing service '%s': %r", service_name, service)
logging.info("Initializing service '%s':", service_name)
logging.info(str(service))
service.start()
except:
sysman.shutdown()
Expand Down Expand Up @@ -96,6 +97,7 @@ def configure(args) -> SystemManager:
fibers_per_device=args.fibers_per_device,
prog_isolation=args.isolation,
show_progress=args.show_progress,
trace_execution=args.trace_execution,
)
sm.load_inference_module(args.clip_vmfb, component="clip")
sm.load_inference_module(args.unet_vmfb, component="unet")
Expand Down Expand Up @@ -217,6 +219,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
action="store_true",
help="enable tqdm progress for unet iterations.",
)
parser.add_argument(
"--trace_execution",
action="store_true",
help="Enable tracing of program modules.",
)
log_levels = {
"info": logging.INFO,
"debug": logging.DEBUG,
Expand Down
3 changes: 1 addition & 2 deletions shortfin/tests/apps/sd/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def __init__(self, args):
stdout=sys.stdout,
stderr=sys.stderr,
)
print(self.process.args)
self._wait_for_ready()

def _wait_for_ready(self):
Expand All @@ -210,7 +209,7 @@ def _wait_for_ready(self):
return
except Exception as e:
if self.process.errors is not None:
raise RuntimeError("API server processs terminated") from e
raise RuntimeError("API server process terminated") from e
time.sleep(1.0)
if (time.time() - start) > 30:
raise RuntimeError("Timeout waiting for server start")
Expand Down
Loading