diff --git a/.github/workflows/pytest_workflow.yml b/.github/workflows/pytest_workflow.yml new file mode 100644 index 0000000..6d1de60 --- /dev/null +++ b/.github/workflows/pytest_workflow.yml @@ -0,0 +1,27 @@ +name: Pytest testing + +on: [push] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y openslide-tools + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest wsi --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html diff --git a/.gitignore b/.gitignore index 775c4b7..02f9b90 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,13 @@ __pycache__/ *.egg-info build/ +dist +.coverage +cache +junit +joblib +__pycache__ +.mypy_cache +coverage.xml +_version.py +*.sublime-* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d83a878 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +clean: + -rm -rf build + -rm -rf dist + -rm -rf *.egg-info + -rm -rf .coverage + -rm -rf cache + -rm -rf junit + -rm -rf joblib + -rm -rf __pycache__ + -rm -rf .mypy_cache + -rm -rf htmlcov + -rm coverage.xml + # -rm -rf .pytest_cache + +test: clean + pytest wsi \ + --doctest-modules \ + --junitxml=junit/test-results.xml \ + --cov=wsi \ + --cov-report=xml \ + --cov-report=html + +install: clean + python -m pip install -e . diff --git a/README.md b/README.md index 61e5a95..d386e63 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -CLAM +WSI ==== This is a fork of the repository from [Mahmood lab's CLAM repository](https://github.com/mahmoodlab/CLAM). It is made available under the GPLv3 License and is available for non-commercial academic purposes. @@ -8,7 +8,7 @@ It is made available under the GPLv3 License and is available for non-commercial The purpose of the fork is to compartimentalize the features related with processing of whole-slide images (WSI) from the CLAM model. -The package has been renamed to `wsi_core` as that was the name of the module related with whole slide image processing. +The package has been renamed to `wsi`. ## Installation @@ -17,8 +17,9 @@ While the repository is private, make sure you [exchange SSH keys of the machine Then simply install with `pip`: ```bash -git clone git@github.com:rendeirolab/CLAM.git -cd CLAM +# pip install git+ssh://git@github.com:rendeirolab/wsi.git +git clone git@github.com:rendeirolab/wsi.git +cd wsi pip install . ``` @@ -26,6 +27,21 @@ Note that the package uses setuptols-scm for version control and therefore the i ## Usage +The only exposed class is `WholeSlideImage` enables all the functionalities of the package. + +### Quick start - segmentation, tiling and feature extraction +```python +from wsi import WholeSlideImage + +url = "https://brd.nci.nih.gov/brd/imagedownload/GTEX-O5YU-1426" +slide = WholeSlideImage(url) +slide.segment() +slide.tile() +feats, coords = slide.inference("resnet18") +``` + +### Full example + This package is meant for both interactive use and for use in a pipeline at scale. By default actions do not return anything, but instead save the results to disk in files relative to the slide file. @@ -33,8 +49,8 @@ All major functions have sensible defaults but allow for customization. Please check the docstring of each function for more information. ```python -from wsi_core import WholeSlideImage -from wsi_core.utils import Path +from wsi import WholeSlideImage +from wsi.utils import Path # Get example slide image slide_file = Path("GTEX-12ZZW-2726.svs") @@ -48,7 +64,7 @@ if not slide_file.exists(): # Instantiate slide object slide = WholeSlideImage(slide_file) -# Instantiate slide object +# Instantiation can be done with custom attributes slide = WholeSlideImage(slide_file, attributes=dict(donor="GTEX-12ZZW")) # Segment tissue (segmentation mask is stored as polygons in slide.contours_tissue) @@ -75,15 +91,28 @@ for img in images: slide.save_tile_images(output_dir=slide_file.parent / (slide_file.stem + "_tiles")) # Use in a torch dataloader -loader = slide.as_data_loader() +loader = slide.as_data_loader(with_coords=True) -# Extract features +# Extract features "manually" import torch from tqdm import tqdm -model = torch.hub.load("pytorch/vision", "resnet50", pretrained=True) -for count, (batch, coords) in tqdm(enumerate(loader), total=len(loader)): +model = torch.hub.load("pytorch/vision", "resnet18", weights="DEFAULT") +feats = list() +coords = list() +for count, (batch, yx) in tqdm(enumerate(loader), total=len(loader)): with torch.no_grad(): - features = model(batch).numpy() + f = model(batch).numpy() + feats.append(f) + coords.append(yx) + +feats = np.concatenate(feats, axis=0) +coords = np.concatenate(coords, axis=0) + +# Extract features "automatically" +feats, coords = slide.inference('resnet18') + +# Additional parameters can also be specified +feats, coords = slide.inference('resnet18', device='cuda', data_loader_kws=dict(batch_size=512)) ``` ## Reference diff --git a/pyproject.toml b/pyproject.toml index a88a2fb..fec58ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ # PIP, using PEP621 [project] -name = "wsi_core" +name = "wsi" authors = [ {name = "Andre Rendeiro", email = "arendeiro@cemm.at"}, ] @@ -11,8 +11,8 @@ keywords = [ ] classifiers = [ "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Development Status :: 3 - Alpha", "Typing :: Typed", "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", @@ -21,14 +21,21 @@ classifiers = [ #license = "gpt3" requires-python = ">=3.10" dependencies = [ - "opencv-python", "h5py", "matplotlib", "numpy", + "opencv-python", "openslide-python", + "pandas", "Pillow", + "requests", + "scikit-image", + "scikit-learn", + "scipy", + "shapely", "torch", "torchvision", + "tqdm", ] dynamic = ['version'] @@ -51,9 +58,9 @@ doc = [ ] [project.urls] -homepage = "https://github.com/rendeirolab/CLAM" -documentation = "https://github.com/rendeirolab/CLAM/blob/main/README.md" -repository = "https://github.com/rendeirolab/CLAM" +homepage = "https://github.com/rendeirolab/wsi" +documentation = "https://github.com/rendeirolab/wsi/blob/main/README.md" +repository = "https://github.com/rendeirolab/wsi" [build-system] # requires = ["poetry>=0.12", "setuptools>=45", "wheel", "poetry-dynamic-versioning"] @@ -62,7 +69,7 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.0"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] -write_to = "wsi_core/_version.py" +write_to = "wsi/_version.py" write_to_template = 'version = __version__ = "{version}"' [tool.black] @@ -104,7 +111,7 @@ module = [ 'matplotlib.*', 'networkx.*', # - 'wsi_core.*' + 'wsi.*' ] ignore_missing_imports = true @@ -117,5 +124,5 @@ testpaths = [ ] markers = [ 'slow', # 'marks tests as slow (deselect with "-m 'not slow'")', - 'serial' -] \ No newline at end of file + "wsi" +] diff --git a/requirements.txt b/requirements.txt index 39cbd2d..dc28c46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,15 @@ -opencv-python h5py matplotlib numpy +opencv-python openslide-python +pandas Pillow +requests +scikit-image +scikit-learn +scipy +shapely torch torchvision +tqdm diff --git a/wsi/__init__.py b/wsi/__init__.py new file mode 100644 index 0000000..c5fcaf1 --- /dev/null +++ b/wsi/__init__.py @@ -0,0 +1,2 @@ +from .wsi import WholeSlideImage +from ._version import version, __version__ diff --git a/wsi/tests/test_wsi.py b/wsi/tests/test_wsi.py new file mode 100644 index 0000000..5dd6f5f --- /dev/null +++ b/wsi/tests/test_wsi.py @@ -0,0 +1,44 @@ +from pathlib import Path +import tempfile +import joblib + +import requests +import pytest +from wsi import WholeSlideImage +import numpy as np + + +mem = joblib.Memory("cache", verbose=0) + + +@pytest.fixture(scope="session") +@mem.cache +def get_test_slide(): + slide_file = Path("GTEX-O5YU-1426.svs") + if not slide_file.exists(): + url = f"https://brd.nci.nih.gov/brd/imagedownload/{slide_file.stem}" + slide_file = Path(tempfile.NamedTemporaryFile(suffix=".svs").name) + + with open(slide_file, "wb") as file: + for chunk in requests.get(url, stream=True).iter_content(chunk_size=1024 * 4): + file.write(chunk) + else: + for f in sorted(Path().glob(slide_file.stem + "*")): + if f != slide_file: + f.unlink() + return slide_file + + +@pytest.mark.wsi +@pytest.mark.slow +def test_whole_slide_image_inference(get_test_slide): + slide = WholeSlideImage(get_test_slide) + slide.segment() + assert len(slide.contours_tissue) == len(slide.holes_tissue) + slide.tile() + feats, coords = slide.inference("resnet18") + + # Assert conditions + assert coords.shape == (654, 2), "Coords shape mismatch" + print(feats.sum()) + assert np.allclose(feats.sum(), 14.555267, atol=1e-3), "Features sum mismatch" diff --git a/wsi/utils.py b/wsi/utils.py new file mode 100755 index 0000000..963001e --- /dev/null +++ b/wsi/utils.py @@ -0,0 +1,285 @@ +import typing as tp +from pathlib import Path +import tempfile + +import requests +import h5py +import numpy as np +import cv2 +import torch +from torch.utils.data import Dataset +from torchvision import transforms + + +class WholeSlideBag(Dataset): + def __init__( + self, + file_path, + wsi, + pretrained=False, + custom_transforms=None, + custom_downsample=1, + target_patch_size=-1, + target=None, + ): + """ + Args: + file_path (string): Path to the .h5 file containing patched data. + wsi (openslide object): OpenSlide object + pretrained (bool): Use ImageNet transforms + custom_transforms (callable, optional): Optional transform to be applied on a sample + custom_downsample (int): Custom defined downscale factor (overruled by target_patch_size) + target_patch_size (int): Custom defined image size before embedding + """ + self.target = target + + self.pretrained = pretrained + self.wsi = wsi + if not custom_transforms: + self.roi_transforms = default_transforms(pretrained=pretrained) + else: + self.roi_transforms = custom_transforms + + self.file_path = file_path + + with h5py.File(self.file_path, "r") as f: + dset = f["coords"] + self.patch_level = f["coords"].attrs["patch_level"] + self.patch_size = f["coords"].attrs["patch_size"] + self.length = len(dset) + if target_patch_size > 0: + self.target_patch_size = (target_patch_size,) * 2 + elif custom_downsample > 1: + self.target_patch_size = (self.patch_size // custom_downsample,) * 2 + else: + self.target_patch_size = None + # self.summary() + + def __len__(self): + return self.length + + def summary(self): + hdf5_file = h5py.File(self.file_path, "r") + dset = hdf5_file["coords"] + for name, value in dset.attrs.items(): + print(name, value) + + # print("\nfeature extraction settings") + # print("target patch size: ", self.target_patch_size) + # print("pretrained: ", self.pretrained) + # print("transformations: ", self.roi_transforms) + + def __getitem__(self, idx): + with h5py.File(self.file_path, "r") as hdf5_file: + coord = hdf5_file["coords"][idx] + img = self.wsi.read_region( + coord, self.patch_level, (self.patch_size, self.patch_size) + ).convert("RGB") + + if self.target_patch_size is not None: + img = img.resize(self.target_patch_size) + img = self.roi_transforms(img).unsqueeze(0) + if self.target is None: + return img, coord + return img, self.target + + +class ContourCheckingFn(object): + # Defining __call__ method + def __call__(self, pt): + raise NotImplementedError + + +class isInContourV1(ContourCheckingFn): + def __init__(self, contour): + self.cont = contour + + def __call__(self, pt): + return ( + 1 + if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) + >= 0 + else 0 + ) + + +class isInContourV2(ContourCheckingFn): + def __init__(self, contour, patch_size): + self.cont = contour + self.patch_size = patch_size + + def __call__(self, pt): + pt = np.array( + (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) + ).astype(float) + return ( + 1 + if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) + >= 0 + else 0 + ) + + +# Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass +class isInContourV3_Easy(ContourCheckingFn): + def __init__(self, contour, patch_size, center_shift=0.5): + self.cont = contour + self.patch_size = patch_size + self.shift = int(patch_size // 2 * center_shift) + + def __call__(self, pt): + center = (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) + if self.shift > 0: + all_points = [ + (center[0] - self.shift, center[1] - self.shift), + (center[0] + self.shift, center[1] + self.shift), + (center[0] + self.shift, center[1] - self.shift), + (center[0] - self.shift, center[1] + self.shift), + ] + else: + all_points = [center] + + for points in all_points: + if ( + cv2.pointPolygonTest( + self.cont, tuple(np.array(points).astype(float)), False + ) + >= 0 + ): + return 1 + return 0 + + +# Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass +class isInContourV3_Hard(ContourCheckingFn): + def __init__(self, contour, patch_size, center_shift=0.5): + self.cont = contour + self.patch_size = patch_size + self.shift = int(patch_size // 2 * center_shift) + + def __call__(self, pt): + center = (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) + if self.shift > 0: + all_points = [ + (center[0] - self.shift, center[1] - self.shift), + (center[0] + self.shift, center[1] + self.shift), + (center[0] + self.shift, center[1] - self.shift), + (center[0] - self.shift, center[1] + self.shift), + ] + else: + all_points = [center] + + for points in all_points: + if ( + cv2.pointPolygonTest( + self.cont, tuple(np.array(points).astype(float)), False + ) + < 0 + ): + return 0 + return 1 + + +def filter_kwargs_by_callable( + kwargs: tp.Dict[str, tp.Any], + callabl: tp.Callable, + exclude: tp.List[str] | None = None, +) -> tp.Dict[str, tp.Any]: + """Filter a dictionary keeping only the keys which are part of a function signature.""" + from inspect import signature + + args = signature(callabl).parameters.keys() + return {k: v for k, v in kwargs.items() if (k in args) and k not in (exclude or [])} + + +def screen_coords(scores, coords, top_left, bot_right): + bot_right = np.array(bot_right) + top_left = np.array(top_left) + mask = np.logical_and( + np.all(coords >= top_left, axis=1), np.all(coords <= bot_right, axis=1) + ) + scores = scores[mask] + coords = coords[mask] + return scores, coords + + +def to_percentiles(scores): + from scipy.stats import rankdata + + scores = rankdata(scores, "average") / len(scores) * 100 + return scores + + +def collate_features(batch, with_coords: bool = False): + img = torch.cat([item[0] for item in batch], dim=0) + if not with_coords: + return img + coords = np.vstack([item[1] for item in batch]) + return [img, coords] + + +def is_url(url: str | Path) -> bool: + if isinstance(url, Path): + url = url.as_posix() + return url.startswith("http") + + +def download_file( + url: str, dest: Path | str | None = None, overwrite: bool = False +) -> Path: + if dest is None: + dest = Path(tempfile.NamedTemporaryFile().name) + + if Path(dest).exists() and not overwrite: + return Path(dest) + + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(dest, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + return Path(dest) + + +def default_transforms(pretrained=False): + if pretrained: + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + else: + mean = (0.5, 0.5, 0.5) + std = (0.5, 0.5, 0.5) + + trnsfrms_val = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)] + ) + + return trnsfrms_val + + +def save_hdf5(output_path, asset_dict, attr_dict=None, mode="a"): + file = h5py.File(output_path, mode) + for key, val in asset_dict.items(): + data_shape = val.shape + if key not in file: + data_type = val.dtype + chunk_shape = (1,) + data_shape[1:] + maxshape = (None,) + data_shape[1:] + dset = file.create_dataset( + key, + shape=data_shape, + maxshape=maxshape, + chunks=chunk_shape, + dtype=data_type, + ) + dset[:] = val + if attr_dict is not None: + if key in attr_dict.keys(): + for attr_key, attr_val in attr_dict[key].items(): + dset.attrs[attr_key] = attr_val + else: + dset = file[key] + dset.resize(len(dset) + data_shape[0], axis=0) + dset[-data_shape[0] :] = val + file.close() + return output_path diff --git a/wsi_core/WholeSlideImage.py b/wsi/wsi.py similarity index 62% rename from wsi_core/WholeSlideImage.py rename to wsi/wsi.py index 340f294..b2508b3 100755 --- a/wsi_core/WholeSlideImage.py +++ b/wsi/wsi.py @@ -1,47 +1,45 @@ -import multiprocessing as mp +from __future__ import annotations + +from pathlib import Path +import typing as tp import math -import os import time -from xml.dom import minidom -import typing as tp -from pathlib import Path as _Path +import multiprocessing as mp +from PIL import Image import cv2 -import matplotlib.pyplot as plt import numpy as np import openslide -from PIL import Image import h5py +import shapely +import torch -from wsi_core.wsi_utils import ( - savePatchIter_bag_hdf5, - initialize_hdf5_bag, - save_hdf5, - screen_coords, - isBlackPatch, - isWhitePatch, - to_percentiles, -) -from wsi_core.util_classes import ( +from wsi.utils import ( isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard, - Contour_Checking_fn, + ContourCheckingFn, + save_hdf5, + screen_coords, + to_percentiles, + filter_kwargs_by_callable, ) -from wsi_core.file_utils import load_pkl, save_pkl -from wsi_core.utils import Path, filter_kwargs_by_callable Image.MAX_IMAGE_PIXELS = 933120000 +# TODO: replace contours_tumor with a generic label field +# TODO: make function to plot contours (colored by label field) +# TODO: write segmentations to geojson + + class WholeSlideImage(object): def __init__( self, - path: Path | _Path | str, + path: Path | str, *, attributes: tp.Optional[dict[str, tp.Any]] = None, - mask_file: Path | None = None, hdf5_file: Path | None = None, ): """ @@ -50,30 +48,60 @@ def __init__( Parameters ---------- path: Path - Path to WSI file. + Path to WSI file or URL. + If URL is given, the file will be downloaded to a temporary directory in the filesystem. attributes: dict[str, tp.Any] Optional dictionary with attributes to store in the object. - mask_file: Path - Path to file used to save segmentation. Default is `path.with_suffix(".segmentation.pickle")`. hdf5_file: Path Path to file used to save tile coordinates (and images). Default is `path.with_suffix(".h5")`. + + Attributes + ---------- + path: Path + Path to WSI file. + attributes: dict[str, tp.Any] + Dictionary with attributes to store in the object. + name: str + Name of the WSI file. + wsi: openslide.OpenSlide + A handle to the low-level OpenSlide object. + hdf5_file: Path + Path to file used to save tile coordinates (and images). + level_downsamples: list[tuple[float, float]] + List of tuples with downsample factors for each level. + level_dim: list[tuple[int, int]] + List of tuples with dimensions for each level. + contours_tissue: list[np.ndarray] + List of tissue contours. + contours_tumor: list[np.ndarray] + List of tumor contours. + holes_tissue: list[np.ndarray] + List of holes in tissue contours. + target: None + Placeholder for target (e.g. label) for the WSI. + + Returns + ------- + WholeSlideImage + WholeSlideImage object. """ + from wsi.utils import is_url, download_file + if not isinstance(path, Path): + if is_url(path): + path = download_file(str(path)) path = Path(path) self.path = path self.attributes = attributes self.name = path.stem self.wsi = openslide.open_slide(path) - self.level_downsamples = self._assertLevelDownsamples() + self.level_downsamples = self._assert_level_downsamples() self.level_dim = self.wsi.level_dimensions self.contours_tissue: list[np.ndarray] | None = None self.contours_tumor: list[np.ndarray] | None = None self.holes_tissue: list[np.ndarray] | None = None - self.holes_tumor: list[np.ndarray] | None = None - self.mask_file: Path = ( - path.with_suffix(".segmentation.pickle") if mask_file is None else mask_file - ) + # UNUSED: self.holes_tumor: list[np.ndarray] | None = None self.hdf5_file: Path = path.with_suffix(".h5") if hdf5_file is None else hdf5_file self.target = None @@ -81,81 +109,126 @@ def __init__( def __repr__(self): return f"WholeSlideImage('{self.path}')" - def getOpenSlide(self): - return self.wsi + def _assert_level_downsamples(self): + level_downsamples = [] + dim_0 = self.wsi.level_dimensions[0] - def initXML(self, xml_path): - def _createContour(coord_list): - return np.array( - [ - [ - [ - int(float(coord.attributes["X"].value)), - int(float(coord.attributes["Y"].value)), - ] - ] - for coord in coord_list - ], - dtype="int32", + for downsample, dim in zip(self.wsi.level_downsamples, self.wsi.level_dimensions): + estimated_downsample = (dim_0[0] / float(dim[0]), dim_0[1] / float(dim[1])) + ( + level_downsamples.append(estimated_downsample) + if estimated_downsample + != ( + downsample, + downsample, + ) + else level_downsamples.append((downsample, downsample)) ) - xmldoc = minidom.parse(xml_path) - annotations = [ - anno.getElementsByTagName("Coordinate") - for anno in xmldoc.getElementsByTagName("Annotation") - ] - self.contours_tumor = [_createContour(coord_list) for coord_list in annotations] - self.contours_tumor = sorted( - self.contours_tumor, key=cv2.contourArea, reverse=True - ) + return level_downsamples - def initTxt(self, annot_path): - def _create_contours_from_dict(annot): - all_cnts = [] - for idx, annot_group in enumerate(annot): - contour_group = annot_group["coordinates"] - if annot_group["type"] == "Polygon": - for idx, contour in enumerate(contour_group): - contour = np.array(contour).astype(np.int32).reshape(-1, 1, 2) - all_cnts.append(contour) + def _load_segmentation_legacy(self, pickle_file: Path | None = None) -> None: + import warnings + import pickle - else: - for idx, sgmt_group in enumerate(contour_group): - contour = [] - for sgmt in sgmt_group: - contour.extend(sgmt) - contour = np.array(contour).astype(np.int32).reshape(-1, 1, 2) - all_cnts.append(contour) - - return all_cnts - - with open(annot_path, "r") as f: - annot = f.read() - annot = eval(annot) - self.contours_tumor = _create_contours_from_dict(annot) - self.contours_tumor = sorted( - self.contours_tumor, key=cv2.contourArea, reverse=True + warnings.warn( + "Loading segmentation results from a pickle file is deprecated. " + "Save segmentation results to an HDF5 file instead.", ) - def initSegmentation(self, mask_file: Path | str | None = None): - if mask_file is None: - mask_file = self.mask_file - # load segmentation results from pickle file - asset_dict = load_pkl(mask_file) - self.holes_tissue = asset_dict["holes"] - self.contours_tissue = asset_dict["tissue"] - - def load_segmentation(self, mask_file: Path | str | None = None): - self.initSegmentation(mask_file) - - def saveSegmentation(self, mask_file: Path | str | None = None): - if mask_file is None: - mask_file = self.mask_file - # save segmentation results using pickle - asset_dict = {"holes": self.holes_tissue, "tissue": self.contours_tissue} - save_pkl(mask_file, asset_dict) - - def segmentTissue( + if pickle_file is None: + pickle_file = self.path.with_suffix(".segmentation.pickle") + + data = pickle.load(pickle_file.open("rb")) + self.contours_tissue = data["tissue"] + self.holes_tissue = data["holes"] + + def load_segmentation(self, hdf5_file: Path | None = None) -> None: + """ + Load slide segmentation results from pickle file. + + Parameters + ---------- + hdf5_file: Path + Path to file used to save segmentation. + If None, the segmentation results will be loaded from `self.hdf5_file`. + + Returns + ------- + None + """ + if hdf5_file is None: + hdf5_file = self.hdf5_file + + legacy_file = self.path.with_suffix(".segmentation.pickle") + if not hdf5_file.exists(): + if legacy_file.exists(): + self._load_segmentation_legacy(legacy_file) + return + + req = [ + "contours_tissue_breakpoints", + "contours_tissue", + "holes_tissue_breakpoints", + "holes_tissue", + ] + + with h5py.File(hdf5_file, "r") as f: + for r in req: + if r not in f: + print(f"H5 file {hdf5_file} did not have the required keys!") + if legacy_file.exists(): + self._load_segmentation_legacy(legacy_file) + return + raise ValueError(f"Required dataset {r} not found in {hdf5_file}") + bpt = f["contours_tissue_breakpoints"][()] + ct = f["contours_tissue"][()] + self.contours_tissue = [ + ct[bpt[i] : bpt[i + 1]] for i in range(bpt.shape[0] - 1) + ] + + bph = f["holes_tissue_breakpoints"][()] + ht = f["holes_tissue"][()] + holes_tissue = list() + for b in bph: + res = [] + for i in range(b.shape[0] - 1): + if b[i + 1] != 0: + res.append(ht[b[i] : b[i + 1]]) + holes_tissue.append(res) + self.holes_tissue = holes_tissue + + def save_segmentation(self, hdf5_file: Path | None = None, mode: str = "a") -> None: + """ + Save slide segmentation results to pickle file. + + Parameters + ---------- + hdf5_file: Path + Path to file used to save segmentation. + If None, the segmentation results will be loaded from `self.hdf5_file`. + + Returns + ------- + None + """ + if hdf5_file is None: + hdf5_file = self.hdf5_file + with h5py.File(hdf5_file, mode) as f: + data = np.concatenate(self.contours_tissue) + f.create_dataset("contours_tissue", data=data) + bpt = [0] + np.cumsum([c.shape[0] for c in self.contours_tissue]).tolist() + f.create_dataset("contours_tissue_breakpoints", data=bpt) + + holes = [h if h else [np.empty((0, 1, 2))] for h in self.holes_tissue] + bph = [[0] + np.cumsum([h.shape[0] for h in c]).tolist() for c in holes] + n = max([len(_h) for _h in bph]) + bph = np.asarray([_h + [0] * (n - len(_h)) for _h in bph]).reshape(-1, n) + holesc = np.concatenate([np.concatenate(c) for c in holes]) + f.create_dataset("holes_tissue", data=holesc) + f.create_dataset("holes_tissue_breakpoints", data=bph) + + def _segment_tissue( self, seg_level=0, sthresh=20, @@ -167,7 +240,7 @@ def segmentTissue( ref_patch_size=512, exclude_ids=[], keep_ids=[], - ): + ) -> None: """ Segment the tissue via HSV -> Median thresholding -> Binary threshold """ @@ -259,10 +332,9 @@ def _filter_contours(contours, hierarchy, filter_params): contours, hierarchy, filter_params ) # Necessary for filtering out artifacts - self.contours_tissue = self.scaleContourDim(foreground_contours, scale) - self.holes_tissue = self.scaleHolesDim(hole_contours, scale) + self.contours_tissue = self._scale_contour_dim(foreground_contours, scale) + self.holes_tissue = self._scale_holes_dim(hole_contours, scale) - # exclude_ids = [0,7,9] if len(keep_ids) > 0: contour_ids = set(keep_ids) - set(exclude_ids) else: @@ -271,22 +343,58 @@ def _filter_contours(contours, hierarchy, filter_params): self.contours_tissue = [self.contours_tissue[i] for i in contour_ids] self.holes_tissue = [self.holes_tissue[i] for i in contour_ids] - def visWSI( + def vis_wsi( self, - vis_level=0, - color=(0, 255, 0), - hole_color=(0, 0, 255), - annot_color=(255, 0, 0), - line_thickness=250, - max_size=None, - top_left=None, - bot_right=None, - custom_downsample=1, - view_slide_only=False, - number_contours=False, - seg_display=True, - annot_display=True, - ): + vis_level: int = 0, + color: tuple[int, int, int] = (0, 255, 0), + hole_color: tuple[int, int, int] = (0, 0, 255), + annot_color: tuple[int, int, int] = (255, 0, 0), + line_thickness: float = 250.0, + max_size: int = None, + top_left: tuple[int, int] = None, + bot_right: tuple[int, int] = None, + custom_downsample: float = 1.0, + view_slide_only: bool = False, + number_contours: bool = False, + seg_display: bool = True, + annot_display: bool = True, + ) -> Image: + """ + Visualize the whole slide image. + + Parameters + ---------- + vis_level: int + The level to visualize. + color: tuple + The color of the tissue. + hole_color: tuple + The color of the holes. + annot_color: tuple + The color of the annotations. + line_thickness: int + The thickness of the annotations. + max_size: int + The maximum size of the image. + top_left: tuple + The top left corner of the region to visualize. + bot_right: tuple[int, int]: tuple + The bottom right corner of the region to visualize. + custom_downsample: int + The custom downsample factor. + view_slide_only: bool + Whether to only visualize the slide. + number_contours: bool + Whether to number the contours. + seg_display: bool + Whether to display the segmentation. + annot_display: bool + Whether to display the annotations. + + Returns + ------- + Image + """ downsample = self.level_downsamples[vis_level] scale = [1 / downsample[0], 1 / downsample[1]] @@ -313,7 +421,7 @@ def visWSI( if not number_contours: cv2.drawContours( img, - self.scaleContourDim(self.contours_tissue, scale), + self._scale_contour_dim(self.contours_tissue, scale), -1, color, line_thickness, @@ -323,7 +431,7 @@ def visWSI( else: # add numbering to each contour for idx, cont in enumerate(self.contours_tissue): - contour = np.array(self.scaleContourDim(cont, scale)) + contour = np.array(self._scale_contour_dim(cont, scale)) M = cv2.moments(contour) cX = int(M["m10"] / (M["m00"] + 1e-9)) cY = int(M["m01"] / (M["m00"] + 1e-9)) @@ -350,7 +458,7 @@ def visWSI( for holes in self.holes_tissue: cv2.drawContours( img, - self.scaleContourDim(holes, scale), + self._scale_contour_dim(holes, scale), -1, hole_color, line_thickness, @@ -360,7 +468,7 @@ def visWSI( if self.contours_tumor is not None and annot_display: cv2.drawContours( img, - self.scaleContourDim(self.contours_tumor, scale), + self._scale_contour_dim(self.contours_tumor, scale), -1, annot_color, line_thickness, @@ -380,226 +488,8 @@ def visWSI( return img - def createPatches_bag_hdf5( - self, - save_path: Path | str | None = None, - patch_level=0, - patch_size=256, - step_size=256, - save_coord=True, - **kwargs, - ): - if save_path is None: - save_path = self.hdf5_file.parent - contours = self.contours_tissue - contour_holes = self.holes_tissue - - print(f"Creating patches for: {self.name}") - elapsed = time.time() - for idx, cont in enumerate(contours): - patch_gen = self._getPatchGenerator( - cont, idx, patch_level, save_path, patch_size, step_size, **kwargs - ) - - if not self.hdf5_file.exists(): - try: - first_patch = next(patch_gen) - - # empty contour, continue - except StopIteration: - continue - - file_path = initialize_hdf5_bag(first_patch, save_coord=save_coord) - self.hdf5_file = Path(file_path) - - for patch in patch_gen: - savePatchIter_bag_hdf5(patch) - - return self.hdf5_file - - def as_tile_bag(self): - # from wsi_core.dataset_h5 import Whole_Slide_Bag - from wsi_core.dataset_h5 import Whole_Slide_Bag_FP - - # dataset = Whole_Slide_Bag(self.hdf5_file, pretrained=True) - dataset = Whole_Slide_Bag_FP( - self.hdf5_file, self.wsi, pretrained=True, target=self.target - ) - return dataset - - def as_data_loader(self, batch_size: int = 32, with_coords: bool = False, **kwargs): - from functools import partial - from wsi_core.utils import collate_features - from torch.utils.data import DataLoader - - collate = partial(collate_features, with_coords=with_coords) - - dataset = self.as_tile_bag() - loader = DataLoader( - dataset=dataset, batch_size=batch_size, collate_fn=collate, **kwargs - ) - return loader - - def inference( - self, - model_name: str, - model_repo: str = "pytorch/vision", - device: str = "cpu", - data_loader_kws: dict = {}, - ) -> tp.Tuple[np.ndarray, np.ndarray]: - """ - Inference on the WSI using a pretrained model. - - Parameters - ---------- - model_name: str - Name of the model to use for inference. - model_repo: str - Repository to load the model from. Default is "torch/vision". - data_loader_kws: dict - Keyword arguments to pass to the data loader. - - Returns - ------- - Tuple[np.ndarray, np.ndarray] - Tuple of (features, coordinates). - """ - import torch - from tqdm import tqdm - - data_loader = self.as_data_loader(**data_loader_kws, with_coords=True) - model = torch.hub.load(model_repo, model_name, pretrained=True).to(device) - model.eval() - coords = list() - feats = list() - for batch, coord in tqdm(data_loader): - with torch.no_grad(): - feats.append(model(batch.to(device)).cpu().numpy()) - coords.append(coord) - return np.concatenate(feats, axis=0), np.concatenate(coords, axis=0) - - def _getPatchGenerator( - self, - cont, - cont_idx, - patch_level, - save_path, - patch_size=256, - step_size=256, - custom_downsample=1, - white_black=True, - white_thresh=15, - black_thresh=50, - contour_fn="four_pt", - use_padding=True, - ): - start_x, start_y, w, h = ( - cv2.boundingRect(cont) - if cont is not None - else (0, 0, self.level_dim[patch_level][0], self.level_dim[patch_level][1]) - ) - # print("Bounding Box:", start_x, start_y, w, h) - # print("Contour Area:", cv2.contourArea(cont)) - - if custom_downsample > 1: - assert custom_downsample == 2 - target_patch_size = patch_size - patch_size = target_patch_size * 2 - step_size = step_size * 2 - print( - "Custom Downsample: {}, Patching at {} x {}, But Final Patch Size is {} x {}".format( - custom_downsample, - patch_size, - patch_size, - target_patch_size, - target_patch_size, - ) - ) - - patch_downsample = ( - int(self.level_downsamples[patch_level][0]), - int(self.level_downsamples[patch_level][1]), - ) - ref_patch_size = ( - patch_size * patch_downsample[0], - patch_size * patch_downsample[1], - ) - - step_size_x = step_size * patch_downsample[0] - step_size_y = step_size * patch_downsample[1] - - if isinstance(contour_fn, str): - if contour_fn == "four_pt": - cont_check_fn = isInContourV3_Easy( - contour=cont, patch_size=ref_patch_size[0], center_shift=0.5 - ) - elif contour_fn == "four_pt_hard": - cont_check_fn = isInContourV3_Hard( - contour=cont, patch_size=ref_patch_size[0], center_shift=0.5 - ) - elif contour_fn == "center": - cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size[0]) - elif contour_fn == "basic": - cont_check_fn = isInContourV1(contour=cont) - else: - raise NotImplementedError - else: - assert isinstance(contour_fn, Contour_Checking_fn) - cont_check_fn = contour_fn - - img_w, img_h = self.level_dim[0] - if use_padding: - stop_y = start_y + h - stop_x = start_x + w - else: - stop_y = min(start_y + h, img_h - ref_patch_size[1]) - stop_x = min(start_x + w, img_w - ref_patch_size[0]) - - count = 0 - for y in range(start_y, stop_y, step_size_y): - for x in range(start_x, stop_x, step_size_x): - if not self.isInContours( - cont_check_fn, - (x, y), - self.holes_tissue[cont_idx], - ref_patch_size[0], - ): # point not inside contour and its associated holes - continue - - count += 1 - patch_PIL = self.wsi.read_region( - (x, y), patch_level, (patch_size, patch_size) - ).convert("RGB") - if custom_downsample > 1: - patch_PIL = patch_PIL.resize((target_patch_size, target_patch_size)) - - if white_black: - if isBlackPatch( - np.array(patch_PIL), rgbThresh=black_thresh - ) or isWhitePatch(np.array(patch_PIL), satThresh=white_thresh): - continue - - patch_info = { - "x": x // (patch_downsample[0] * custom_downsample), - "y": y // (patch_downsample[1] * custom_downsample), - "cont_idx": cont_idx, - "patch_level": patch_level, - "downsample": self.level_downsamples[patch_level], - "downsampled_level_dim": tuple( - np.array(self.level_dim[patch_level]) // custom_downsample - ), - "level_dim": self.level_dim[patch_level], - "patch_PIL": patch_PIL, - "name": self.name, - "save_path": save_path, - } - - yield patch_info - - print("patches extracted: {}".format(count)) - @staticmethod - def isInHoles(holes, pt, patch_size): + def _is_in_holes(holes, pt, patch_size): for hole in holes: if ( cv2.pointPolygonTest( @@ -612,46 +502,33 @@ def isInHoles(holes, pt, patch_size): return 0 @staticmethod - def isInContours(cont_check_fn, pt, holes=None, patch_size=256): + def _is_in_contours(cont_check_fn, pt, holes=None, patch_size=256): if cont_check_fn(pt): if holes is not None: - return not WholeSlideImage.isInHoles(holes, pt, patch_size) + return not WholeSlideImage._is_in_holes(holes, pt, patch_size) else: return 1 return 0 @staticmethod - def scaleContourDim(contours, scale): + def _scale_contour_dim(contours, scale): return [np.array(cont * scale, dtype="int32") for cont in contours] @staticmethod - def scaleHolesDim(contours, scale): + def _scale_holes_dim(contours, scale): return [ [np.array(hole * scale, dtype="int32") for hole in holes] for holes in contours ] - def _assertLevelDownsamples(self): - level_downsamples = [] - dim_0 = self.wsi.level_dimensions[0] - - for downsample, dim in zip(self.wsi.level_downsamples, self.wsi.level_dimensions): - estimated_downsample = (dim_0[0] / float(dim[0]), dim_0[1] / float(dim[1])) - level_downsamples.append(estimated_downsample) if estimated_downsample != ( - downsample, - downsample, - ) else level_downsamples.append((downsample, downsample)) - - return level_downsamples - - def process_contours( + def _process_contours( self, save_path: tp.Optional[Path] = None, patch_level=0, patch_size=256, step_size=256, **kwargs, - ): + ) -> Path: if save_path is None: save_path = self.hdf5_file # print("Creating patches for: ", self.name, "...") @@ -664,7 +541,7 @@ def process_contours( if (idx + 1) % fp_chunk_size == fp_chunk_size: print("Processing contour {}/{}".format(idx, n_contours)) - asset_dict, attr_dict = self.process_contour( + asset_dict, attr_dict = self._process_contour( cont, self.holes_tissue[idx], patch_level, @@ -690,15 +567,16 @@ def process_contours( return self.hdf5_file - def decompose_color(self, output_file: Path | None = None): + def _decompose_color(self, output_file: Path | None = None) -> None: from skimage.color import rgb2hsv, hsv2rgb, rgb_from_hed from sklearn.decomposition import PCA # , FastICA + import matplotlib.pyplot as plt - def minmax_scale(x): + def _minmax_scale(x): return (x - np.min(x)) / (np.max(x) - np.min(x)) if output_file is None: - output_file = self.mask_file.with_suffix(".pca.png") + output_file = self.hdf5_file.with_suffix(".pca.png") # Work with thubnail by default level = self.wsi.level_count - 1 @@ -726,13 +604,59 @@ def minmax_scale(x): axes[0].set(title=f"PC {sel}, argmax: {sel.argmax()}, sign: {sign}") axes[0].imshow(thumbnail) for i in range(3): - axes[i + 1].imshow(minmax_scale(thumbnail_pca[..., i])) + axes[i + 1].imshow(_minmax_scale(thumbnail_pca[..., i])) for ax in axes: ax.axis("off") fig.savefig(output_file, bbox_inches="tight", dpi=200, pad_inches=0.0) plt.close(fig) - def segment_tissue_manual(self, level: int | None = None, color_space: str = "RGB"): + def _get_best_level(self, target_dimensions: tuple[int, int] = (2000, 2000)) -> int: + g = np.absolute( + (np.asarray(self.wsi.level_dimensions) - np.asarray(target_dimensions)) + ).sum(1) + return np.argmin(g) + + def get_thumbnail(self, level: int | None = None) -> np.ndarray: + """ + Get array representing a low resolution image of the whole slide image. + + Parameters + ---------- + level: int + Which pyramid level to retrieve image at. + """ + if level is None: + level = self._get_best_level((2000, 2000)) + thumbnail = np.array( + self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB") + ) + return thumbnail + + def _segment_tissue_manual( + self, + level: int | None = None, + color_space: str = "RGB", + otsu_threshold_relaxation: float = 0, + dilation_diameter: float = 2.0, + small_object_threshold: int = 200, + fill_holes_threshold: int = 20, + hole_object_threshold: int = 5000, + ) -> None: + """ + Segment the tissue using manually optimized parameters. + + Parameters + ---------- + level: int + WSI level to segment tissue from. + Default is None, which will find the level closest to a thumbnail with 2000x2000 pixels. + color_space: str + Color space to work in. Either "RGB" or "HED". + + Returns + ------- + None + """ import skimage import scipy.ndimage as ndi @@ -740,15 +664,7 @@ def segment_tissue_manual(self, level: int | None = None, color_space: str = "RG # Work with thumbnail by default if level is None: - # level = self.wsi.level_count - 1 - # Find level with dimension closest to 2000x2000 - level = ( - np.absolute( - np.asarray(self.wsi.level_dimensions) - np.asarray([(2000, 2000)]) - ) - .mean(1) - .argmin() - ) + level = self._get_best_level((2000, 2000)) thumbnail = np.array( self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB") ) @@ -758,32 +674,65 @@ def segment_tissue_manual(self, level: int | None = None, color_space: str = "RG from skimage.color import rgb2hed hed = rgb2hed(thumbnail) - thumbnail = hed[..., :-1].min(-1) + thumbnailm = hed[..., :-1].min(-1) # Threshold for bright - m = thumbnail > skimage.filters.threshold_otsu(thumbnail) + t = skimage.filters.threshold_otsu(thumbnailm) + m = thumbnailm > (t - t * otsu_threshold_relaxation) elif color_space == "RGB": # Work in mean RGB space - thumbnail = thumbnail.mean(-1) + thumbnailm = thumbnail.mean(-1) # Threshold for dark - m = thumbnail < skimage.filters.threshold_otsu(thumbnail) + t = skimage.filters.threshold_otsu(thumbnailm) + m = thumbnailm < (t + t * otsu_threshold_relaxation) # Dilate mask - m = skimage.morphology.dilation(m, skimage.morphology.disk(2)) + m = skimage.morphology.dilation(m, skimage.morphology.disk(dilation_diameter)) + + # Remove foreground overlapping the edges + m[0, :] = False + m[-1, :] = False + m[:, 0] = False + m[:, -1] = False # Remove small objects - m = skimage.morphology.remove_small_objects(m, 500, connectivity=1) + m = skimage.morphology.remove_small_objects( + m, m.size // small_object_threshold, connectivity=1 + ) # Fill holes (for contour) - mask = ~skimage.morphology.remove_small_objects(~m, m.size // 2, connectivity=1) + mask = ~skimage.morphology.remove_small_objects( + ~m, m.size // fill_holes_threshold, connectivity=1 + ) # Get polygon contours from binary mask contours_tissue = skimage.measure.find_contours(mask, 0.5, fully_connected="high") + # blobs_tissue = skimage.measure.label(mask, background=0) + # tprops = skimage.measure.regionprops(blobs_tissue) + # contours_tissue = [ + # np.concatenate( + # skimage.measure.find_contours(p.image, 0.5, fully_connected="high") + # ) + # + p.bbox[:2] + # for p in tprops + # ] # Get holes holes, _ = ndi.label(~m) # # remove largest one (which should be the background) holes[holes == 1] = 0 - holes = skimage.morphology.remove_small_objects(holes, 50, connectivity=1) - holes_tissue = skimage.measure.find_contours(holes, 0.5, fully_connected="high") + holes = skimage.morphology.remove_small_objects( + holes, m.size // hole_object_threshold, connectivity=1 + ) + holes_tissue = skimage.measure.find_contours( + holes > 1, 0.5, fully_connected="high" + ) + # hprops = skimage.measure.regionprops(holes) + # holes_tissue = [ + # np.concatenate( + # skimage.measure.find_contours(p.image, 0.5, fully_connected="high") + # ) + # + p.bbox[:2] + # for p in hprops + # ] # Scale up to size of original image # # Reverse axis order @@ -796,19 +745,63 @@ def segment_tissue_manual(self, level: int | None = None, color_space: str = "RG for cont in holes_tissue ] - # TODO: Important! Pair holes and contours by checking which holes are in which tissue pieces + # Important! Pair holes and contours by checking which holes are in which tissue pieces # shape of holes_tissue must match contours_tissue, even if there are no holes self.contours_tissue = [x[:, np.newaxis, :] for x in contours_tissue] self.holes_tissue = [x[:, np.newaxis, :] for x in holes_tissue] + conts = { + i: shapely.Polygon(cont.squeeze()) + for i, cont in enumerate(self.contours_tissue) + } + new_holes = [[] for _ in range(len(self.contours_tissue))] + for hole in self.holes_tissue: + h = shapely.Polygon(hole.squeeze()) + for i, cont in conts.items(): + if h.intersects(cont): + new_holes[i].append(hole) + self.holes_tissue = new_holes + assert len(self.contours_tissue) > 0, "Segmentation could not find tissue!" - self.saveSegmentation() + self.save_segmentation() + return None + + # # Viz during development: + # import matplotlib.pyplot as plt + + # fig, axes = plt.subplots(1, 6, figsize=(30, 5)) + # axes[0].imshow(thumbnail, rasterized=True) + # axes[0].axis("off") + # axes[0].set_title("Original") + # axes[1].imshow(thumbnailm, rasterized=True) + # axes[1].axis("off") + # axes[1].set_title("Mean") + # axes[2].imshow(m, rasterized=True) + # axes[2].axis("off") + # axes[2].set_title("pre-Mask") + # axes[3].imshow(mask, rasterized=True) + # axes[3].axis("off") + # axes[3].set_title("Mask") + # axes[4].imshow(holes > 0, rasterized=True) + # axes[4].axis("off") + # axes[4].set_title("Holes") + # axes[5].imshow(thumbnail, rasterized=True) + # colors = ["green", "orange", "purple"] + # for col, cont in zip(colors, contours_tissue): + # axes[5].plot(*cont.squeeze().T, color=col) + # for hole in holes_tissue: + # axes[5].plot(*hole.squeeze().T, color="black") + # axes[5].axis("off") + # axes[5].set_title("Trace") + # fig.tight_layout() + # # fig.savefig("test.png") + # return fig def segment( self, - params: tp.Optional[dict[str, tp.Any]] = None, method: str = "manual", + params: tp.Optional[dict[str, tp.Any]] = None, ) -> None: """ Segment the WSI for tissue and background. @@ -839,7 +832,7 @@ def segment( """ assert method in ["manual", "CLAM"], f"Unknown segmentation method: {method}" if method == "manual": - self.segment_tissue_manual(**(params or {})) + self._segment_tissue_manual(**(params or {})) else: # import pandas as pd if params is None: @@ -864,61 +857,22 @@ def segment( } if "seg_level" not in params: - g = np.absolute( - (np.asarray(self.wsi.level_dimensions) - np.asarray([1000, 1000])) - ).sum(1) - params["seg_level"] = np.argmin(g) + params["seg_level"] = self._get_best_level((1000, 1000)) - kwargs = filter_kwargs_by_callable(params, self.segmentTissue) + kwargs = filter_kwargs_by_callable(params, self._segment_tissue) fkwargs = {k: v for k, v in params.items() if k not in kwargs} - self.segmentTissue(**kwargs, filter_params=fkwargs) + self._segment_tissue(**kwargs, filter_params=fkwargs) assert len(self.contours_tissue) > 0, "Segmentation could not find tissue!" - self.saveSegmentation() + self.save_segmentation() self.plot_segmentation() - # def plot_segmentation(self, output_file: tp.Optional[Path] = None) -> None: - # from shapely.geometry import Polygon - - # if output_file is None: - # output_file = self.mask_file.with_suffix(".png") - - # level = self.wsi.level_count - 1 - # thumbnail = np.array( - # self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB") - # ) - - # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) - # ax.imshow(thumbnail) - # tissue: np.ndarray - # hole: np.ndarray - # for i, tissue in enumerate(self.contours_tissue or [], 1): - # # resize to thumbnail size - # tissue = np.array( - # tissue.squeeze() / self.wsi.level_downsamples[level], dtype="int32" - # ) - # poly = Polygon(tissue) - # ax.plot(*tissue.T) - # ax.text( - # *poly.centroid.coords[0], - # str(i), - # color="black", - # ha="center", - # va="center", - # fontsize=10, - # ) - # for i, hole in enumerate(self.holes_tissue or [], 1): - # # resize to thumbnail size - # hole = np.array( - # hole.squeeze() / self.wsi.level_downsamples[level], dtype="int32" - # ) - # poly = Polygon(hole) - # ax.plot(*hole.T, color="black", linestyle="-", linewidth=0.2) - # ax.axis("off") - # fig.savefig(output_file, bbox_inches="tight", dpi=200, pad_inches=0.0) - # plt.close(fig) - # return fig - - def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> None: + def plot_segmentation( + self, + output_file: tp.Optional[Path] = None, + per_contour: bool = False, + level: int | None = None, + **kwargs, + ) -> None: """ Plot the segmentation of the WSI. @@ -932,7 +886,7 @@ def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> `self.path.with_suffix(".segmentation.png")`. kwargs: dict - Additional keyword arguments to pass to `visWSI`. + Additional keyword arguments to pass to `vis_wsi`. Returns ------- @@ -941,8 +895,20 @@ def plot_segmentation(self, output_file: tp.Optional[Path] = None, **kwargs) -> if output_file is None: output_file = self.path.with_suffix(".segmentation.png") - level = self.wsi.level_count - 1 - self.visWSI(vis_level=level, **kwargs).save(output_file) + if level is None: + level = self._get_best_level((2000, 2000)) + if not per_contour: + self.vis_wsi(vis_level=level, **kwargs).save(output_file) + return + + for i, piece in enumerate(self.contours_tissue): + poly = shapely.geometry.Polygon(piece.squeeze()) + top_left, bottom_right = np.asarray(poly.bounds).reshape(2, 2).astype(int) + img = self.vis_wsi( + vis_level=level, top_left=top_left, bot_right=bottom_right, **kwargs + ) + n = str(i).zfill(3) + img.save(output_file.with_suffix(f".{n}.png")) def tile( self, @@ -964,7 +930,7 @@ def tile( step_size: int Step size between patches in pixels. contour_subset: list[int] - 1-based index of which contours to use. If None, use all contours. + Index of which contours to use (0-based). If None, use all contours. Returns ------- @@ -974,22 +940,60 @@ def tile( if contour_subset is not None: original_contours = copy(self.contours_tissue) - self.contours_tissue = [self.contours_tissue[i - 1] for i in contour_subset] + self.contours_tissue = [self.contours_tissue[i] for i in contour_subset] - self.process_contours( + if contour_subset is not None: + original_holes = copy(self.holes_tissue) + self.holes_tissue = [self.holes_tissue[i] for i in contour_subset] + + self._process_contours( patch_level=patch_level, patch_size=patch_size, step_size=step_size ) if contour_subset is not None: self.contours_tissue = original_contours + self.holes_tissue = original_holes + + def has_tissue_contours(self) -> bool: + """ + Check if the WSI has tissue contours saved. + + Returns + ------- + bool + True if it exists + """ + if not self.hdf5_file.exists(): + return False + if self.contours_tissue is not None: + if isinstance(self.contours_tissue, list): + return isinstance(self.contours_tissue[0], np.ndarray) + with h5py.File(self.hdf5_file, "r") as h5: + return "contours_tissue" in h5 + + def has_tile_coords(self) -> bool: + """ + Check if the WSI has tile coordinates saved in its HDF5 file. - def has_tile_coords(self): + Returns + ------- + bool + True if it exists + """ if not self.hdf5_file.exists(): return False with h5py.File(self.hdf5_file, "r") as h5: return "coords" in h5 - def has_tile_images(self): + def has_tile_images(self) -> bool: + """ + Check if the WSI has tile images in its HDF5 file. + + Returns + ------- + bool + True if it exists + """ if not self.hdf5_file.exists(): return False with h5py.File(self.hdf5_file, "r") as h5: @@ -1019,12 +1023,158 @@ def get_tile_coordinates(self, hdf5_file: Path | None = None) -> np.ndarray: def get_tile_coordinate_level_size( self, hdf5_file: Path | None = None ) -> tuple[int, int]: + """ + Retrieve level and size of tiles from HDF5 file. + + By default uses the `self.hdf5_file` attribute, but can be overridden. + + Parameters + ---------- + hdf5_file: Path + Path to HDF5 file containing tile coordinates. + + Returns + ------- + tuple[int, int] + Level and size of tiles. + """ if hdf5_file is None: hdf5_file = self.hdf5_file # or self.tile_h5 with h5py.File(hdf5_file, "r") as h5: attrs = h5["coords"].attrs return attrs["patch_level"], attrs["patch_size"] + def get_tile_polygons(self) -> list[shapely.Polygon]: + """ + Retrieve polygons of tile bounds. + + Returns + ------- + List of tile shapely.Polygon objects. + """ + if (not self.has_tissue_contours()) or (not self.has_tile_coords()): + raise ValueError("Tissue contours or tile coordinates not found.") + + level, size = self.get_tile_coordinate_level_size() + scale = self.wsi.level_downsamples[level] + coords = self.get_tile_coordinates() + + # Make a spatial index tree of the tiles + tiles = [ + shapely.Polygon( + [ + (xy[0], xy[1]), + ((xy[0] + (size * scale)), xy[1]), + ((xy[0] + (size * scale)), (xy[1] + (size * scale))), + (xy[0], (xy[1] + (size * scale))), + (xy[0], xy[1]), + ] + ) + for idx, xy in enumerate(coords) + ] + return tiles + + def get_tile_tissue_piece(self) -> np.ndarray: + """ + Retrieve which tile overlaps which tissue contour. + + Returns + ------- + np.ndarray + Array of shape (N, M) where N is the number of tiles and M is the number of tissue contours. + """ + + # Make a spatial index tree of the tiles + tiles = self.get_tile_polygons() + tree = shapely.strtree.STRtree(tiles) + + # Query the tree for each tissue contour + pieces = np.zeros((len(tiles), len(self.contours_tissue)), dtype=bool) + for i, cont in enumerate(self.contours_tissue): + poly = shapely.Polygon(cont.squeeze()) + result = tree.query(poly) + pieces[:, i] = np.isin(range(len(tiles)), result) + + return pieces + + def get_tile_graph( + self, query_type="distance", max_dist: float | None = None + ) -> np.ndarray: + """ + Retrieve a graph of tile spatial proximity. + + Parameters + ---------- + query_type: str + Type of query. Either "distance" or "knn". + max_dist: float + Maximum distance for distance-based queries. If None, use the tile size centered on tile centroids. + + Returns + ------- + np.ndarray + Array with edges of shape (2, N) where N is the number of edges. + """ + from scipy.spatial import KDTree + + assert query_type in ["distance", "knn"], "query_type must be 'distance' or 'knn'" + + tiles = self.get_tile_polygons() + data = np.asarray([t.centroid.xy for t in tiles]).squeeze() + if max_dist is None: + level, size = self.get_tile_coordinate_level_size() + max_dist = size * self.wsi.level_downsamples[level] + 1 + + tree = KDTree(data) + if query_type == "distance": + edge_index = tree.query_pairs(max_dist, output_type="ndarray").T + elif query_type == "knn": + raise NotImplementedError("knn not implemented yet") + # dist, edge_index = tree.query(data, k=k) + + return edge_index + + def plot_tile_graph(self, output_file: Path | None = None) -> None: + """ + Plot a graph of tile spatial proximity. + + Parameters + ---------- + output_file: Path + Path to output file. If None, save to `self.path.with_suffix(".tile_graph.png")`. + + Returns + ------- + None + """ + import matplotlib.pyplot as plt + + if output_file is None: + output_file = self.path.with_suffix(".tile_graph.png") + + level = self._get_best_level((2000, 2000)) + thumbnail = np.array( + self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB") + ) + scale = self.wsi.level_downsamples[level] + coords = self.get_tile_coordinates() / scale + edge_index = self.get_tile_graph() + ar = thumbnail.shape[0] / thumbnail.shape[1] + + fig, ax = plt.subplots(figsize=(10, 10 * ar)) + ax.imshow(thumbnail) + ax.scatter(coords[:, 0], coords[:, 1], s=1, rasterized=True) + for edge_index in edge_index.T: + ax.plot( + coords[edge_index, 0], + coords[edge_index, 1], + color="black", + rasterized=True, + ) + ax.axis("off") + fig.tight_layout() + fig.savefig(output_file, bbox_inches="tight", dpi=200) + def get_tile_images( self, hdf5_file: Path | None = None, @@ -1043,8 +1193,8 @@ def get_tile_images( Returns ------- - np.ndarray - Array of tile images with shape (N, 3, H, W). + generator + Each element is an array of with shape (3, H, W). """ if hdf5_file is None: hdf5_file = self.hdf5_file # or self.tile_h5 @@ -1080,7 +1230,7 @@ def get_tile_images( def save_tile_images( self, output_dir: Path, - format: str = "jpg", + output_format: str = "jpg", attributes: bool = True, n: int | None = None, frac: float = 1.0, @@ -1092,7 +1242,7 @@ def save_tile_images( ---------- output_dir: Path Directory to save tile images to. - format: str + output_format: str File format to save images as. attributes: bool Whether to include attributes in filename. @@ -1117,11 +1267,9 @@ def save_tile_images( _attributes = {} if attributes: _attributes = self.attributes if self.attributes is not None else {} - output_prefix = output_dir / ( - self.name + ("." + ".".join(_attributes.values())) - ) + output_prefix = self.name + ("." + ".".join(_attributes.values())) else: - output_prefix = output_dir / self.name + output_prefix = self.name hdf5_file = self.hdf5_file # or self.tile_h5 level, size = self.get_tile_coordinate_level_size(hdf5_file) @@ -1135,12 +1283,12 @@ def save_tile_images( sel = pd.Series(range(nc)).sample(frac=frac, n=n).values for coord in coords[sel]: - # Output in the form of: slide_name.attr[0].attr[1].attr[n].x.y.format - fp = output_prefix + f".{coord[0]}.{coord[1]}.{format}" + # Output in the form of: slide_name.attr[0].attr[1].attr[n].x.y.output_format + fp = output_dir / (output_prefix + f".{coord[0]}.{coord[1]}.{output_format}") img = self.wsi.read_region(coord, level=level, size=(size, size)) img.convert("RGB").save(fp) - def process_contour( + def _process_contour( self, cont, contour_holes, @@ -1210,7 +1358,7 @@ def process_contour( else: raise NotImplementedError else: - assert isinstance(contour_fn, Contour_Checking_fn) + assert isinstance(contour_fn, ContourCheckingFn) cont_check_fn = contour_fn step_size_x = step_size * patch_downsample[0] @@ -1230,7 +1378,7 @@ def process_contour( (coord, contour_holes, ref_patch_size[0], cont_check_fn) for coord in coord_candidates ] - results = pool.starmap(WholeSlideImage.process_coord_candidate, iterable) + results = pool.starmap(WholeSlideImage._process_coord_candidate, iterable) pool.close() results = np.array([result for result in results if result is not None]) @@ -1256,14 +1404,15 @@ def process_contour( return {}, {} @staticmethod - def process_coord_candidate(coord, contour_holes, ref_patch_size, cont_check_fn): - if WholeSlideImage.isInContours( + def _process_coord_candidate(coord, contour_holes, ref_patch_size, cont_check_fn): + if WholeSlideImage._is_in_contours( cont_check_fn, coord, contour_holes, ref_patch_size ): return coord else: return None + # TODO: adapt and illustrate usage def visHeatmap( self, scores, @@ -1273,7 +1422,7 @@ def visHeatmap( bot_right=None, patch_size=(256, 256), blank_canvas=False, - canvas_color=(220, 20, 50), + # UNUSED: canvas_color=(220, 20, 50), alpha=0.4, blur=False, overlap=0.0, @@ -1297,7 +1446,7 @@ def visHeatmap( alpha (float [0, 1]): blending coefficient for overlaying heatmap onto original slide blur (bool): apply gaussian blurring overlap (float [0 1]): percentage of overlap between neighboring patches (only affect radius of blurring) - segment (bool): whether to use tissue segmentation contour (must have already called self.segmentTissue such that + segment (bool): whether to use tissue segmentation contour (must have already called self._segment_tissue such that self.contours_tissue and self.holes_tissue are not None use_holes (bool): whether to also clip out detected tissue cavities (only in effect when segment == True) convert_to_percentiles (bool): whether to convert attention scores to percentiles @@ -1307,6 +1456,7 @@ def visHeatmap( custom_downsample (int): additionally downscale the heatmap by specified factor cmap (str): name of matplotlib colormap to use """ + import matplotlib.pyplot as plt if vis_level < 0: vis_level = self.wsi.get_best_level_for_downsample(32) @@ -1406,7 +1556,7 @@ def visHeatmap( ) if segment: - tissue_mask = self.get_seg_mask( + tissue_mask = self._get_seg_mask( region_size, scale, use_holes=use_holes, offset=tuple(top_left) ) # return Image.fromarray(tissue_mask) # tissue mask @@ -1479,7 +1629,7 @@ def visHeatmap( ) if alpha < 1.0: - img = self.block_blending( + img = self._block_blending( img, vis_level, top_left, @@ -1501,7 +1651,7 @@ def visHeatmap( return img - def block_blending( + def _block_blending( self, img, vis_level, @@ -1566,13 +1716,13 @@ def block_blending( ) return img - def get_seg_mask(self, region_size, scale, use_holes=False, offset=(0, 0)): + def _get_seg_mask(self, region_size, scale, use_holes=False, offset=(0, 0)): print("\ncomputing foreground tissue mask") tissue_mask = np.full(np.flip(region_size), 0).astype(np.uint8) - contours_tissue = self.scaleContourDim(self.contours_tissue, scale) + contours_tissue = self._scale_contour_dim(self.contours_tissue, scale) offset = tuple((np.array(offset) * np.array(scale) * -1).astype(np.int32)) - contours_holes = self.scaleHolesDim(self.holes_tissue, scale) + contours_holes = self._scale_holes_dim(self.holes_tissue, scale) contours_tissue, contours_holes = zip( *sorted( zip(contours_tissue, contours_holes), @@ -1599,7 +1749,7 @@ def get_seg_mask(self, region_size, scale, use_holes=False, offset=(0, 0)): offset=offset, thickness=-1, ) - # contours_holes = self._scaleContourDim(self.holes_tissue, scale, holes=True, area_thresh=area_thresh) + # contours_holes = self.__scale_contour_dim(self.holes_tissue, scale, holes=True, area_thresh=area_thresh) tissue_mask = tissue_mask.astype(bool) print( @@ -1608,3 +1758,154 @@ def get_seg_mask(self, region_size, scale, use_holes=False, offset=(0, 0)): ) ) return tissue_mask + + def as_tile_bag(self, **kwargs): + """ + Return a torch.dataset of tiles from the whole slide image. + + Can be customized for example in which transform functions are used using the keyword arguments. + Check wsi.utils.WholeSlideBag for more details. + + Parameters + ---------- + kwargs : dict + Additional keyword arguments to pass to wsi.utils.WholeSlideBag. + """ + from wsi.utils import WholeSlideBag + + # dataset = Whole_Slide_Bag(self.hdf5_file, pretrained=True) + dataset = WholeSlideBag( + self.hdf5_file, self.wsi, pretrained=True, target=self.target, **kwargs + ) + return dataset + + def as_data_loader( + self, + batch_size: int = 32, + with_coords: bool = False, + tile_bag_kwargs: dict = {}, + data_loader_kwargs: dict = {}, + ) -> torch.utils.data.DataLoader: + """ + Return a data loader for the whole slide image. + + Parameters + ---------- + batch_size : int + Number of images per batch in data loader. Default is 32. + with_coords : bool + Whether to include coordinates in data loader. Default is False. + kwargs : dict + Additional keyword arguments to pass to torch.utils.data.DataLoader. + + Returns + ------- + torch.utils.data.DataLoader + """ + from functools import partial + from wsi.utils import collate_features + from torch.utils.data import DataLoader + + collate = partial(collate_features, with_coords=with_coords) + + dataset = self.as_tile_bag(**tile_bag_kwargs) + loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + collate_fn=collate, + **data_loader_kwargs, + ) + return loader + + def inference( + self, + model: torch.nn.Module | str | None = None, + model_repo: str = "pytorch/vision", + device: str | None = None, + data_loader_kws: dict = {}, + ) -> tp.Tuple[np.ndarray, np.ndarray]: + """ + Inference on the WSI using a pretrained model. + + Parameters + ---------- + model_name: str + Name of the model to use for inference. + model_repo: str + Repository to load the model from. Default is "torch/vision". + data_loader_kws: dict + Keyword arguments to pass to the data loader. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Tuple of (features, coordinates). + """ + from typing import cast + import torch + from tqdm_loggable.auto import tqdm + + if isinstance(model, torch.nn.Module): + model = cast(torch.nn.Module, model) + elif isinstance(model, str): + model = torch.hub.load(model_repo, model, weights="DEFAULT") + else: + raise ValueError( + f"model must be a string or a torch.nn.Module, not {type(model)}" + ) + + if device is None: + device = device or "cuda" if torch.cuda.is_available() else "cpu" + + data_loader = self.as_data_loader(**data_loader_kws, with_coords=True) + model.eval() + model = model.to(device) + coords = list() + feats = list() + for batch, coord in tqdm(data_loader): + with torch.no_grad(): + feats.append(model(batch.to(device)).cpu().numpy()) + coords.append(coord) + return np.concatenate(feats, axis=0), np.concatenate(coords, axis=0) + + def as_torch_geometric_data( + self, + feats: np.ndarray | None = None, + coords: np.ndarray | None = None, + model_name: str | None = None, + data_loader_kws: dict = {}, + ) -> torch_geometric.data.Data: + """ + Return a torch_geometric.data.Data object for the whole slide image. + + Parameters + ---------- + feats : np.ndarray + Array of features. + By default, features extracted for tiles using self.inference() with `model_name` are used. + coords : np.ndarray + Array of coordinates. + By default, coordinates for tiles present as output of `slide.tile()` are used. + model_name : str + Name of the model to use for inference. + data_loader_kws : dict + Additional keyword arguments to pass to torch.utils.data.DataLoader. + + Returns + ------- + torch_geometric.data.Data + """ + from torch_geometric.data import Data + + if (feats is None) or (coords is None): + assert ( + model_name is not None + ), "model_name must be provided when feats and coords are None" + feats, coords = self.inference(model_name, data_loader_kws=data_loader_kws) + else: + assert feats is not None, "feats must be provided when coords is not None" + assert coords is not None, "coords must be provided when feats is not None" + + edge_index = self.get_tile_graph() + data = Data(x=feats, edge_index=edge_index, pos=coords) + return data diff --git a/wsi_core/__init__.py b/wsi_core/__init__.py deleted file mode 100644 index 98f65ed..0000000 --- a/wsi_core/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .WholeSlideImage import WholeSlideImage -from ._version import version, __version__ diff --git a/wsi_core/core_utils.py b/wsi_core/core_utils.py deleted file mode 100755 index 6b2cb7d..0000000 --- a/wsi_core/core_utils.py +++ /dev/null @@ -1,656 +0,0 @@ -import os - -import numpy as np -import torch -from sklearn.preprocessing import label_binarize -from sklearn.metrics import roc_auc_score, roc_curve -from sklearn.metrics import auc as calc_auc - -from .utils import * - - -class Accuracy_Logger(object): - """Accuracy logger""" - - def __init__(self, n_classes): - super(Accuracy_Logger, self).__init__() - self.n_classes = n_classes - self.initialize() - - def initialize(self): - self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] - - def log(self, Y_hat, Y): - Y_hat = int(Y_hat) - Y = int(Y) - self.data[Y]["count"] += 1 - self.data[Y]["correct"] += Y_hat == Y - - def log_batch(self, Y_hat, Y): - Y_hat = np.array(Y_hat).astype(int) - Y = np.array(Y).astype(int) - for label_class in np.unique(Y): - cls_mask = Y == label_class - self.data[label_class]["count"] += cls_mask.sum() - self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() - - def get_summary(self, c): - count = self.data[c]["count"] - correct = self.data[c]["correct"] - - if count == 0: - acc = None - else: - acc = float(correct) / count - - return acc, correct, count - - -class EarlyStopping: - """Early stops the training if validation loss doesn't improve after a given patience.""" - - def __init__(self, patience=20, stop_epoch=50, verbose=False): - """ - Args: - patience (int): How long to wait after last time validation loss improved. - Default: 20 - stop_epoch (int): Earliest epoch possible for stopping - verbose (bool): If True, prints a message for each validation loss improvement. - Default: False - """ - self.patience = patience - self.stop_epoch = stop_epoch - self.verbose = verbose - self.counter = 0 - self.best_score = None - self.early_stop = False - self.val_loss_min = np.Inf - - def __call__(self, epoch, val_loss, model, ckpt_name="checkpoint.pt"): - - score = -val_loss - - if self.best_score is None: - self.best_score = score - self.save_checkpoint(val_loss, model, ckpt_name) - elif score < self.best_score: - self.counter += 1 - print(f"EarlyStopping counter: {self.counter} out of {self.patience}") - if self.counter >= self.patience and epoch > self.stop_epoch: - self.early_stop = True - else: - self.best_score = score - self.save_checkpoint(val_loss, model, ckpt_name) - self.counter = 0 - - def save_checkpoint(self, val_loss, model, ckpt_name): - """Saves model when validation loss decrease.""" - if self.verbose: - print( - f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." - ) - torch.save(model.state_dict(), ckpt_name) - self.val_loss_min = val_loss - - -def train(datasets, cur, args): - """ - train for a single fold - """ - print("\nTraining Fold {}!".format(cur)) - writer_dir = os.path.join(args.results_dir, str(cur)) - if not os.path.isdir(writer_dir): - os.mkdir(writer_dir) - - if args.log_data: - from tensorboardX import SummaryWriter - - writer = SummaryWriter(writer_dir, flush_secs=15) - - else: - writer = None - - print("\nInit train/val/test splits...", end=" ") - train_split, val_split, test_split = datasets - save_splits( - datasets, - ["train", "val", "test"], - os.path.join(args.results_dir, "splits_{}.csv".format(cur)), - ) - print("Done!") - print("Training on {} samples".format(len(train_split))) - print("Validating on {} samples".format(len(val_split))) - print("Testing on {} samples".format(len(test_split))) - - print("\nInit loss function...", end=" ") - if args.bag_loss == "svm": - from topk.svm import SmoothTop1SVM - - loss_fn = SmoothTop1SVM(n_classes=args.n_classes) - if device.type == "cuda": - loss_fn = loss_fn.cuda() - else: - loss_fn = nn.CrossEntropyLoss() - print("Done!") - - print("\nInit Model...", end=" ") - model_dict = {"dropout": args.drop_out, "n_classes": args.n_classes} - - if args.model_size is not None and args.model_type != "mil": - model_dict.update({"size_arg": args.model_size}) - - if args.model_type in ["clam_sb", "clam_mb"]: - if args.subtyping: - model_dict.update({"subtyping": True}) - - if args.B > 0: - model_dict.update({"k_sample": args.B}) - - if args.inst_loss == "svm": - from topk.svm import SmoothTop1SVM - - instance_loss_fn = SmoothTop1SVM(n_classes=2) - if device.type == "cuda": - instance_loss_fn = instance_loss_fn.cuda() - else: - instance_loss_fn = nn.CrossEntropyLoss() - - if args.model_type == "clam_sb": - model = CLAM_SB(**model_dict, instance_loss_fn=instance_loss_fn) - elif args.model_type == "clam_mb": - model = CLAM_MB(**model_dict, instance_loss_fn=instance_loss_fn) - else: - raise NotImplementedError - - else: # args.model_type == 'mil' - if args.n_classes > 2: - model = MIL_fc_mc(**model_dict) - else: - model = MIL_fc(**model_dict) - - model.relocate() - print("Done!") - print_network(model) - - print("\nInit optimizer ...", end=" ") - optimizer = get_optim(model, args) - print("Done!") - - print("\nInit Loaders...", end=" ") - train_loader = get_split_loader( - train_split, training=True, testing=args.testing, weighted=args.weighted_sample - ) - val_loader = get_split_loader(val_split, testing=args.testing) - test_loader = get_split_loader(test_split, testing=args.testing) - print("Done!") - - print("\nSetup EarlyStopping...", end=" ") - if args.early_stopping: - early_stopping = EarlyStopping(patience=20, stop_epoch=50, verbose=True) - - else: - early_stopping = None - print("Done!") - - for epoch in range(args.max_epochs): - if args.model_type in ["clam_sb", "clam_mb"] and not args.no_inst_cluster: - train_loop_clam( - epoch, - model, - train_loader, - optimizer, - args.n_classes, - args.bag_weight, - writer, - loss_fn, - ) - stop = validate_clam( - cur, - epoch, - model, - val_loader, - args.n_classes, - early_stopping, - writer, - loss_fn, - args.results_dir, - ) - - else: - train_loop( - epoch, model, train_loader, optimizer, args.n_classes, writer, loss_fn - ) - stop = validate( - cur, - epoch, - model, - val_loader, - args.n_classes, - early_stopping, - writer, - loss_fn, - args.results_dir, - ) - - if stop: - break - - if args.early_stopping: - model.load_state_dict( - torch.load(os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur))) - ) - else: - torch.save( - model.state_dict(), - os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)), - ) - - _, val_error, val_auc, _ = summary(model, val_loader, args.n_classes) - print("Val error: {:.4f}, ROC AUC: {:.4f}".format(val_error, val_auc)) - - results_dict, test_error, test_auc, acc_logger = summary( - model, test_loader, args.n_classes - ) - print("Test error: {:.4f}, ROC AUC: {:.4f}".format(test_error, test_auc)) - - for i in range(args.n_classes): - acc, correct, count = acc_logger.get_summary(i) - print("class {}: acc {}, correct {}/{}".format(i, acc, correct, count)) - - if writer: - writer.add_scalar("final/test_class_{}_acc".format(i), acc, 0) - - if writer: - writer.add_scalar("final/val_error", val_error, 0) - writer.add_scalar("final/val_auc", val_auc, 0) - writer.add_scalar("final/test_error", test_error, 0) - writer.add_scalar("final/test_auc", test_auc, 0) - writer.close() - return results_dict, test_auc, val_auc, 1 - test_error, 1 - val_error - - -def train_loop_clam( - epoch, model, loader, optimizer, n_classes, bag_weight, writer=None, loss_fn=None -): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.train() - acc_logger = Accuracy_Logger(n_classes=n_classes) - inst_logger = Accuracy_Logger(n_classes=n_classes) - - train_loss = 0.0 - train_error = 0.0 - train_inst_loss = 0.0 - inst_count = 0 - - print("\n") - for batch_idx, (data, label) in enumerate(loader): - data, label = data.to(device), label.to(device) - logits, Y_prob, Y_hat, _, instance_dict = model( - data, label=label, instance_eval=True - ) - - acc_logger.log(Y_hat, label) - loss = loss_fn(logits, label) - loss_value = loss.item() - - instance_loss = instance_dict["instance_loss"] - inst_count += 1 - instance_loss_value = instance_loss.item() - train_inst_loss += instance_loss_value - - total_loss = bag_weight * loss + (1 - bag_weight) * instance_loss - - inst_preds = instance_dict["inst_preds"] - inst_labels = instance_dict["inst_labels"] - inst_logger.log_batch(inst_preds, inst_labels) - - train_loss += loss_value - if (batch_idx + 1) % 20 == 0: - print( - "batch {}, loss: {:.4f}, instance_loss: {:.4f}, weighted_loss: {:.4f}, ".format( - batch_idx, loss_value, instance_loss_value, total_loss.item() - ) - + "label: {}, bag_size: {}".format(label.item(), data.size(0)) - ) - - error = calculate_error(Y_hat, label) - train_error += error - - # backward pass - total_loss.backward() - # step - optimizer.step() - optimizer.zero_grad() - - # calculate loss and error for epoch - train_loss /= len(loader) - train_error /= len(loader) - - if inst_count > 0: - train_inst_loss /= inst_count - print("\n") - for i in range(2): - acc, correct, count = inst_logger.get_summary(i) - print( - "class {} clustering acc {}: correct {}/{}".format(i, acc, correct, count) - ) - - print( - "Epoch: {}, train_loss: {:.4f}, train_clustering_loss: {:.4f}, train_error: {:.4f}".format( - epoch, train_loss, train_inst_loss, train_error - ) - ) - for i in range(n_classes): - acc, correct, count = acc_logger.get_summary(i) - print("class {}: acc {}, correct {}/{}".format(i, acc, correct, count)) - if writer and acc is not None: - writer.add_scalar("train/class_{}_acc".format(i), acc, epoch) - - if writer: - writer.add_scalar("train/loss", train_loss, epoch) - writer.add_scalar("train/error", train_error, epoch) - writer.add_scalar("train/clustering_loss", train_inst_loss, epoch) - - -def train_loop(epoch, model, loader, optimizer, n_classes, writer=None, loss_fn=None): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.train() - acc_logger = Accuracy_Logger(n_classes=n_classes) - train_loss = 0.0 - train_error = 0.0 - - print("\n") - for batch_idx, (data, label) in enumerate(loader): - data, label = data.to(device), label.to(device) - - logits, Y_prob, Y_hat, _, _ = model(data) - - acc_logger.log(Y_hat, label) - loss = loss_fn(logits, label) - loss_value = loss.item() - - train_loss += loss_value - if (batch_idx + 1) % 20 == 0: - print( - "batch {}, loss: {:.4f}, label: {}, bag_size: {}".format( - batch_idx, loss_value, label.item(), data.size(0) - ) - ) - - error = calculate_error(Y_hat, label) - train_error += error - - # backward pass - loss.backward() - # step - optimizer.step() - optimizer.zero_grad() - - # calculate loss and error for epoch - train_loss /= len(loader) - train_error /= len(loader) - - print( - "Epoch: {}, train_loss: {:.4f}, train_error: {:.4f}".format( - epoch, train_loss, train_error - ) - ) - for i in range(n_classes): - acc, correct, count = acc_logger.get_summary(i) - print("class {}: acc {}, correct {}/{}".format(i, acc, correct, count)) - if writer: - writer.add_scalar("train/class_{}_acc".format(i), acc, epoch) - - if writer: - writer.add_scalar("train/loss", train_loss, epoch) - writer.add_scalar("train/error", train_error, epoch) - - -def validate( - cur, - epoch, - model, - loader, - n_classes, - early_stopping=None, - writer=None, - loss_fn=None, - results_dir=None, -): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.eval() - acc_logger = Accuracy_Logger(n_classes=n_classes) - # loader.dataset.update_mode(True) - val_loss = 0.0 - val_error = 0.0 - - prob = np.zeros((len(loader), n_classes)) - labels = np.zeros(len(loader)) - - with torch.no_grad(): - for batch_idx, (data, label) in enumerate(loader): - data, label = data.to(device, non_blocking=True), label.to( - device, non_blocking=True - ) - - logits, Y_prob, Y_hat, _, _ = model(data) - - acc_logger.log(Y_hat, label) - - loss = loss_fn(logits, label) - - prob[batch_idx] = Y_prob.cpu().numpy() - labels[batch_idx] = label.item() - - val_loss += loss.item() - error = calculate_error(Y_hat, label) - val_error += error - - val_error /= len(loader) - val_loss /= len(loader) - - if n_classes == 2: - auc = roc_auc_score(labels, prob[:, 1]) - - else: - auc = roc_auc_score(labels, prob, multi_class="ovr") - - if writer: - writer.add_scalar("val/loss", val_loss, epoch) - writer.add_scalar("val/auc", auc, epoch) - writer.add_scalar("val/error", val_error, epoch) - - print( - "\nVal Set, val_loss: {:.4f}, val_error: {:.4f}, auc: {:.4f}".format( - val_loss, val_error, auc - ) - ) - for i in range(n_classes): - acc, correct, count = acc_logger.get_summary(i) - print("class {}: acc {}, correct {}/{}".format(i, acc, correct, count)) - - if early_stopping: - assert results_dir - early_stopping( - epoch, - val_loss, - model, - ckpt_name=os.path.join(results_dir, "s_{}_checkpoint.pt".format(cur)), - ) - - if early_stopping.early_stop: - print("Early stopping") - return True - - return False - - -def validate_clam( - cur, - epoch, - model, - loader, - n_classes, - early_stopping=None, - writer=None, - loss_fn=None, - results_dir=None, -): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.eval() - acc_logger = Accuracy_Logger(n_classes=n_classes) - inst_logger = Accuracy_Logger(n_classes=n_classes) - val_loss = 0.0 - val_error = 0.0 - - val_inst_loss = 0.0 - val_inst_acc = 0.0 - inst_count = 0 - - prob = np.zeros((len(loader), n_classes)) - labels = np.zeros(len(loader)) - sample_size = model.k_sample - with torch.no_grad(): - for batch_idx, (data, label) in enumerate(loader): - data, label = data.to(device), label.to(device) - logits, Y_prob, Y_hat, _, instance_dict = model( - data, label=label, instance_eval=True - ) - acc_logger.log(Y_hat, label) - - loss = loss_fn(logits, label) - - val_loss += loss.item() - - instance_loss = instance_dict["instance_loss"] - - inst_count += 1 - instance_loss_value = instance_loss.item() - val_inst_loss += instance_loss_value - - inst_preds = instance_dict["inst_preds"] - inst_labels = instance_dict["inst_labels"] - inst_logger.log_batch(inst_preds, inst_labels) - - prob[batch_idx] = Y_prob.cpu().numpy() - labels[batch_idx] = label.item() - - error = calculate_error(Y_hat, label) - val_error += error - - val_error /= len(loader) - val_loss /= len(loader) - - if n_classes == 2: - auc = roc_auc_score(labels, prob[:, 1]) - aucs = [] - else: - aucs = [] - binary_labels = label_binarize(labels, classes=[i for i in range(n_classes)]) - for class_idx in range(n_classes): - if class_idx in labels: - fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], prob[:, class_idx]) - aucs.append(calc_auc(fpr, tpr)) - else: - aucs.append(float("nan")) - - auc = np.nanmean(np.array(aucs)) - - print( - "\nVal Set, val_loss: {:.4f}, val_error: {:.4f}, auc: {:.4f}".format( - val_loss, val_error, auc - ) - ) - if inst_count > 0: - val_inst_loss /= inst_count - for i in range(2): - acc, correct, count = inst_logger.get_summary(i) - print( - "class {} clustering acc {}: correct {}/{}".format(i, acc, correct, count) - ) - - if writer: - writer.add_scalar("val/loss", val_loss, epoch) - writer.add_scalar("val/auc", auc, epoch) - writer.add_scalar("val/error", val_error, epoch) - writer.add_scalar("val/inst_loss", val_inst_loss, epoch) - - for i in range(n_classes): - acc, correct, count = acc_logger.get_summary(i) - print("class {}: acc {}, correct {}/{}".format(i, acc, correct, count)) - - if writer and acc is not None: - writer.add_scalar("val/class_{}_acc".format(i), acc, epoch) - - if early_stopping: - assert results_dir - early_stopping( - epoch, - val_loss, - model, - ckpt_name=os.path.join(results_dir, "s_{}_checkpoint.pt".format(cur)), - ) - - if early_stopping.early_stop: - print("Early stopping") - return True - - return False - - -def summary(model, loader, n_classes): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - acc_logger = Accuracy_Logger(n_classes=n_classes) - model.eval() - test_loss = 0.0 - test_error = 0.0 - - all_probs = np.zeros((len(loader), n_classes)) - all_labels = np.zeros(len(loader)) - - slide_ids = loader.dataset.slide_data["slide_id"] - patient_results = {} - - for batch_idx, (data, label) in enumerate(loader): - data, label = data.to(device), label.to(device) - slide_id = slide_ids.iloc[batch_idx] - with torch.no_grad(): - logits, Y_prob, Y_hat, _, _ = model(data) - - acc_logger.log(Y_hat, label) - probs = Y_prob.cpu().numpy() - all_probs[batch_idx] = probs - all_labels[batch_idx] = label.item() - - patient_results.update( - { - slide_id: { - "slide_id": np.array(slide_id), - "prob": probs, - "label": label.item(), - } - } - ) - error = calculate_error(Y_hat, label) - test_error += error - - test_error /= len(loader) - - if n_classes == 2: - auc = roc_auc_score(all_labels, all_probs[:, 1]) - aucs = [] - else: - aucs = [] - binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) - for class_idx in range(n_classes): - if class_idx in all_labels: - fpr, tpr, _ = roc_curve( - binary_labels[:, class_idx], all_probs[:, class_idx] - ) - aucs.append(calc_auc(fpr, tpr)) - else: - aucs.append(float("nan")) - - auc = np.nanmean(np.array(aucs)) - - return patient_results, test_error, auc, acc_logger diff --git a/wsi_core/dataset_h5.py b/wsi_core/dataset_h5.py deleted file mode 100644 index aba6691..0000000 --- a/wsi_core/dataset_h5.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import print_function, division - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset -from torchvision import transforms - -from PIL import Image -import h5py - - -def eval_transforms(pretrained=False): - if pretrained: - mean = (0.485, 0.456, 0.406) - std = (0.229, 0.224, 0.225) - - else: - mean = (0.5, 0.5, 0.5) - std = (0.5, 0.5, 0.5) - - trnsfrms_val = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)] - ) - - return trnsfrms_val - - -class Whole_Slide_Bag(Dataset): - def __init__( - self, - file_path, - pretrained=False, - custom_transforms=None, - target_patch_size=-1, - target=None, - ): - """ - Args: - file_path (string): Path to the .h5 file containing patched data. - pretrained (bool): Use ImageNet transforms - custom_transforms (callable, optional): Optional transform to be applied on a sample - """ - self.target = target - - self.pretrained = pretrained - if target_patch_size > 0: - self.target_patch_size = (target_patch_size, target_patch_size) - else: - self.target_patch_size = None - - if not custom_transforms: - self.roi_transforms = eval_transforms(pretrained=pretrained) - else: - self.roi_transforms = custom_transforms - - self.file_path = file_path - - with h5py.File(self.file_path, "r") as f: - dset = f["imgs"] - self.length = len(dset) - - self.summary() - - def __len__(self): - return self.length - - def summary(self): - hdf5_file = h5py.File(self.file_path, "r") - dset = hdf5_file["imgs"] - for name, value in dset.attrs.items(): - print(name, value) - - print("pretrained:", self.pretrained) - print("transformations:", self.roi_transforms) - if self.target_patch_size is not None: - print("target_size: ", self.target_patch_size) - - def __getitem__(self, idx): - with h5py.File(self.file_path, "r") as hdf5_file: - img = hdf5_file["imgs"][idx] - coord = hdf5_file["coords"][idx] - - img = Image.fromarray(img) - if self.target_patch_size is not None: - img = img.resize(self.target_patch_size) - img = self.roi_transforms(img).unsqueeze(0) - if self.target is None: - return img, coord - return img, self.target - - -class Whole_Slide_Bag_FP(Dataset): - def __init__( - self, - file_path, - wsi, - pretrained=False, - custom_transforms=None, - custom_downsample=1, - target_patch_size=-1, - target=None, - ): - """ - Args: - file_path (string): Path to the .h5 file containing patched data. - pretrained (bool): Use ImageNet transforms - custom_transforms (callable, optional): Optional transform to be applied on a sample - custom_downsample (int): Custom defined downscale factor (overruled by target_patch_size) - target_patch_size (int): Custom defined image size before embedding - """ - self.target = target - - self.pretrained = pretrained - self.wsi = wsi - if not custom_transforms: - self.roi_transforms = eval_transforms(pretrained=pretrained) - else: - self.roi_transforms = custom_transforms - - self.file_path = file_path - - with h5py.File(self.file_path, "r") as f: - dset = f["coords"] - self.patch_level = f["coords"].attrs["patch_level"] - self.patch_size = f["coords"].attrs["patch_size"] - self.length = len(dset) - if target_patch_size > 0: - self.target_patch_size = (target_patch_size,) * 2 - elif custom_downsample > 1: - self.target_patch_size = (self.patch_size // custom_downsample,) * 2 - else: - self.target_patch_size = None - self.summary() - - def __len__(self): - return self.length - - def summary(self): - hdf5_file = h5py.File(self.file_path, "r") - dset = hdf5_file["coords"] - for name, value in dset.attrs.items(): - print(name, value) - - # print("\nfeature extraction settings") - # print("target patch size: ", self.target_patch_size) - # print("pretrained: ", self.pretrained) - # print("transformations: ", self.roi_transforms) - - def __getitem__(self, idx): - with h5py.File(self.file_path, "r") as hdf5_file: - coord = hdf5_file["coords"][idx] - img = self.wsi.read_region( - coord, self.patch_level, (self.patch_size, self.patch_size) - ).convert("RGB") - - if self.target_patch_size is not None: - img = img.resize(self.target_patch_size) - img = self.roi_transforms(img).unsqueeze(0) - if self.target is None: - return img, coord - return img, self.target - - -class Dataset_All_Bags(Dataset): - def __init__(self, csv_path): - self.df = pd.read_csv(csv_path) - - def __len__(self): - return len(self.df) - - def __getitem__(self, idx): - return self.df["slide_id"][idx] diff --git a/wsi_core/file_utils.py b/wsi_core/file_utils.py deleted file mode 100755 index 69ec6ad..0000000 --- a/wsi_core/file_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -import pickle - -import h5py - - -def save_pkl(filename, save_object): - writer = open(filename, "wb") - pickle.dump(save_object, writer) - writer.close() - - -def load_pkl(filename): - loader = open(filename, "rb") - file = pickle.load(loader) - loader.close() - return file - - -def save_hdf5(output_path, asset_dict, attr_dict=None, mode="a"): - file = h5py.File(output_path, mode) - for key, val in asset_dict.items(): - data_shape = val.shape - if key not in file: - data_type = val.dtype - chunk_shape = (1,) + data_shape[1:] - maxshape = (None,) + data_shape[1:] - dset = file.create_dataset( - key, - shape=data_shape, - maxshape=maxshape, - chunks=chunk_shape, - dtype=data_type, - ) - dset[:] = val - if attr_dict is not None: - if key in attr_dict.keys(): - for attr_key, attr_val in attr_dict[key].items(): - dset.attrs[attr_key] = attr_val - else: - dset = file[key] - dset.resize(len(dset) + data_shape[0], axis=0) - dset[-data_shape[0] :] = val - file.close() - return output_path diff --git a/wsi_core/util_classes.py b/wsi_core/util_classes.py deleted file mode 100644 index 4fed467..0000000 --- a/wsi_core/util_classes.py +++ /dev/null @@ -1,159 +0,0 @@ -import os - -import numpy as np -from PIL import Image -import cv2 - - -class Mosaic_Canvas(object): - def __init__( - self, - patch_size=256, - n=100, - downscale=4, - n_per_row=10, - bg_color=(0, 0, 0), - alpha=-1, - ): - self.patch_size = patch_size - self.downscaled_patch_size = int(np.ceil(patch_size / downscale)) - self.n_rows = int(np.ceil(n / n_per_row)) - self.n_cols = n_per_row - w = self.n_cols * self.downscaled_patch_size - h = self.n_rows * self.downscaled_patch_size - if alpha < 0: - canvas = Image.new(size=(w, h), mode="RGB", color=bg_color) - else: - canvas = Image.new( - size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) - ) - - self.canvas = canvas - self.dimensions = np.array([w, h]) - self.reset_coord() - - def reset_coord(self): - self.coord = np.array([0, 0]) - - def increment_coord(self): - # print('current coord: {} x {} / {} x {}'.format(self.coord[0], self.coord[1], self.dimensions[0], self.dimensions[1])) - assert np.all(self.coord <= self.dimensions) - if ( - self.coord[0] + self.downscaled_patch_size - <= self.dimensions[0] - self.downscaled_patch_size - ): - self.coord[0] += self.downscaled_patch_size - else: - self.coord[0] = 0 - self.coord[1] += self.downscaled_patch_size - - def save(self, save_path, **kwargs): - self.canvas.save(save_path, **kwargs) - - def paste_patch(self, patch): - assert patch.size[0] == self.patch_size - assert patch.size[1] == self.patch_size - self.canvas.paste( - patch.resize(tuple([self.downscaled_patch_size, self.downscaled_patch_size])), - tuple(self.coord), - ) - self.increment_coord() - - def get_painting(self): - return self.canvas - - -class Contour_Checking_fn(object): - # Defining __call__ method - def __call__(self, pt): - raise NotImplementedError - - -class isInContourV1(Contour_Checking_fn): - def __init__(self, contour): - self.cont = contour - - def __call__(self, pt): - return ( - 1 - if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) - >= 0 - else 0 - ) - - -class isInContourV2(Contour_Checking_fn): - def __init__(self, contour, patch_size): - self.cont = contour - self.patch_size = patch_size - - def __call__(self, pt): - pt = np.array( - (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) - ).astype(float) - return ( - 1 - if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) - >= 0 - else 0 - ) - - -# Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass -class isInContourV3_Easy(Contour_Checking_fn): - def __init__(self, contour, patch_size, center_shift=0.5): - self.cont = contour - self.patch_size = patch_size - self.shift = int(patch_size // 2 * center_shift) - - def __call__(self, pt): - center = (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) - if self.shift > 0: - all_points = [ - (center[0] - self.shift, center[1] - self.shift), - (center[0] + self.shift, center[1] + self.shift), - (center[0] + self.shift, center[1] - self.shift), - (center[0] - self.shift, center[1] + self.shift), - ] - else: - all_points = [center] - - for points in all_points: - if ( - cv2.pointPolygonTest( - self.cont, tuple(np.array(points).astype(float)), False - ) - >= 0 - ): - return 1 - return 0 - - -# Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass -class isInContourV3_Hard(Contour_Checking_fn): - def __init__(self, contour, patch_size, center_shift=0.5): - self.cont = contour - self.patch_size = patch_size - self.shift = int(patch_size // 2 * center_shift) - - def __call__(self, pt): - center = (pt[0] + self.patch_size // 2, pt[1] + self.patch_size // 2) - if self.shift > 0: - all_points = [ - (center[0] - self.shift, center[1] - self.shift), - (center[0] + self.shift, center[1] + self.shift), - (center[0] + self.shift, center[1] - self.shift), - (center[0] - self.shift, center[1] + self.shift), - ] - else: - all_points = [center] - - for points in all_points: - if ( - cv2.pointPolygonTest( - self.cont, tuple(np.array(points).astype(float)), False - ) - < 0 - ): - return 0 - return 1 diff --git a/wsi_core/utils.py b/wsi_core/utils.py deleted file mode 100755 index e1ca294..0000000 --- a/wsi_core/utils.py +++ /dev/null @@ -1,314 +0,0 @@ -from __future__ import annotations -import os -import typing as tp -import math -from itertools import islice -import collections -import pathlib - -import torch -import numpy as np -import torch.nn as nn -from torch.utils.data import ( - DataLoader, - Sampler, - WeightedRandomSampler, - RandomSampler, - SequentialSampler, - sampler, -) -import torch.optim as optim - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -class Path(pathlib.Path): - """ - A pathlib.Path child class that allows concatenation with strings - by overloading the addition operator. - - In addition, it implements the ``startswith`` and ``endswith`` methods - just like in the base :obj:`str` type. - - The ``replace_`` implementation is meant to be an implementation closer - to the :obj:`str` type. - - Iterating over a directory with ``iterdir`` that does not exists - will return an empty iterator instead of throwing an error. - - Creating a directory with ``mkdir`` allows existing directory and - creates parents by default. - """ - - _flavour = ( - pathlib._windows_flavour # type: ignore[attr-defined] # pylint: disable=W0212 - if os.name == "nt" - else pathlib._posix_flavour # type: ignore[attr-defined] # pylint: disable=W0212 - ) - - def __add__(self, string: str) -> Path: - return Path(str(self) + string) - - def startswith(self, string: str) -> bool: - return str(self).startswith(string) - - def endswith(self, string: str) -> bool: - return str(self).endswith(string) - - def replace_(self, patt: str, repl: str) -> Path: - return Path(str(self).replace(patt, repl)) - - def iterdir(self) -> tp.Generator: - if self.exists(): - yield from [Path(x) for x in pathlib.Path(str(self)).iterdir()] - yield from [] - - def unlink(self, missing_ok: bool = True) -> Path: - super().unlink(missing_ok=missing_ok) - return self - - def mkdir(self, mode=0o777, parents: bool = True, exist_ok: bool = True) -> Path: - super().mkdir(mode=mode, parents=parents, exist_ok=exist_ok) - return self - - def glob(self, pattern: str) -> tp.Generator: - # to support ** with symlinks: https://bugs.python.org/issue33428 - from glob import glob - - if "**" in pattern: - sep = "/" if self.is_dir() else "" - yield from map( - Path, - glob(self.as_posix() + sep + pattern, recursive=True), - ) - else: - yield from super().glob(pattern) - - -class SubsetSequentialSampler(Sampler): - """Samples elements sequentially from a given list of indices, without replacement. - - Arguments: - indices (sequence): a sequence of indices - """ - - def __init__(self, indices): - self.indices = indices - - def __iter__(self): - return iter(self.indices) - - def __len__(self): - return len(self.indices) - - -def filter_kwargs_by_callable( - kwargs: tp.Dict[str, tp.Any], callabl: tp.Callable, exclude: tp.List[str] = None -) -> tp.Dict[str, tp.Any]: - """Filter a dictionary keeping only the keys which are part of a function signature.""" - from inspect import signature - - args = signature(callabl).parameters.keys() - return {k: v for k, v in kwargs.items() if (k in args) and k not in (exclude or [])} - - -def collate_MIL(batch): - img = torch.cat([item[0] for item in batch], dim=0) - label = torch.LongTensor([item[1] for item in batch]) - return [img, label] - - -def collate_features(batch, with_coords: bool = False): - img = torch.cat([item[0] for item in batch], dim=0) - if not with_coords: - return img - coords = np.vstack([item[1] for item in batch]) - return [img, coords] - - -def get_simple_loader(dataset, batch_size=1, num_workers=1): - kwargs = ( - {"num_workers": 4, "pin_memory": False, "num_workers": num_workers} - if device.type == "cuda" - else {} - ) - loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler.SequentialSampler(dataset), - collate_fn=collate_MIL, - **kwargs, - ) - return loader - - -def get_split_loader(split_dataset, training=False, testing=False, weighted=False): - """ - return either the validation loader or training loader - """ - kwargs = {"num_workers": 4} if device.type == "cuda" else {} - if not testing: - if training: - if weighted: - weights = make_weights_for_balanced_classes_split(split_dataset) - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=WeightedRandomSampler(weights, len(weights)), - collate_fn=collate_MIL, - **kwargs, - ) - else: - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=RandomSampler(split_dataset), - collate_fn=collate_MIL, - **kwargs, - ) - else: - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=SequentialSampler(split_dataset), - collate_fn=collate_MIL, - **kwargs, - ) - - else: - ids = np.random.choice( - np.arange(len(split_dataset), int(len(split_dataset) * 0.1)), replace=False - ) - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=SubsetSequentialSampler(ids), - collate_fn=collate_MIL, - **kwargs, - ) - - return loader - - -def get_optim(model, args): - if args.opt == "adam": - optimizer = optim.Adam( - filter(lambda p: p.requires_grad, model.parameters()), - lr=args.lr, - weight_decay=args.reg, - ) - elif args.opt == "sgd": - optimizer = optim.SGD( - filter(lambda p: p.requires_grad, model.parameters()), - lr=args.lr, - momentum=0.9, - weight_decay=args.reg, - ) - else: - raise NotImplementedError - return optimizer - - -def print_network(net): - num_params = 0 - num_params_train = 0 - print(net) - - for param in net.parameters(): - n = param.numel() - num_params += n - if param.requires_grad: - num_params_train += n - - print("Total number of parameters: %d" % num_params) - print("Total number of trainable parameters: %d" % num_params_train) - - -def generate_split( - cls_ids, - val_num, - test_num, - samples, - n_splits=5, - seed=7, - label_frac=1.0, - custom_test_ids=None, -): - indices = np.arange(samples).astype(int) - - if custom_test_ids is not None: - indices = np.setdiff1d(indices, custom_test_ids) - - np.random.seed(seed) - for i in range(n_splits): - all_val_ids = [] - all_test_ids = [] - sampled_train_ids = [] - - if custom_test_ids is not None: # pre-built test split, do not need to sample - all_test_ids.extend(custom_test_ids) - - for c in range(len(val_num)): - possible_indices = np.intersect1d( - cls_ids[c], indices - ) # all indices of this class - val_ids = np.random.choice( - possible_indices, val_num[c], replace=False - ) # validation ids - - remaining_ids = np.setdiff1d( - possible_indices, val_ids - ) # indices of this class left after validation - all_val_ids.extend(val_ids) - - if custom_test_ids is None: # sample test split - test_ids = np.random.choice(remaining_ids, test_num[c], replace=False) - remaining_ids = np.setdiff1d(remaining_ids, test_ids) - all_test_ids.extend(test_ids) - - if label_frac == 1: - sampled_train_ids.extend(remaining_ids) - - else: - sample_num = math.ceil(len(remaining_ids) * label_frac) - slice_ids = np.arange(sample_num) - sampled_train_ids.extend(remaining_ids[slice_ids]) - - yield sampled_train_ids, all_val_ids, all_test_ids - - -def nth(iterator, n, default=None): - if n is None: - return collections.deque(iterator, maxlen=0) - else: - return next(islice(iterator, n, None), default) - - -def calculate_error(Y_hat, Y): - error = 1.0 - Y_hat.float().eq(Y.float()).float().mean().item() - - return error - - -def make_weights_for_balanced_classes_split(dataset): - N = float(len(dataset)) - weight_per_class = [ - N / len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids)) - ] - weight = [0] * int(N) - for idx in range(len(dataset)): - y = dataset.getlabel(idx) - weight[idx] = weight_per_class[y] - - return torch.DoubleTensor(weight) - - -def initialize_weights(module): - for m in module.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight) - m.bias.data.zero_() - - elif isinstance(m, nn.BatchNorm1d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) diff --git a/wsi_core/wsi_utils.py b/wsi_core/wsi_utils.py deleted file mode 100755 index 64451f7..0000000 --- a/wsi_core/wsi_utils.py +++ /dev/null @@ -1,500 +0,0 @@ -import os -import math - -import h5py -import numpy as np -from PIL import Image -import cv2 - -from .util_classes import Mosaic_Canvas - - -def isWhitePatch(patch, satThresh=5): - patch_hsv = cv2.cvtColor(patch, cv2.COLOR_RGB2HSV) - return True if np.mean(patch_hsv[:, :, 1]) < satThresh else False - - -def isBlackPatch(patch, rgbThresh=40): - return True if np.all(np.mean(patch, axis=(0, 1)) < rgbThresh) else False - - -def isBlackPatch_S(patch, rgbThresh=20, percentage=0.05): - num_pixels = patch.size[0] * patch.size[1] - return ( - True - if np.all(np.array(patch) < rgbThresh, axis=(2)).sum() > num_pixels * percentage - else False - ) - - -def isWhitePatch_S(patch, rgbThresh=220, percentage=0.2): - num_pixels = patch.size[0] * patch.size[1] - return ( - True - if np.all(np.array(patch) > rgbThresh, axis=(2)).sum() > num_pixels * percentage - else False - ) - - -def coord_generator(x_start, x_end, x_step, y_start, y_end, y_step, args_dict=None): - for x in range(x_start, x_end, x_step): - for y in range(y_start, y_end, y_step): - if args_dict is not None: - process_dict = args_dict.copy() - process_dict.update({"pt": (x, y)}) - yield process_dict - else: - yield (x, y) - - -def savePatchIter_bag_hdf5(patch): - ( - x, - y, - cont_idx, - patch_level, - downsample, - downsampled_level_dim, - level_dim, - img_patch, - name, - save_path, - ) = tuple(patch.values()) - img_patch = np.array(img_patch)[np.newaxis, ...] - img_shape = img_patch.shape - - file_path = os.path.join(save_path, name) + ".h5" - file = h5py.File(file_path, "a") - - dset = file["imgs"] - dset.resize(len(dset) + img_shape[0], axis=0) - dset[-img_shape[0] :] = img_patch - - if "coords" in file: - coord_dset = file["coords"] - coord_dset.resize(len(coord_dset) + img_shape[0], axis=0) - coord_dset[-img_shape[0] :] = (x, y) - - file.close() - - -def save_hdf5(output_path, asset_dict, attr_dict=None, mode="a"): - file = h5py.File(output_path, mode) - for key, val in asset_dict.items(): - data_shape = val.shape - if key not in file: - data_type = val.dtype - chunk_shape = (1,) + data_shape[1:] - maxshape = (None,) + data_shape[1:] - dset = file.create_dataset( - key, - shape=data_shape, - maxshape=maxshape, - chunks=chunk_shape, - dtype=data_type, - ) - dset[:] = val - if attr_dict is not None: - if key in attr_dict.keys(): - for attr_key, attr_val in attr_dict[key].items(): - dset.attrs[attr_key] = attr_val - else: - dset = file[key] - dset.resize(len(dset) + data_shape[0], axis=0) - dset[-data_shape[0] :] = val - file.close() - return output_path - - -def initialize_hdf5_bag(first_patch, save_coord=False): - ( - x, - y, - cont_idx, - patch_level, - downsample, - downsampled_level_dim, - level_dim, - img_patch, - name, - save_path, - ) = tuple(first_patch.values()) - file_path = save_path / name + ".h5" - file = h5py.File(file_path, "w") - img_patch = np.array(img_patch)[np.newaxis, ...] - dtype = img_patch.dtype - - # Initialize a resizable dataset to hold the output - img_shape = img_patch.shape - maxshape = (None,) + img_shape[ - 1: - ] # maximum dimensions up to which dataset maybe resized (None means unlimited) - dset = file.create_dataset( - "imgs", shape=img_shape, maxshape=maxshape, chunks=img_shape, dtype=dtype - ) - - dset[:] = img_patch - dset.attrs["patch_level"] = patch_level - dset.attrs["wsi_name"] = name - dset.attrs["downsample"] = downsample - dset.attrs["level_dim"] = level_dim - dset.attrs["downsampled_level_dim"] = downsampled_level_dim - - if save_coord: - coord_dset = file.create_dataset( - "coords", shape=(1, 2), maxshape=(None, 2), chunks=(1, 2), dtype=np.int32 - ) - coord_dset[:] = (x, y) - - file.close() - return file_path - - -def sample_indices(scores, k, start=0.48, end=0.52, convert_to_percentile=False, seed=1): - np.random.seed(seed) - if convert_to_percentile: - end_value = np.quantile(scores, end) - start_value = np.quantile(scores, start) - else: - end_value = end - start_value = start - score_window = np.logical_and(scores >= start_value, scores <= end_value) - indices = np.where(score_window)[0] - if len(indices) < 1: - return -1 - else: - return np.random.choice(indices, min(k, len(indices)), replace=False) - - -def top_k(scores, k, invert=False): - if invert: - top_k_ids = scores.argsort()[:k] - else: - top_k_ids = scores.argsort()[::-1][:k] - return top_k_ids - - -def to_percentiles(scores): - from scipy.stats import rankdata - - scores = rankdata(scores, "average") / len(scores) * 100 - return scores - - -def screen_coords(scores, coords, top_left, bot_right): - bot_right = np.array(bot_right) - top_left = np.array(top_left) - mask = np.logical_and( - np.all(coords >= top_left, axis=1), np.all(coords <= bot_right, axis=1) - ) - scores = scores[mask] - coords = coords[mask] - return scores, coords - - -def sample_rois( - scores, - coords, - k=5, - mode="range_sample", - seed=1, - score_start=0.45, - score_end=0.55, - top_left=None, - bot_right=None, -): - - if len(scores.shape) == 2: - scores = scores.flatten() - - scores = to_percentiles(scores) - if top_left is not None and bot_right is not None: - scores, coords = screen_coords(scores, coords, top_left, bot_right) - - if mode == "range_sample": - sampled_ids = sample_indices( - scores, - start=score_start, - end=score_end, - k=k, - convert_to_percentile=False, - seed=seed, - ) - elif mode == "topk": - sampled_ids = top_k(scores, k, invert=False) - elif mode == "reverse_topk": - sampled_ids = top_k(scores, k, invert=True) - else: - raise NotImplementedError - coords = coords[sampled_ids] - scores = scores[sampled_ids] - - asset = {"sampled_coords": coords, "sampled_scores": scores} - return asset - - -def DrawGrid(img, coord, shape, thickness=2, color=(0, 0, 0, 255)): - cv2.rectangle( - img, - tuple(np.maximum([0, 0], coord - thickness // 2)), - tuple(coord - thickness // 2 + np.array(shape)), - (0, 0, 0, 255), - thickness=thickness, - ) - return img - - -def DrawMap( - canvas, patch_dset, coords, patch_size, indices=None, verbose=1, draw_grid=True -): - if indices is None: - indices = np.arange(len(coords)) - total = len(indices) - if verbose > 0: - ten_percent_chunk = math.ceil(total * 0.1) - print("start stitching {}".format(patch_dset.attrs["wsi_name"])) - - for idx in range(total): - if verbose > 0: - if idx % ten_percent_chunk == 0: - print("progress: {}/{} stitched".format(idx, total)) - - patch_id = indices[idx] - patch = patch_dset[patch_id] - patch = cv2.resize(patch, patch_size) - coord = coords[patch_id] - canvas_crop_shape = canvas[ - coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 - ].shape[:2] - canvas[ - coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 - ] = patch[: canvas_crop_shape[0], : canvas_crop_shape[1], :] - if draw_grid: - DrawGrid(canvas, coord, patch_size) - - return Image.fromarray(canvas) - - -def DrawMapFromCoords( - canvas, - wsi_object, - coords, - patch_size, - vis_level, - indices=None, - verbose=1, - draw_grid=True, -): - downsamples = wsi_object.wsi.level_downsamples[vis_level] - if indices is None: - indices = np.arange(len(coords)) - total = len(indices) - if verbose > 0: - ten_percent_chunk = math.ceil(total * 0.1) - - patch_size = tuple( - np.ceil((np.array(patch_size) / np.array(downsamples))).astype(np.int32) - ) - print("downscaled patch size: {}x{}".format(patch_size[0], patch_size[1])) - - for idx in range(total): - if verbose > 0: - if idx % ten_percent_chunk == 0: - print("progress: {}/{} stitched".format(idx, total)) - - patch_id = indices[idx] - coord = coords[patch_id] - patch = np.array( - wsi_object.wsi.read_region(tuple(coord), vis_level, patch_size).convert("RGB") - ) - coord = np.ceil(coord / downsamples).astype(np.int32) - canvas_crop_shape = canvas[ - coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 - ].shape[:2] - canvas[ - coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 - ] = patch[: canvas_crop_shape[0], : canvas_crop_shape[1], :] - if draw_grid: - DrawGrid(canvas, coord, patch_size) - - return Image.fromarray(canvas) - - -def StitchPatches( - hdf5_file_path, downscale=16, draw_grid=False, bg_color=(0, 0, 0), alpha=-1 -): - file = h5py.File(hdf5_file_path, "r") - dset = file["imgs"] - coords = file["coords"][:] - if "downsampled_level_dim" in dset.attrs.keys(): - w, h = dset.attrs["downsampled_level_dim"] - else: - w, h = dset.attrs["level_dim"] - print("original size: {} x {}".format(w, h)) - w = w // downscale - h = h // downscale - coords = (coords / downscale).astype(np.int32) - print("downscaled size for stiching: {} x {}".format(w, h)) - print("number of patches: {}".format(len(dset))) - img_shape = dset[0].shape - print("patch shape: {}".format(img_shape)) - downscaled_shape = (img_shape[1] // downscale, img_shape[0] // downscale) - - if w * h > Image.MAX_IMAGE_PIXELS: - raise Image.DecompressionBombError( - "Visualization Downscale %d is too large" % downscale - ) - - if alpha < 0 or alpha == -1: - heatmap = Image.new(size=(w, h), mode="RGB", color=bg_color) - else: - heatmap = Image.new( - size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) - ) - - heatmap = np.array(heatmap) - heatmap = DrawMap( - heatmap, dset, coords, downscaled_shape, indices=None, draw_grid=draw_grid - ) - - file.close() - return heatmap - - -def StitchCoords( - hdf5_file_path, - wsi_object, - downscale=16, - draw_grid=False, - bg_color=(0, 0, 0), - alpha=-1, -): - wsi = wsi_object.getOpenSlide() - vis_level = wsi.get_best_level_for_downsample(downscale) - file = h5py.File(hdf5_file_path, "r") - dset = file["coords"] - coords = dset[:] - w, h = wsi.level_dimensions[0] - - print("start stitching {}".format(dset.attrs["name"])) - print("original size: {} x {}".format(w, h)) - - w, h = wsi.level_dimensions[vis_level] - - print("downscaled size for stiching: {} x {}".format(w, h)) - print("number of patches: {}".format(len(coords))) - - patch_size = dset.attrs["patch_size"] - patch_level = dset.attrs["patch_level"] - print("patch size: {}x{} patch level: {}".format(patch_size, patch_size, patch_level)) - patch_size = tuple( - (np.array((patch_size, patch_size)) * wsi.level_downsamples[patch_level]).astype( - np.int32 - ) - ) - print("ref patch size: {}x{}".format(patch_size, patch_size)) - - if w * h > Image.MAX_IMAGE_PIXELS: - raise Image.DecompressionBombError( - "Visualization Downscale %d is too large" % downscale - ) - - if alpha < 0 or alpha == -1: - heatmap = Image.new(size=(w, h), mode="RGB", color=bg_color) - else: - heatmap = Image.new( - size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) - ) - - heatmap = np.array(heatmap) - heatmap = DrawMapFromCoords( - heatmap, - wsi_object, - coords, - patch_size, - vis_level, - indices=None, - draw_grid=draw_grid, - ) - - file.close() - return heatmap - - -def SamplePatches( - coords_file_path, - save_file_path, - wsi_object, - patch_level=0, - custom_downsample=1, - patch_size=256, - sample_num=100, - seed=1, - stitch=True, - verbose=1, - mode="w", -): - file = h5py.File(coords_file_path, "r") - dset = file["coords"] - coords = dset[:] - - h5_patch_size = dset.attrs["patch_size"] - h5_patch_level = dset.attrs["patch_level"] - - if verbose > 0: - print("in .h5 file: total number of patches: {}".format(len(coords))) - print( - "in .h5 file: patch size: {}x{} patch level: {}".format( - h5_patch_size, h5_patch_size, h5_patch_level - ) - ) - - if patch_level < 0: - patch_level = h5_patch_level - - if patch_size < 0: - patch_size = h5_patch_size - - np.random.seed(seed) - indices = np.random.choice( - np.arange(len(coords)), min(len(coords), sample_num), replace=False - ) - - target_patch_size = np.array([patch_size, patch_size]) - - if custom_downsample > 1: - target_patch_size = ( - np.array([patch_size, patch_size]) / custom_downsample - ).astype(np.int32) - - if stitch: - canvas = Mosaic_Canvas( - patch_size=target_patch_size[0], - n=sample_num, - downscale=4, - n_per_row=10, - bg_color=(0, 0, 0), - alpha=-1, - ) - else: - canvas = None - - for idx in indices: - coord = coords[idx] - patch = wsi_object.wsi.read_region( - coord, patch_level, tuple([patch_size, patch_size]) - ).convert("RGB") - if custom_downsample > 1: - patch = patch.resize(tuple(target_patch_size)) - - # if isBlackPatch_S(patch, rgbThresh=20, percentage=0.05) or isWhitePatch_S(patch, rgbThresh=220, percentage=0.25): - # continue - - if stitch: - canvas.paste_patch(patch) - - asset_dict = {"imgs": np.array(patch)[np.newaxis, ...], "coords": coord} - save_hdf5(save_file_path, asset_dict, mode=mode) - mode = "a" - - return canvas, len(coords), len(indices)