Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 19, 2024
1 parent 8029a71 commit ada6525
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,23 @@ def test_api_dataset_generator(library, task, model):
_ = generator()


@pytest.mark.parametrize("launcher_config", LAUNCHER_CONFIGS)
@pytest.mark.parametrize("device", DEVICES)
def test_api_launch(launcher_config, device):
@pytest.mark.parametrize("launcher_config", LAUNCHER_CONFIGS)
def test_api_launch(device, launcher_config):
benchmark_config = InferenceConfig(latency=True, memory=True)
device_ids = ",".join(str(i) for i in range(torch.cuda.device_count())) if device == "cuda" else None
backend_config = PyTorchConfig(model="bert-base-uncased", device_ids=device_ids, no_weights=True, device=device)
backend_config = PyTorchConfig(
model="bert-base-uncased",
device_ids="0,1" if device == "cuda" else None,
no_weights=True,
device=device,
)
experiment_config = ExperimentConfig(
experiment_name="api-experiment", benchmark=benchmark_config, launcher=launcher_config, backend=backend_config
experiment_name="api-experiment",
benchmark=benchmark_config,
launcher=launcher_config,
backend=backend_config,
)

benchmark_report = launch(experiment_config)

with TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit ada6525

Please sign in to comment.