Skip to content

Commit

Permalink
add option to plot each tissue piece separately in plot_segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
afrendeiro committed Jun 28, 2024
1 parent e714b9f commit 71961a7
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions wsi/wsi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations
import multiprocessing as mp

from pathlib import Path
import typing as tp
import math
import time
import typing as tp
from pathlib import Path
import multiprocessing as mp

from PIL import Image
import cv2
import numpy as np
import openslide
from PIL import Image
import h5py
import shapely
import torch
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
WholeSlideImage
WholeSlideImage object.
"""
from .utils import is_url, download_file
from wsi.utils import is_url, download_file

if not isinstance(path, Path):
if is_url(path):
Expand Down Expand Up @@ -303,7 +304,7 @@ def vis_wsi(
number_contours: bool = False,
seg_display: bool = True,
annot_display: bool = True,
) -> None:
) -> Image:
"""
Visualize the whole slide image.
Expand Down Expand Up @@ -338,7 +339,7 @@ def vis_wsi(
Returns
-------
None
Image
"""
downsample = self.level_downsamples[vis_level]
scale = [1 / downsample[0], 1 / downsample[1]]
Expand Down Expand Up @@ -829,7 +830,13 @@ def segment(
self.save_segmentation()
self.plot_segmentation()

def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> None:
def plot_segmentation(
self,
output_file: tp.Optional[Path] = None,
per_contour: bool = False,
level: int | None = None,
**kwargs,
) -> None:
"""
Plot the segmentation of the WSI.
Expand All @@ -852,8 +859,20 @@ def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) ->
if output_file is None:
output_file = self.path.with_suffix(".segmentation.png")

level = self._get_best_level((2000, 2000))
self.vis_wsi(vis_level=level, **kwargs).save(output_file)
if level is None:
level = self._get_best_level((2000, 2000))
if not per_contour:
self.vis_wsi(vis_level=level, **kwargs).save(output_file)
return

for i, piece in enumerate(self.contours_tissue):
poly = shapely.geometry.Polygon(piece.squeeze())
top_left, bottom_right = np.asarray(poly.bounds).reshape(2, 2).astype(int)
img = self.vis_wsi(
vis_level=level, top_left=top_left, bot_right=bottom_right, **kwargs
)
n = str(i).zfill(3)
img.save(output_file.with_suffix(f".{n}.png"))

def tile(
self,
Expand Down Expand Up @@ -1716,7 +1735,7 @@ def as_tile_bag(self, **kwargs):
kwargs : dict
Additional keyword arguments to pass to wsi.utils.WholeSlideBag.
"""
from .utils import WholeSlideBag
from wsi.utils import WholeSlideBag

# dataset = Whole_Slide_Bag(self.hdf5_file, pretrained=True)
dataset = WholeSlideBag(
Expand Down Expand Up @@ -1748,7 +1767,7 @@ def as_data_loader(
torch.utils.data.DataLoader
"""
from functools import partial
from .utils import collate_features
from wsi.utils import collate_features
from torch.utils.data import DataLoader

collate = partial(collate_features, with_coords=with_coords)
Expand Down Expand Up @@ -1789,7 +1808,7 @@ def inference(
"""
from typing import cast
import torch
from tqdm import tqdm
from tqdm_loggable.auto import tqdm

if isinstance(model, torch.nn.Module):
assert model_name is None, "model_name must be None when model is provided"
Expand Down

0 comments on commit 71961a7

Please sign in to comment.