Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding CLI for restore #875

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from natsort import natsorted
from tqdm import tqdm
from cellpose import utils, models, io, version_str, train
from cellpose import utils, models, io, version_str, train, denoise
from cellpose.cli import get_arg_parser

try:
Expand Down Expand Up @@ -90,9 +90,18 @@ def main():
else:
pretrained_model = args.pretrained_model

restore_type = args.restore_type
if restore_type is not None:
try:
denoise.model_path(restore_type)
except Exception as e:
raise ValueError("restore_type invalid")
if args.train or args.train_size:
raise ValueError("restore_type cannot be used with training on CLI yet")

model_type = None
if pretrained_model and not os.path.exists(pretrained_model):
model_type = pretrained_model if pretrained_model is not None else "cyto"
model_type = pretrained_model if pretrained_model is not None else "cyto3"
model_strings = models.get_user_models()
all_models = models.MODEL_NAMES.copy()
all_models.extend(model_strings)
Expand Down Expand Up @@ -127,26 +136,39 @@ def main():
">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s"
% (nimg, cstr0[channels[0]], cstr1[channels[1]]))

# handle built-in model exceptions; bacterial ones get no size model
if builtin_size:
# handle built-in model exceptions
if builtin_size and restore_type is None:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type)
else:
builtin_size = False
if args.all_channels:
channels = None
pretrained_model = None if model_type is not None else pretrained_model
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type)
if restore_type is None:
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type)
else:
model = denoise.CellposeDenoiseModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type,
restore_type=restore_type,
chan2_restore=args.chan2_restore)

# handle diameters
if args.diameter == 0:
if builtin_size:
diameter = None
logger.info(">>>> estimating diameter for each image")
else:
logger.info(
">>>> not using cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
)
if restore_type is None:
logger.info(
">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
)
else:
logger.info(
">>>> cannot auto-estimate diameter for image restoration"
)
diameter = model.diam_labels
logger.info(">>>> using diameter %0.3f for all images" % diameter)
else:
Expand All @@ -168,17 +190,26 @@ def main():
channel_axis=args.channel_axis, z_axis=args.z_axis,
anisotropy=args.anisotropy, niter=args.niter)
masks, flows = out[:2]
if len(out) > 3:
if len(out) > 3 and restore_type is None:
diams = out[-1]
else:
diams = diameter
ratio = 1.
if restore_type is not None:
imgs_dn = out[-1]
ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1.
diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams
else:
imgs_dn = None
if args.exclude_on_edges:
masks = utils.remove_edge_masks(masks)
if not args.no_npy:
io.masks_flows_to_seg(image, masks, flows, image_name,
channels=channels, diams=diams)
io.masks_flows_to_seg(image, masks, flows, image_name, imgs_restore=imgs_dn,
channels=channels, diams=diams,
restore_type=restore_type, ratio=1.)
if saving_something:
io.save_masks(image, masks, flows, image_name, png=args.save_png,
io.save_masks(image, masks, flows, image_name,
png=args.save_png,
tif=args.save_tif, save_flows=args.save_flows,
save_outlines=args.save_outlines,
dir_above=args.dir_above, savedir=args.savedir,
Expand Down
5 changes: 5 additions & 0 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def get_arg_parser():
model_args.add_argument("--pretrained_model", required=False, default="cyto",
type=str,
help="model to use for running or starting training")
model_args.add_argument("--restore_type", required=False, default=None,
type=str,
help="model to use for image restoration")
model_args.add_argument("--chan2_restore", action="store_true",
help="use nuclei restore model for second channel")
model_args.add_argument(
"--add_model", required=False, default=None, type=str,
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
Expand Down
16 changes: 9 additions & 7 deletions cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,19 +464,19 @@ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
class CellposeDenoiseModel():
""" model to run Cellpose and Image restoration """
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
restore_type="denoise_cyto3", chan2_denoise=False,
restore_type="denoise_cyto3", chan2_restore=False,
device=None):

self.dn = DenoiseModel(gpu=gpu, model_type=restore_type,
chan2=chan2_denoise, device=device)
chan2=chan2_restore, device=device)
self.cp = CellposeModel(gpu=gpu, model_type=model_type,
pretrained_model=pretrained_model, device=device)

def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1,
resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0,
do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15,
niter=None, interp=True):
augment=False, resample=True, invert=False, flow_threshold=0.4,
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
min_size=15, niter=None, interp=True):
"""
Restore array or list of images using the image restoration model, and then segment.

Expand Down Expand Up @@ -510,6 +510,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
Expand Down Expand Up @@ -549,7 +550,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
masks, flows, styles = self.cp.eval(img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
normalize=normalize_params, rescale=rescale, diameter=diameter,
tile=tile, tile_overlap=tile_overlap, resample=resample, invert=invert,
tile=tile, tile_overlap=tile_overlap, augment=augment,
resample=resample, invert=invert,
flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold,
do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold,
min_size=min_size, niter=niter, interp=interp)
Expand Down Expand Up @@ -644,7 +646,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
)
if chan2 and builtin:
chan2_path = model_path(os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
print(f"loading model for chan2: {os.path.split(str(chan2_path)[-1])}")
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
mkldnn=self.mkldnn, max_pool=True,
diam_mean=17.).to(self.device)
Expand Down
Loading
Loading