From db3c8f3dc7230e6de04cfbc8fed59be91723f12e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 27 Nov 2024 12:12:10 +0100 Subject: [PATCH] fix --- .../scenarios/inference/scenario.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index e05cb7b9..c7faffed 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -66,15 +66,17 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.logger.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs} self.logger.info("\t+ Initializing Text Generation report") - self.report = BenchmarkReport.from_list(targets=["load", "prefill", "decode", "per_token"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "prefill", "decode", "per_token"]) elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.logger.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs} self.logger.info("\t+ Initializing Image Diffusion report") - self.report = BenchmarkReport.from_list(targets=["load", "call"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "call"]) else: self.logger.info("\t+ Initializing Inference report") - self.report = BenchmarkReport.from_list(targets=["load", "forward"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "forward"]) + + self.run_model_loading_tracking(backend) self.logger.info("\t+ Creating input generator") self.input_generator = InputGenerator( @@ -83,15 +85,11 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: input_shapes=self.config.input_shapes, model_type=backend.config.model_type, ) - self.logger.info("\t+ Generating inputs") self.inputs = self.input_generator() - - self.logger.info("\t+ Preparing inputs for Inference") + self.logger.info("\t+ Preparing inputs for backend") self.inputs = backend.prepare_inputs(inputs=self.inputs) - self.run_model_loading_tracking(backend) - if self.config.latency or self.config.energy: # latency and energy are metrics that require some warmup if self.config.warmup_runs > 0: @@ -159,8 +157,14 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): ) if self.config.latency: latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + if self.config.energy: + energy_tracker = EnergyTracker( + backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids + ) with ExitStack() as context_stack: + if self.config.energy: + context_stack.enter_context(energy_tracker.track()) if self.config.memory: context_stack.enter_context(memory_tracker.track()) if self.config.latency: @@ -169,9 +173,11 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): backend.load() if self.config.latency: - self.report.load.latency = latency_tracker.get_latency() + self.report.load_model.latency = latency_tracker.get_latency() if self.config.memory: - self.report.load.memory = memory_tracker.get_max_memory() + self.report.load_model.memory = memory_tracker.get_max_memory() + if self.config.energy: + self.report.load_model.energy = energy_tracker.get_energy() ## Memory tracking def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):