Skip to content

Commit

Permalink
(shortfin-sd) Cleanup fiber distribution, logging, error handling. (#555
Browse files Browse the repository at this point in the history
)

- Switches fiber distribution to a much simpler process where idle
fibers are kept in a set and pop/replaced when used. For dense batches
this should be near optimal, though it cycles through fibers without
regard of each fiber's current load (this only matters in per_call/none
isolation where we allow in-fiber concurrency -- in per_call/none
isolation the fibers are put back in the idle_fibers set after boarding)

- Suppresses leak messages for shortfin nanobind objects with keep_awake
(maybe fix and rollback after V3.0)
- Makes port selection/error return much more friendly
- Segments configuration and suppresses some builder outputs for now.
- Makes logging look a little better, removes some extraneous outputs.
  • Loading branch information
monorimet authored Nov 18, 2024
1 parent abe6467 commit a968fc4
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 136 deletions.
17 changes: 8 additions & 9 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,15 @@ ConfigOptions CreateConfigOptions(std::optional<std::string> &env_prefix,
} // namespace

NB_MODULE(lib, m) {
// Tragically, debug builds of Python do the right thing and don't immortalize
// many identifiers and such. This makes the last chance leak checking that
// nanobind does somewhat unreliable since the reports it prints may be
// to identifiers that are no longer live (at a time in process shutdown
// where it is expected that everything left just gets dropped on the floor).
// This causes segfaults or ASAN violations in the leak checker on exit in
// certain scenarios where we have spurious "leaks" of global objects.
#if defined(Py_DEBUG)
// Tragically, debug builds of Python do the right thing and don't immortalize
// many identifiers and such. This makes the last chance leak checking that
// nanobind does somewhat unreliable since the reports it prints may be
// to identifiers that are no longer live (at a time in process shutdown
// where it is expected that everything left just gets dropped on the floor).
// This causes segfaults or ASAN violations in the leak checker on exit in
// certain scenarios where we have spurious "leaks" of global objects.

py::set_leak_warnings(false);
#endif

logging::InitializeFromEnv();

Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
try:
import transformers
except ModuleNotFoundError as e:
raise ShortfinDepNotFoundError(__name__, "diffusers") from e
raise ShortfinDepNotFoundError(__name__, "transformers") from e

try:
import tokenizers
Expand Down
8 changes: 3 additions & 5 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def needs_update(ctx):

def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
needed = True
if os.path.exists(out_file):
if url:
needed = not is_valid_size(out_file, url)
Expand All @@ -176,16 +177,13 @@ def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN):


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
vmfb_name = f"{filename}_{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace=namespace)


def get_cached_vmfb(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
vmfb_name = f"{filename}_{target}.vmfb"
return ctx.file(vmfb_name)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

from iree.build import *
from iree.build.executor import FileNamespace
import itertools
import os
import shortfin.array as sfnp
import copy

ARTIFACT_VERSION = "11182024"
SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/"
Expand Down Expand Up @@ -72,21 +69,18 @@ def sdxlconfig(
model_config_filenames = [f"{model}_config_i8.json"]
model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in model_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

topology_config_filenames = [f"topology_config_{topology}.txt"]
topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in topology_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

flagfile_filenames = [f"{model}_flagfile_{target}.txt"]
flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET)
for f, url in flagfile_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

Expand All @@ -95,7 +89,6 @@ def sdxlconfig(
)
tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET)
for f, url in tuning_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from dataclasses import dataclass
from pathlib import Path

import dataclasses_json
from dataclasses_json import dataclass_json, Undefined

import shortfin.array as sfnp
Expand Down
5 changes: 1 addition & 4 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import asyncio
import io
import logging
import json

import shortfin as sf
import shortfin.array as sfnp

# TODO: Have a generic "Responder" interface vs just the concrete impl.
from shortfin.interop.fastapi import FastAPIResponder

from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest, InferencePhase
from .messages import InferenceExecRequest
from .service import GenerateService
from .metrics import measure

