Skip to content

Commit

Permalink
avoid recreating trackers
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 4, 2024
1 parent d36a9f2 commit 74d15d2
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 107 deletions.
37 changes: 11 additions & 26 deletions optimum_benchmark/scenarios/energy_star/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,12 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
self.logger.info("\t+ Updating Image Diffusion kwargs with default values")
self.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs}
self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs}
self.logger.info("\t+ Initializing Image Diffusion report")
self.report = BenchmarkReport.from_list(
targets=["load_dataset", "preprocess_dataset", "load_model", "call"]
)
else:
self.logger.info("\t+ Updating Inference kwargs with default values")
self.forward_kwargs = {**self.config.forward_kwargs}
self.logger.info("\t+ Initializing Inference report")
self.report = BenchmarkReport.from_list(
targets=["load_dataset", "preprocess_dataset", "load_model", "forward"]
Expand Down Expand Up @@ -133,26 +131,20 @@ def init_trackers(self, backend: Backend[BackendConfigT]):
def track(self, task_name: str):
with ExitStack() as context_stack:
if self.config.energy:
self.energy_tracker.reset()
context_stack.enter_context(self.energy_tracker.track(task_name=task_name))
if self.config.memory:
self.memory_tracker.reset()
context_stack.enter_context(self.memory_tracker.track())
if self.config.latency:
self.latency_tracker.reset()
context_stack.enter_context(self.latency_tracker.track())
yield

def reset_trackers(self):
if self.config.latency:
self.latency_tracker.reset()
if self.config.memory:
self.memory_tracker.reset()
if self.config.energy:
self.energy_tracker.reset()

# Dataset loading tracking
def run_dataset_loading_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running dataset loading tracking")

self.reset_trackers()
with self.track(task_name="load_dataset"):
self.dataset = load_dataset(
self.config.dataset_name, self.config.dataset_config, split=self.config.dataset_split
Expand All @@ -169,7 +161,6 @@ def run_dataset_loading_tracking(self, backend: Backend[BackendConfigT]):
def run_dataset_preprocessing_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running dataset preprocessing tracking")

self.reset_trackers()
with self.track(task_name="preprocess_dataset"):
self.dataset = TASKS_TO_PREPROCESSORS[backend.config.task](
dataset=self.dataset,
Expand Down Expand Up @@ -199,7 +190,6 @@ def run_dataset_preprocessing_tracking(self, backend: Backend[BackendConfigT]):
def run_model_loading_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running model loading energy tracking")

self.reset_trackers()
with self.track(task_name="load_model"):
backend.load()

Expand All @@ -212,34 +202,32 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]):

# Text Generation warmup
def warmup_text_generation(self, backend: Backend[BackendConfigT]):
warmup_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_WARMUP_OVERRIDES}
self.logger.info("\t+ Warming up backend for Text Generation")
backend.generate(self.sample_inputs, self.config.generate_kwargs)
warmup_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_WARMUP_OVERRIDES}
for _ in range(self.config.warmup_runs):
backend.generate(self.sample_inputs, warmup_kwargs)

# Image Diffusion warmup
def warmup_image_diffusion(self, backend: Backend[BackendConfigT]):
warmup_kwargs = {**self.config.call_kwargs, **IMAGE_DIFFUSION_WARMUP_OVERRIDES}
self.logger.info("\t+ Warming up backend for Image Diffusion")
backend.call(self.sample_inputs, self.call_kwargs)
warmup_kwargs = {**self.call_kwargs, **IMAGE_DIFFUSION_WARMUP_OVERRIDES}
backend.call(self.sample_inputs, self.config.call_kwargs)
for _ in range(self.config.warmup_runs):
backend.call(self.sample_inputs, warmup_kwargs)

# Inference warmup
def warmup_inference(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Warming up backend for Inference")
warmup_kwargs = {**self.forward_kwargs}
for _ in range(self.config.warmup_runs):
backend.forward(self.sample_inputs, warmup_kwargs)
backend.forward(self.sample_inputs, self.config.forward_kwargs)

# Text Generation energy tracking
def run_text_generation_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running Text Generation tracking")

prefill_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_PREFILL_OVERRIDES}

self.reset_trackers()
with self.track(task_name="prefill"):
for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
inputs = backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
Expand All @@ -262,7 +250,6 @@ def run_text_generation_tracking(self, backend: Backend[BackendConfigT]):
if self.config.memory:
self.report.prefill.memory = self.memory_tracker.get_max_memory()

self.reset_trackers()
with self.track(task_name="generate"):
for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
inputs = backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
Expand Down Expand Up @@ -291,11 +278,10 @@ def run_text_generation_tracking(self, backend: Backend[BackendConfigT]):
def run_image_diffusion_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running Image Diffusion tracking")

self.reset_trackers()
with self.track(task_name="call"):
for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
inputs = backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
backend.call(inputs, self.call_kwargs)
backend.call(inputs, self.config.call_kwargs)

if self.config.energy:
call_energy = self.energy_tracker.get_energy()
Expand All @@ -314,11 +300,10 @@ def run_image_diffusion_tracking(self, backend: Backend[BackendConfigT]):
def run_inference_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running Inference tracking")

self.reset_trackers()
with self.track(task_name="forward"):
for i in tqdm(range(0, self.config.num_samples, self.config.input_shapes["batch_size"])):
inputs = backend.prepare_inputs(self.dataset[i : i + self.config.input_shapes["batch_size"]])
backend.forward(inputs, self.forward_kwargs)
backend.forward(inputs, self.config.forward_kwargs)

if self.config.energy:
forward_energy = self.energy_tracker.get_energy()
Expand Down Expand Up @@ -360,6 +345,6 @@ def dataset_decode_volume(self) -> int: # in terms of generated tokens
@property
def dataset_call_volume(self) -> int: # in terms of generated images
if self.task == "text-to-image":
return self.config.num_samples * self.call_kwargs["num_images_per_prompt"]
return self.config.num_samples * self.config.call_kwargs["num_images_per_prompt"]
else:
return self.config.num_samples
Loading

0 comments on commit 74d15d2

Please sign in to comment.