Skip to content

Commit

Permalink
Rl/process dic timelapses (#19)
Browse files Browse the repository at this point in the history
* Update stack_processing.py

- add dask-compatible function for Li thresholding
- change default number of central frames to extract during processing

* enable processing for DIC time lapses

This is a sort of crude and minimal implementation for processing DIC time lapses of agar microchambers such that it could be used to (quickly) process old in-house datasets.

* Fixes from code review #19
  • Loading branch information
lanery authored Jul 19, 2024
1 parent a077f21 commit 298033c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 27 deletions.
4 changes: 4 additions & 0 deletions src/chlamytracker/pool_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def segment_pools(self, filled_ratio_threshold=0.1):
(Nx, Ny): (T, Y, X) bool array
}
"""
# set imaging modality for segmentation
modality = "brightfield" if self.is_brightfield else "DIC"

# convert minimum cell diameter to pixelated area
min_area = self.convert_um_to_px2_circle(self.min_cell_diameter_um)

Expand All @@ -341,6 +344,7 @@ def segment_pools(self, filled_ratio_threshold=0.1):
if pool.has_cells():
try:
pool_segmented = pool.segment(
modality=modality,
min_area=min_area,
filled_ratio_threshold=filled_ratio_threshold,
)
Expand Down
87 changes: 62 additions & 25 deletions src/chlamytracker/pool_processor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import skimage as ski

from .stack_processing import (
circular_alpha_mask,
gaussian_filter_3d_parallel,
get_central_frames,
li_threshold_3d,
otsu_threshold_3d,
remove_small_objects_3d_parallel,
rescale_to_float,
)
Expand All @@ -13,16 +15,16 @@
class PoolSegmenter:
"""Class for processing timelapse microscopy data of an individual agar microchamber pool.
TODO: detailed description of what processing steps this class seeks to accomplish.
Primary application is for segmenting cells within an individual agar
microchamber pool [1]. Performs background subtraction prior to segmentation.
Background is estimated as the mean intensity projection.
Parameters
----------
raw_data_pool : (T, Y, X) uint16 array
Input timelapse microscopy image data of an individual agar microchamber
pool that has been tightly cropped to either manually or e.g. after
being detected with `PoolFinder.find_pools()`.
gaussian_filter_sigma : scalar (optional)
Sigma of Gaussian filter for blurring the alpha mask (preprocessing).
num_workers : int (optional)
Number of processors to dedicate for multiprocessing.
Expand All @@ -37,39 +39,74 @@ class PoolSegmenter:
[1] https://doi.org/10.57844/arcadia-v1bg-6b60
"""

def __init__(
self,
raw_data_pool,
gaussian_filter_sigma=4,
num_workers=6,
):
def __init__(self, raw_data_pool, num_workers=6):
self.raw_data = raw_data_pool.copy()
self.gaussian_filter_sigma = gaussian_filter_sigma
self.num_workers = num_workers

def has_cells(self, contrast_threshold=0.05):
def has_cells(self, contrast_threshold=1e-3, num_central_frames=200):
"""Determine whether pool contains cells.
Determination is based on the amount of contrast in the standard
deviation projection, using the variance of intensity values as a proxy
for contrast.
deviation projection, using the standard deviation of intensity
values as a proxy for contrast.
Default values for `contrast_threshold` and `num_central_frames`
were derived empirically from visual inspection of several time
lapses of pools with and without cells.
"""
# get dtype limits for normalization
# (0, 65535) is expected but safer to check
dtype_limit_max = max(ski.util.dtype_limits(self.raw_data))
# compute the standard deviation projection
std_intensity_projection = self.raw_data.std(axis=0)
# use variance of intensity as measure of contrast
normalized_contrast = std_intensity_projection.var() / dtype_limit_max
# std projection on smoothed substack
num_central_frames = min(num_central_frames, self.raw_data.shape[0])
central_frames = get_central_frames(self.raw_data, num_central_frames)
central_frames_smoothed = gaussian_filter_3d_parallel(central_frames, sigma=3)
std_projection = central_frames_smoothed.std(axis=0)
# use std dev of intensity as measure of contrast
normalized_contrast = std_projection.std()
return normalized_contrast > contrast_threshold

@timeit
def segment(self, min_area=150, filled_ratio_threshold=0.1, li_threshold=0.1):
""""""
def segment(
self,
modality="brightfield",
min_area=150,
filled_ratio_threshold=0.1,
li_threshold=0.1,
otsu_thresholding_scale_factor=0.66,
):
"""Segment cells within a pool.
Parameters
----------
modality : str (optional)
Imaging modality: either "brightfield" or "DIC".
min_area : float (optional)
Minimum area (px^2) for object removal.
filled_ratio_threshold : float (optional)
Threshold used for discarding noisy segmentation results.
For reliable cell tracking it is assumed that the pools are
quite sparsely populated with cells.
li_threshold : float (optional)
Initial guess parameter for Li thresholding [1].
otsu_thresholding_scale_factor : float (optional)
Value for (somewhat arbitrarily) scaling the Otsu threshold.
Default value of 0.66 was found to work empirically on a sample
of test DIC time lapses.
References
----------
[1] https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.threshold_li
"""
# background subtraction
background_subtracted = self.subtract_background()
# segment cells based on Li thresholding -- more forgiving than Otsu
threshold = ski.filters.threshold_li(background_subtracted, initial_guess=li_threshold)

# set threshold for segmentation based on modality -- Li thresholding
# was empirically determined to provide better results on brightfield
# microscopy data, while Otsu performed better on DIC
if modality == "brightfield":
threshold = li_threshold_3d(background_subtracted, initial_guess=li_threshold)
else:
threshold = otsu_threshold_3d(background_subtracted) * otsu_thresholding_scale_factor

# apply threshold
segmentation = background_subtracted > threshold

# apply circular alpha mask to segmentation
Expand Down
22 changes: 20 additions & 2 deletions src/chlamytracker/stack_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@timeit
def get_central_frames(stack, num_central_frames=10):
def get_central_frames(stack, num_central_frames=100):
"""Crops `num_central_frames` from the center of an image stack.
Parameters
Expand Down Expand Up @@ -104,7 +104,7 @@ def rescale_to_float(stack):


