Skip to content

Commit

Permalink
simplify signature of .inference()
Browse files Browse the repository at this point in the history
  • Loading branch information
afrendeiro committed Jul 8, 2024
1 parent 9e3b866 commit 0a220aa
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions wsi/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,8 +1819,7 @@ def as_data_loader(

def inference(
self,
model: torch.nn.Module | None = None,
model_name: str | None = None,
model: torch.nn.Module | str | None = None,
model_repo: str = "pytorch/vision",
device: str | None = None,
data_loader_kws: dict = {},
Expand All @@ -1847,11 +1846,13 @@ def inference(
from tqdm_loggable.auto import tqdm

if isinstance(model, torch.nn.Module):
assert model_name is None, "model_name must be None when model is provided"
model = cast(torch.nn.Module, model)
elif model_name is not None:
assert model is None, "model must be None when model_name is provided"
model = torch.hub.load(model_repo, model_name, weights="DEFAULT")
elif isinstance(model, str):
model = torch.hub.load(model_repo, model, weights="DEFAULT")
else:
raise ValueError(
f"model must be a string or a torch.nn.Module, not {type(model)}"
)

if device is None:
device = device or "cuda" if torch.cuda.is_available() else "cpu"
Expand Down

0 comments on commit 0a220aa

Please sign in to comment.