Skip to content

Commit

Permalink
Fix feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Jul 2, 2024
1 parent 51f8b0d commit df35623
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
2 changes: 1 addition & 1 deletion lazyslide/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __getitem__(self, idx):
tile = self.reader.get_region(
x, y, self.spec.width, self.spec.height, level=self.spec.level
)
self.cn_func(tile)
tile = self.cn_func(tile)
if self.transform:
tile = self.transform(tile)
if self.targets is not None:
Expand Down
70 changes: 48 additions & 22 deletions lazyslide/tl/features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager
from pathlib import Path
Expand Down Expand Up @@ -34,14 +35,32 @@ def get_default_transform():
return Compose(transforms)


def load_models(model, repo=None, **kwargs):
"""Load a model with timm or torch.hub.load"""
try:
import timm
except ImportError:
# use torch.hub.load
import torch

if repo is None:
repo = "pytorch/vision"
model = torch.hub.load(repo, model, **kwargs)
return model

kwargs = {"pretrained": True, "scriptable": True, **kwargs}
model = timm.create_model(model, **kwargs)
return model


# TODO: Test if it's possible to load model files
def feature_extraction(
wsi: WSI,
model: str | Any,
repo: str = None,
create_opts: dict = None,
model_func: Callable = None,
transform: Callable = None,
scriptable: bool = True,
compile: bool = True,
compile_opts: dict = None,
device: str = "cpu",
Expand All @@ -54,13 +73,23 @@ def feature_extraction(
return_features: bool = False,
**kwargs,
):
"""
Extract features from WSI tiles using a pre-trained model.
Parameters
----------
wsi : WSI
The whole-slide image object.
model : str or Any
The path to the model file or the model object.
"""
try:
import torch
from torch.utils.data import DataLoader
except ImportError:
raise ImportError("Feature extraction requires pytorch and timm (optional).")

try:
if isinstance(model, (str, Path)):
model_path = Path(model)
feature_key = model_path.stem
if model_path.exists():
Expand All @@ -69,35 +98,33 @@ def feature_extraction(
except: # noqa: E722
model = torch.jit.load(model)
else:
try:
import timm
except ImportError:
raise ImportError("Using model from model market requires timm.")
try:
create_opts = {} if create_opts is None else create_opts
feature_key = model
model = timm.create_model(
model, pretrained=True, scriptable=scriptable, **create_opts
)
if transform is None:
# data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
# transform = timm.data.create_transform(**data_cfg)
transform = get_default_transform()
except Exception as _: # noqa: E722
raise ValueError(f"Model {model} not found.")
except: # noqa: E722
create_opts = {} if create_opts is None else create_opts
model = load_models(model, repo=repo, **create_opts)
elif isinstance(model, Callable):
model = model
else:
raise ValueError(
"Model must be a model name, "
"path to the model file, "
"or a model object."
)

if compile:
compile_opts = {} if compile_opts is None else compile_opts
torch.compile(model, **compile_opts)
try:
compile_opts = {} if compile_opts is None else compile_opts
torch.compile(model, **compile_opts)
except Exception as _: # noqa: E722
warnings.warn("Failed to compile the model.", RuntimeWarning)

try:
model = model.to(device)
model.eval()
except: # noqa: E722
pass

if transform is None:
transform = get_default_transform()

if feature_key is None:
if hasattr(model, "__class__"):
feature_key = model.__class__.__name__
Expand Down Expand Up @@ -162,7 +189,6 @@ def model_func(model, image):
with torch.inference_mode():
for batch in loader:
batch = batch.to(device)
print(batch.shape)
output = model_func(model, batch)
features.append(output.cpu().numpy())
pbar.update(task, advance=len(batch))
Expand Down

0 comments on commit df35623

Please sign in to comment.