@timeit
def otsu_threshold_3d(stack, num_central_frames=10):
def otsu_threshold_3d(stack, num_central_frames=100):
"""Wrapper for `ski.filters.threshold_otsu` better equipped for handling
large image stacks.
Expand All @@ -121,6 +121,24 @@ def otsu_threshold_3d(stack, num_central_frames=10):
return threshold


@timeit
def li_threshold_3d(stack, num_central_frames=100, initial_guess=0.1):
"""Wrapper for `ski.filters.threshold_li` better equipped for handling
large image stacks.
Parameters
----------
stack : (Z, Y, X) array
Input image stack of arbitrary dtype.
num_central_frames : int
Number of central frames to use for determining the threshold. Useful
for speeding up computation time when large stacks when
"""
central_frames = get_central_frames(stack, num_central_frames)
threshold = ski.filters.threshold_li(central_frames, initial_guess=initial_guess)
return threshold


def otsu_threshold_dask(dask_array):
"""Otsu thresholding function compatible with dask.
Expand Down
5 changes: 5 additions & 0 deletions src/chlamytracker/timelapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def __init__(self, nd2_file, use_dask=False, load=True):

# extract relevant metadata from nd2 headers
with nd2.ND2File(nd2_file) as nd2f:
metadata = nd2f.metadata
voxels_um = nd2f.voxel_size() # in microns
sizes = nd2f.sizes # e.g. {'T': 10, 'C': 2, 'Y': 256, 'X': 256}
events = nd2f.events()

# convert metadata fields to useful attributes
self.metadata = metadata
self.dimensions = sizes
self.pixelsize_um = (voxels_um.x + voxels_um.y) / 2
self.frametimes = np.diff([event["Time [s]"] for event in events])
Expand All @@ -55,6 +57,9 @@ def __init__(self, nd2_file, use_dask=False, load=True):

# determine whether timelapse is also a zstack
self.is_zstack = self.dimensions.get("Z") is not None
# determine imaging modality
self.modality_flags = self.metadata.channels[0].microscope.modalityFlags[0]
self.is_brightfield = "brightfield" in self.modality_flags

# load data from nd2 file
if load:
Expand Down

0 comments on commit 298033c

Please sign in to comment.