Expand Down Expand Up @@ -83,7 +81,6 @@ def __init__(
self.batcher = service.batcher
self.complete_infeed = self.system.create_queue()

@measure(type="throughput", num_items="num_output_images", freq=1, label="samples")
async def run(self):
logger.debug("Started ClientBatchGenerateProcess: %r", self)
try:
Expand Down
4 changes: 1 addition & 3 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
from dataclasses import dataclass
import uuid

import shortfin.array as sfnp


@dataclass
class GenerateReqInput:
Expand Down
5 changes: 3 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logging.info(f"Created local system with {self.ls.device_names} devices")
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
Expand All @@ -39,9 +39,10 @@ def start(self):
def shutdown(self):
logger.info("Shutting down system manager")
self.command_queue.close()
self.ls.shutdown()

async def run(self):
reader = self.command_queue.reader()
while command := await reader():
...
logging.info("System manager command processor stopped")
logger.info("System manager command processor stopped")
3 changes: 1 addition & 2 deletions shortfin/python/shortfin_apps/sd/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import logging
import time
import asyncio
from typing import Callable, Any
from typing import Any
import functools

logger = logging.getLogger("shortfin-sd.metrics")
Expand Down
69 changes: 36 additions & 33 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

import asyncio
import logging
import math
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
from PIL import Image
import io
import base64

import shortfin as sf
Expand All @@ -23,9 +21,7 @@
from .tokenizer import Tokenizer
from .metrics import measure


logger = logging.getLogger("shortfin-sd.service")
logger.setLevel(logging.DEBUG)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
Expand Down Expand Up @@ -79,23 +75,32 @@ def __init__(

self.workers = []
self.fibers = []
self.fiber_status = []
self.idle_fibers = set()
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.workers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
self.workers.append(worker)
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.fibers_per_device):
fiber = sysman.ls.create_fiber(
self.workers[i % len(self.workers)], devices=[device]
)
tgt_worker = self.workers[i % len(self.workers)]
fiber = sysman.ls.create_fiber(tgt_worker, devices=[device])
self.fibers.append(fiber)
self.fiber_status.append(0)
self.idle_fibers.add(fiber)
for idx in range(len(self.workers)):
self.inference_programs[idx] = {}
self.inference_functions[idx] = {}
# Scope dependent objects.
self.batcher = BatcherProcess(self)

def get_worker_index(self, fiber):
if fiber not in self.fibers:
raise ValueError("A worker was requested from a rogue fiber.")
fiber_idx = self.fibers.index(fiber)
worker_idx = int(
(fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker
)
return worker_idx

def load_inference_module(self, vmfb_path: Path, component: str = None):
if not self.inference_modules.get(component):
self.inference_modules[component] = []
Expand All @@ -112,7 +117,7 @@ def load_inference_parameters(
):
p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope)
for path in paths:
logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path)
p.load(path, format=format)
if not self.inference_parameters.get(component):
self.inference_parameters[component] = []
Expand All @@ -121,6 +126,7 @@ def load_inference_parameters(
def start(self):
# Initialize programs.
for component in self.inference_modules:
logger.info(f"Loading component: {component}")
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
Expand All @@ -141,7 +147,6 @@ def start(self):
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
logger.info("Program loaded.")

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
Expand Down Expand Up @@ -270,14 +275,17 @@ def board_flights(self):
return
self.strobes = 0
batches = self.sort_batches()
for idx, batch in batches.items():
for fidx, status in enumerate(self.service.fiber_status):
if (
status == 0
or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL
):
self.board(batch["reqs"], index=fidx)
break
for batch in batches.values():
# Assign the batch to the next idle fiber.
if len(self.service.idle_fibers) == 0:
return
fiber = self.service.idle_fibers.pop()
fiber_idx = self.service.fibers.index(fiber)
worker_idx = self.service.get_worker_index(fiber)
logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})")
self.board(batch["reqs"], fiber=fiber)
if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(fiber)

def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
Expand Down Expand Up @@ -310,20 +318,18 @@ def sort_batches(self):
}
return batches

def board(self, request_bundle, index):
def board(self, request_bundle, fiber):
pending = request_bundle
if len(pending) == 0:
return
exec_process = InferenceExecutorProcess(self.service, index)
exec_process = InferenceExecutorProcess(self.service, fiber)
for req in pending:
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
exec_process.exec_requests.append(req)
if exec_process.exec_requests:
for flighted_request in exec_process.exec_requests:
self.pending_requests.remove(flighted_request)
if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL:
self.service.fiber_status[index] = 1
exec_process.launch()


Expand All @@ -338,15 +344,11 @@ class InferenceExecutorProcess(sf.Process):
def __init__(
self,
service: GenerateService,
index: int,
fiber,
):
super().__init__(fiber=service.fibers[index])
super().__init__(fiber=fiber)
self.service = service
self.fiber_index = index
self.worker_index = int(
(index - index % self.service.fibers_per_worker)
/ self.service.fibers_per_worker
)
self.worker_index = self.service.get_worker_index(fiber)
self.exec_requests: list[InferenceExecRequest] = []

@measure(type="exec", task="inference process")
Expand All @@ -360,7 +362,7 @@ async def run(self):
phase = req.phase
phases = self.exec_requests[0].phases
req_count = len(self.exec_requests)
device0 = self.service.fibers[self.fiber_index].device(0)
device0 = self.fiber.device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
Expand All @@ -375,7 +377,8 @@ async def run(self):
for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
self.service.fiber_status[self.fiber_index] = 0
if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER:
self.service.idle_fibers.add(self.fiber)

except Exception:
logger.exception("Fatal error in image generation")
Expand Down Expand Up @@ -574,7 +577,7 @@ async def _denoise(self, device, requests):
for i, t in tqdm(
enumerate(range(step_count)),
disable=(not self.service.show_progress),
desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})",
desc=f"DENOISE (bs{req_bs})",
):
step = sfnp.device_array.for_device(device, [1], sfnp.sint64)
s_host = step.for_transfer()
Expand Down
4 changes: 0 additions & 4 deletions shortfin/python/shortfin_apps/sd/components/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path

from transformers import CLIPTokenizer, BatchEncoding

import numpy as np

import shortfin as sf
import shortfin.array as sfnp

Expand Down
Loading

0 comments on commit a968fc4

Please sign in to comment.