Skip to content

Commit

Permalink
improve docs; more customization of kwargs in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
afrendeiro committed Jun 5, 2024
1 parent c97d9b7 commit e714b9f
Showing 1 changed file with 63 additions and 16 deletions.
79 changes: 63 additions & 16 deletions wsi/wsi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import multiprocessing as mp
import math
import time
Expand All @@ -9,6 +10,8 @@
import openslide
from PIL import Image
import h5py
import shapely
import torch

from wsi.utils import (
isInContourV1,
Expand Down Expand Up @@ -302,10 +305,10 @@ def vis_wsi(
annot_display: bool = True,
) -> None:
"""
Visualize the whole slide image.
Visualize the whole slide image.
Parameters
----------
Parameters
----------
vis_level: int
The level to visualize.
color: tuple
Expand Down Expand Up @@ -333,9 +336,9 @@ def vis_wsi(
annot_display: bool
Whether to display the annotations.
Returns
-------
None
Returns
-------
None
"""
downsample = self.level_downsamples[vis_level]
scale = [1 / downsample[0], 1 / downsample[1]]
Expand Down Expand Up @@ -619,7 +622,6 @@ def _segment_tissue_manual(
"""
import skimage
import scipy.ndimage as ndi
import shapely

assert color_space in ["RGB", "HED"], "color_space must be RGB or HED."

Expand Down Expand Up @@ -1136,8 +1138,8 @@ def get_tile_images(
Returns
-------
np.ndarray
Array of tile images with shape (N, 3, H, W).
generator
Each element is an array of with shape (3, H, W).
"""
if hdf5_file is None:
hdf5_file = self.hdf5_file # or self.tile_h5
Expand Down Expand Up @@ -1702,31 +1704,68 @@ def _get_seg_mask(self, region_size, scale, use_holes=False, offset=(0, 0)):
)
return tissue_mask

def as_tile_bag(self):
def as_tile_bag(self, **kwargs):
"""
Return a torch.dataset of tiles from the whole slide image.
Can be customized for example in which transform functions are used using the keyword arguments.
Check wsi.utils.WholeSlideBag for more details.
Parameters
----------
kwargs : dict
Additional keyword arguments to pass to wsi.utils.WholeSlideBag.
"""
from .utils import WholeSlideBag

# dataset = Whole_Slide_Bag(self.hdf5_file, pretrained=True)
dataset = WholeSlideBag(
self.hdf5_file, self.wsi, pretrained=True, target=self.target
self.hdf5_file, self.wsi, pretrained=True, target=self.target, **kwargs
)
return dataset

def as_data_loader(self, batch_size: int = 32, with_coords: bool = False, **kwargs):
def as_data_loader(
self,
batch_size: int = 32,
with_coords: bool = False,
tile_bag_kwargs: dict = {},
data_loader_kwargs: dict = {},
) -> torch.utils.data.DataLoader:
"""
Return a data loader for the whole slide image.
Parameters
----------
batch_size : int
Number of images per batch in data loader. Default is 32.
with_coords : bool
Whether to include coordinates in data loader. Default is False.
kwargs : dict
Additional keyword arguments to pass to torch.utils.data.DataLoader.
Returns
-------
torch.utils.data.DataLoader
"""
from functools import partial
from .utils import collate_features
from torch.utils.data import DataLoader

collate = partial(collate_features, with_coords=with_coords)

dataset = self.as_tile_bag()
dataset = self.as_tile_bag(**tile_bag_kwargs)
loader = DataLoader(
dataset=dataset, batch_size=batch_size, collate_fn=collate, **kwargs
dataset=dataset,
batch_size=batch_size,
collate_fn=collate,
**data_loader_kwargs,
)
return loader

def inference(
self,
model_name: str,
model: torch.nn.Module | None = None,
model_name: str | None = None,
model_repo: str = "pytorch/vision",
device: str | None = None,
data_loader_kws: dict = {},
Expand All @@ -1748,15 +1787,23 @@ def inference(
Tuple[np.ndarray, np.ndarray]
Tuple of (features, coordinates).
"""
from typing import cast
import torch
from tqdm import tqdm

if isinstance(model, torch.nn.Module):
assert model_name is None, "model_name must be None when model is provided"
model = cast(torch.nn.Module, model)
elif model_name is not None:
assert model is None, "model must be None when model_name is provided"
model = torch.hub.load(model_repo, model_name, weights="DEFAULT")

if device is None:
device = device or "cuda" if torch.cuda.is_available() else "cpu"

data_loader = self.as_data_loader(**data_loader_kws, with_coords=True)
model = torch.hub.load(model_repo, model_name, weights="DEFAULT").to(device)
model.eval()
model = model.to(device)
coords = list()
feats = list()
for batch, coord in tqdm(data_loader):
Expand Down

0 comments on commit e714b9f

Please sign in to comment.