diff --git a/wsi/wsi.py b/wsi/wsi.py index 880460c..d81d8ac 100755 --- a/wsi/wsi.py +++ b/wsi/wsi.py @@ -1,14 +1,15 @@ from __future__ import annotations -import multiprocessing as mp + +from pathlib import Path +import typing as tp import math import time -import typing as tp -from pathlib import Path +import multiprocessing as mp +from PIL import Image import cv2 import numpy as np import openslide -from PIL import Image import h5py import shapely import torch @@ -84,7 +85,7 @@ def __init__( WholeSlideImage WholeSlideImage object. """ - from .utils import is_url, download_file + from wsi.utils import is_url, download_file if not isinstance(path, Path): if is_url(path): @@ -303,7 +304,7 @@ def vis_wsi( number_contours: bool = False, seg_display: bool = True, annot_display: bool = True, - ) -> None: + ) -> Image: """ Visualize the whole slide image. @@ -338,7 +339,7 @@ def vis_wsi( Returns ------- - None + Image """ downsample = self.level_downsamples[vis_level] scale = [1 / downsample[0], 1 / downsample[1]] @@ -829,7 +830,13 @@ def segment( self.save_segmentation() self.plot_segmentation() - def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> None: + def plot_segmentation( + self, + output_file: tp.Optional[Path] = None, + per_contour: bool = False, + level: int | None = None, + **kwargs, + ) -> None: """ Plot the segmentation of the WSI. @@ -852,8 +859,20 @@ def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> if output_file is None: output_file = self.path.with_suffix(".segmentation.png") - level = self._get_best_level((2000, 2000)) - self.vis_wsi(vis_level=level, **kwargs).save(output_file) + if level is None: + level = self._get_best_level((2000, 2000)) + if not per_contour: + self.vis_wsi(vis_level=level, **kwargs).save(output_file) + return + + for i, piece in enumerate(self.contours_tissue): + poly = shapely.geometry.Polygon(piece.squeeze()) + top_left, bottom_right = np.asarray(poly.bounds).reshape(2, 2).astype(int) + img = self.vis_wsi( + vis_level=level, top_left=top_left, bot_right=bottom_right, **kwargs + ) + n = str(i).zfill(3) + img.save(output_file.with_suffix(f".{n}.png")) def tile( self, @@ -1716,7 +1735,7 @@ def as_tile_bag(self, **kwargs): kwargs : dict Additional keyword arguments to pass to wsi.utils.WholeSlideBag. """ - from .utils import WholeSlideBag + from wsi.utils import WholeSlideBag # dataset = Whole_Slide_Bag(self.hdf5_file, pretrained=True) dataset = WholeSlideBag( @@ -1748,7 +1767,7 @@ def as_data_loader( torch.utils.data.DataLoader """ from functools import partial - from .utils import collate_features + from wsi.utils import collate_features from torch.utils.data import DataLoader collate = partial(collate_features, with_coords=with_coords) @@ -1789,7 +1808,7 @@ def inference( """ from typing import cast import torch - from tqdm import tqdm + from tqdm_loggable.auto import tqdm if isinstance(model, torch.nn.Module): assert model_name is None, "model_name must be None when model is provided"