Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 27, 2024
1 parent db3c8f3 commit 86634fa
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from optimum_benchmark.generators.dataset_generator import DatasetGenerator
from optimum_benchmark.generators.input_generator import InputGenerator
from optimum_benchmark.import_utils import get_git_revision_hash
from optimum_benchmark.scenarios.inference.config import INPUT_SHAPES
from optimum_benchmark.scenarios.training.config import DATASET_SHAPES
from optimum_benchmark.system_utils import is_nvidia_system, is_rocm_system
from optimum_benchmark.trackers import LatencyTracker, MemoryTracker

Expand All @@ -40,16 +38,25 @@
("diffusers", "text-to-image", "CompVis/stable-diffusion-v1-4"),
]

INPUT_SHAPES = {
"batch_size": 2, # for all tasks
"sequence_length": 16, # for text processing tasks
"num_choices": 2, # for multiple-choice task
}

DATASET_SHAPES = {
"dataset_size": 2, # for all tasks
"sequence_length": 16, # for text processing tasks
"num_choices": 2, # for multiple-choice task
}


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("scenario", ["training", "inference"])
@pytest.mark.parametrize("library,task,model", LIBRARIES_TASKS_MODELS)
def test_api_launch(device, scenario, library, task, model):
benchmark_name = f"{device}_{scenario}_{library}_{task}_{model}"

if task == "multiple-choice":
INPUT_SHAPES["num_choices"] = 2

if device == "cuda":
device_isolation = True
if is_rocm_system():
Expand Down Expand Up @@ -173,9 +180,6 @@ def test_api_input_generator(library, task, model):
else:
raise ValueError(f"Unknown library {library}")

if task == "multiple-choice":
INPUT_SHAPES["num_choices"] = 2

input_generator = InputGenerator(
task=task,
input_shapes=INPUT_SHAPES,
Expand Down

0 comments on commit 86634fa

Please sign in to comment.