Skip to content

Commit

Permalink
feat: add result seg file formats from geopandas
Browse files Browse the repository at this point in the history
  • Loading branch information
okunator committed Oct 10, 2023
1 parent 30a7510 commit 351c232
Show file tree
Hide file tree
Showing 8 changed files with 2,143 additions and 1,482 deletions.
212 changes: 118 additions & 94 deletions cellseg_models_pytorch/inference/_base_inferer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from itertools import chain
Expand Down Expand Up @@ -33,9 +32,6 @@ def __init__(
normalization: str = None,
device: str = "cuda",
n_devices: int = 1,
save_intermediate: bool = False,
save_dir: Union[Path, str] = None,
save_format: str = ".mat",
checkpoint_path: Union[Path, str] = None,
n_images: int = None,
type_post_proc: Callable = None,
Expand All @@ -46,57 +42,49 @@ def __init__(
Parameters
----------
model : nn.Module
A segmentation model.
input_path : Path | str
Path to a folder of images or to hdf5 db.
out_activations : Dict[str, str]
Dictionary of head names mapped to a string value that specifies the
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
Allowed values: "softmax", "sigmoid", "tanh", None.
out_boundary_weights : Dict[str, bool]
Dictionary of head names mapped to a boolean value. If the value is
True, after a prediction, a weight matrix is applied that assigns bigger
weight on pixels in the center and less weight to pixels on the tile
boundaries. helps dealing with prediction artefacts on the boundaries.
E.g. {"type": False, "cellpose": True}
patch_size : Tuple[int, int]:
The size of the input patches that are fed to the segmentation model.
instance_postproc : str
The post-processing method for the instance segmentation mask. One of:
"cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
padding : int, optional
The amount of reflection padding for the input images.
batch_size : int, default=8
Number of images loaded from the folder at every batch.
normalization : str, optional
Apply img normalization at forward pass (Same as during training).
One of: "dataset", "minmax", "norm", "percentile", None.
device : str, default="cuda"
The device of the input and model. One of: "cuda", "cpu"
n_devices : int, default=1
Number of devices (cpus/gpus) used for inference.
The model will be copied into these devices.
save_dir : bool, optional
Path to save directory. If None, no masks will be saved to disk as .mat
or .json files. Instead the masks will be saved in `self.out_masks`.
save_intermediate : bool, default=False
If True, intermediate soft masks will be saved into `soft_masks` var.
save_format : str, default=".mat"
The file format for the saved output masks. One of (".mat", ".json").
The ".json" option will save masks into geojson format.
checkpoint_path : Path | str, optional
Path to the model weight checkpoints.
n_images : int, optional
First n-number of images used from the `input_path`.
type_post_proc : Callable, optional
A post-processing function for the type maps. If not None, overrides
the default.
sem_post_proc : Callable, optional
A post-processing function for the semantc seg maps. If not None,
overrides the default.
**kwargs:
Arbitrary keyword arguments expecially for post-processing and saving.
model : nn.Module
A segmentation model.
input_path : Path | str
Path to a folder of images or to hdf5 db.
out_activations : Dict[str, str]
Dictionary of head names mapped to a string value that specifies the
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
Allowed values: "softmax", "sigmoid", "tanh", None.
out_boundary_weights : Dict[str, bool]
Dictionary of head names mapped to a boolean value. If the value is
True, after a prediction, a weight matrix is applied that assigns bigger
weight on pixels in the center and less weight to pixels on the tile
boundaries. helps dealing with prediction artefacts on the boundaries.
E.g. {"type": False, "cellpose": True}
patch_size : Tuple[int, int]:
The size of the input patches that are fed to the segmentation model.
instance_postproc : str
The post-processing method for the instance segmentation mask. One of:
"cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
padding : int, optional
The amount of reflection padding for the input images.
batch_size : int, default=8
Number of images loaded from the folder at every batch.
normalization : str, optional
Apply img normalization at forward pass (Same as during training).
One of: "dataset", "minmax", "norm", "percentile", None.
device : str, default="cuda"
The device of the input and model. One of: "cuda", "cpu"
n_devices : int, default=1
Number of devices (cpus/gpus) used for inference.
The model will be copied into these devices.
checkpoint_path : Path | str, optional
Path to the model weight checkpoints.
n_images : int, optional
First n-number of images used from the `input_path`.
type_post_proc : Callable, optional
A post-processing function for the type maps. If not None, overrides
the default.
sem_post_proc : Callable, optional
A post-processing function for the semantc seg maps. If not None,
overrides the default.
**kwargs:
Arbitrary keyword arguments for post-processing.
"""
# basic inits
self.model = model
Expand All @@ -109,22 +97,10 @@ def __init__(
self.head_kwargs = self._check_and_set_head_args()
self.kwargs = kwargs

self.save_dir = Path(save_dir) if save_dir is not None else None
self.save_intermediate = save_intermediate
self.save_format = save_format

# dataset & dataloader
self.path = Path(input_path)
if self.path.is_dir():
ds = FolderDatasetInfer(self.path, n_images=n_images)
if self.save_dir is None and len(ds.fnames) > 40 and n_images is None:
warnings.warn(
"`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
"class attribute. If the input folder contains many images, running"
" inference will likely flood the memory depending on the size and "
"number of the images. Consider saving outputs to disk by providing"
" `save_dir` argument."
)
elif self.path.is_file() and self.path.suffix in (".h5", ".hdf5"):
from .hdf5_dataset_infer import HDF5DatasetInfer

Expand Down Expand Up @@ -167,10 +143,10 @@ def __init__(

# try loading the weights to the model
try:
msg = self.model.load_state_dict(state_dict, strict=True)
msg = self.model.load_state_dict(state_dict, strict=False)
except RuntimeError:
new_ckpt = self._strip_state_dict(state_dict)
msg = self.model.load_state_dict(new_ckpt, strict=True)
msg = self.model.load_state_dict(new_ckpt, strict=False)
except BaseException as e:
raise RuntimeError(f"Error when loading checkpoint: {e}")

Expand Down Expand Up @@ -218,34 +194,74 @@ def from_yaml(cls, model: nn.Module, yaml_path: str):
def _infer_batch(self):
raise NotImplementedError

def infer(self, mixed_precision: bool = False) -> None:
"""Run inference and post-processing for the images.
def infer(
self,
save_dir: Union[Path, str] = None,
save_format: str = ".mat",
save_intermediate: bool = False,
classes_type: Dict[str, int] = None,
classes_sem: Dict[str, int] = None,
offsets: bool = False,
mixed_precision: bool = False,
) -> None:
"""Run inference and post-processing for the image(s) inside `input_path`.
NOTE:
- Saves outputs in class attributes or to disk (.mat/.json) files.
- If masks are saved to .json (geojson) files, more key word arguments
need to be given at class initialization. Namely: `geo_format`,
`classes_type`, `classes_sem`, `offsets`. See more in the
`FileHandler.save_masks` docs.
NOTE: If `save_dir` is None, the output masks will be cached in a class
attribute `self.out_masks`. Otherwise the masks will be saved to disk.
WARNING: Running inference without setting `save_dir` can take a lot of memory
if the input directory contains many images.
Parameters
----------
mixed_precision : bool, default=False
If True, inference is performed with mixed precision.
save_dir : bool, optional
Path to save directory. If None, no masks will be saved to disk.
Instead the masks will be cached in a class attribute `self.out_masks`.
save_format : str, default=".mat"
The file format for the saved output masks. One of ".mat", ".geojson",
"feather" "parquet".
save_intermediate : bool, default=False
If True, intermediate soft masks will be saved into `self.soft_masks`
class attribute. WARNING: This can take a lot of memory if the input
directory contains many images.
classes_type : Dict[str, str], optional
Cell type dictionary. e.g. {"inflam":1, "epithelial":2, "connec":3}.
This is required only if `save_format` is one of the following formats:
".geojson", ".parquet", ".feather".
classes_sem : Dict[str, str], otional
Tissue type dictionary. e.g. {"tissue1":1, "tissue2":2, "tissue3":3}
This is required only if `save_format` is one of the following formats:
".geojson", ".parquet", ".feather".
offsets : bool, default=False
If True, geojson coords are shifted by the offsets that are encoded in
the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` == `.mat`.
mixed_precision : bool, default=False
If True, inference is performed with mixed precision.
Attributes
----------
- out_masks : Dict[str, Dict[str, np.ndarray]]
The output masks for each image. The keys are the image names and the
values are dictionaries of the masks. E.g.
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
- soft_masks : Dict[str, Dict[str, np.ndarray]]
NOTE: This attribute is set only if `save_intermediate = True`.
The soft masks for each image. I.e. the soft predictions of the trained
model The keys are the image names and the values are dictionaries of
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
- out_masks : Dict[str, Dict[str, np.ndarray]]
The output masks for each image. The keys are the image names and the
values are dictionaries of the masks. E.g.
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
- soft_masks : Dict[str, Dict[str, np.ndarray]]
NOTE: This attribute is set only if `save_intermediate = True`.
The soft masks for each image. I.e. the soft predictions of the trained
model The keys are the image names and the values are dictionaries of
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
"""
# check save_dir and save_format
save_dir = Path(save_dir) if save_dir is not None else None
save_intermediate = save_intermediate
save_format = save_format
if save_dir is not None:
allowed_formats = (".mat", ".geojson", ".feather", ".parquet")
if save_format not in allowed_formats:
raise ValueError(
f"Given `save_format`: {save_format} is not one of the allowed "
f"formats: {allowed_formats}"
)

self.soft_masks = {}
self.out_masks = {}
self.elapsed = []
Expand All @@ -271,7 +287,7 @@ def infer(self, mixed_precision: bool = False) -> None:
self.elapsed.append(loader.format_dict["elapsed"])
self.rate.append(loader.format_dict["rate"])

if self.save_intermediate:
if save_intermediate:
for n, m in zip(names, soft_masks):
self.soft_masks[n] = m

Expand All @@ -283,25 +299,33 @@ def infer(self, mixed_precision: bool = False) -> None:
seg["soft_sem"] = soft["sem"]

# save to cache or disk
if self.save_dir is None:
if save_dir is None:
for n, m in zip(names, seg_results):
self.out_masks[n] = m
else:
loader.set_postfix_str("Saving results to disk")
if self.batch_size > 1:
fnames = [Path(self.save_dir) / n for n in names]
fnames = [Path(save_dir) / n for n in names]
FileHandler.save_masks_parallel(
maps=seg_results,
fnames=fnames,
**{**self.kwargs, "format": self.save_format},
format=save_format,
classes_type=classes_type,
classes_sem=classes_sem,
offsets=offsets,
pooltype="thread",
maptype="amap",
)
else:
for n, m in zip(names, seg_results):
fname = Path(self.save_dir) / n
fname = Path(save_dir) / n
FileHandler.save_masks(
fname=fname,
maps=m,
**{**self.kwargs, "format": self.save_format},
format=save_format,
classes_type=classes_type,
classes_sem=classes_sem,
offsets=offsets,
)

def _strip_state_dict(self, ckpt: Dict) -> OrderedDict:
Expand Down
Loading

0 comments on commit 351c232

Please sign in to comment.