From 501c42df853a19e8a1d4d6ec7c2ccca5da4309c8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 9 Dec 2024 04:29:28 +0100 Subject: [PATCH] fix --- optimum_benchmark/backends/ipex/backend.py | 1 - tests/test_api.py | 18 ++++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index b584ff6c..7e4983a9 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -63,7 +63,6 @@ def _load_ipexmodel_from_pretrained(self) -> None: self.pretrained_model = self.ipexmodel_class.from_pretrained( self.config.model, export=self.config.export, - device=self.config.device, **self.config.model_kwargs, **self.automodel_kwargs, ) diff --git a/tests/test_api.py b/tests/test_api.py index 34598a02..9122eb49 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -229,10 +229,10 @@ def test_api_dataset_generator(library, task, model): def test_api_latency_tracker(device, backend): tracker = LatencyTracker(device=device, backend=backend) - tracker.reset() - while tracker.elapsed() < 2: - with tracker.track(): - time.sleep(1) + with tracker.session(): + while tracker.elapsed() < 2: + with tracker.track(): + time.sleep(1) latency = tracker.get_latency() latency.log() @@ -241,10 +241,10 @@ def test_api_latency_tracker(device, backend): assert latency.mean > 0.9 assert len(latency.values) == 2 - tracker.reset() - while tracker.count() < 2: - with tracker.track(): - time.sleep(1) + with tracker.session(): + while tracker.count() < 2: + with tracker.track(): + time.sleep(1) latency = tracker.get_latency() latency.log() @@ -273,7 +273,6 @@ def test_api_memory_tracker(device, backend): tracker = MemoryTracker(device=device, backend=backend, device_ids=device_ids) - tracker.reset() with tracker.track(): time.sleep(1) pass @@ -281,7 +280,6 @@ def test_api_memory_tracker(device, backend): initial_memory = tracker.get_max_memory() initial_memory.log() - tracker.reset() with tracker.track(): array = torch.randn((10000, 10000), dtype=torch.float64, device=device) expected_memory = array.nbytes / 1e6