diff --git a/cellseg_models_pytorch/inference/_base_inferer.py b/cellseg_models_pytorch/inference/_base_inferer.py index 57e9f68..3f0b061 100644 --- a/cellseg_models_pytorch/inference/_base_inferer.py +++ b/cellseg_models_pytorch/inference/_base_inferer.py @@ -143,10 +143,10 @@ def __init__( # try loading the weights to the model try: - msg = self.model.load_state_dict(state_dict, strict=False) + msg = self.model.load_state_dict(state_dict, strict=True) except RuntimeError: new_ckpt = self._strip_state_dict(state_dict) - msg = self.model.load_state_dict(new_ckpt, strict=False) + msg = self.model.load_state_dict(new_ckpt, strict=True) except BaseException as e: raise RuntimeError(f"Error when loading checkpoint: {e}") diff --git a/cellseg_models_pytorch/utils/file_manager.py b/cellseg_models_pytorch/utils/file_manager.py index d3c5451..759bf2f 100644 --- a/cellseg_models_pytorch/utils/file_manager.py +++ b/cellseg_models_pytorch/utils/file_manager.py @@ -318,9 +318,6 @@ def get_gson( x-coordinate offset. (to set geojson to .mrxs wsi coordinates) y_offset : int, default=0 y-coordinate offset. (to set geojson to .mrxs wsi coordinates) - geo_format : str, default="qupath" - The format for the geo object. "qupath" format allows the result file - to be read with QuPath. "simple" format allows for geopandas etc. Returns ------- @@ -357,25 +354,38 @@ def get_gson( inst_type_soft[key] = float(inst_type_soft[key]) # get the cell contour coordinates - contours = cv2.findContours(inst, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + contours = cv2.findContours(inst, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours = contours[0] if len(contours) == 2 else contours[1] + shell = contours[0] # exterior + holes = [cont for cont in contours[1:]] + # got a line instead of a polygon - if contours[0].shape[0] < 3: + if shell.shape[0] < 3: continue # shift coordinates based on the offsets if x_offset: - contours[0][..., 0] += x_offset + shell[..., 0] += x_offset + if holes: + for cont in holes: + cont[..., 0] += x_offset + if y_offset: - contours[0][..., 1] += y_offset + shell[..., 1] += y_offset + if holes: + for cont in holes: + cont[..., 1] += y_offset - poly = contours[0].squeeze().tolist() - poly.append(poly[0]) # close the polygon + # convert to list for shapely Polygon + shell = shell.squeeze().tolist() + if holes: + holes = [cont.squeeze().tolist() for cont in holes] + # shell.append(shell[0]) # close the polygon features.append( FileHandler.geo_obj( - poly=Polygon(poly), + poly=Polygon(shell=shell, holes=holes), uid=inst_id, class_name=inst_type, class_probs=inst_type_soft, @@ -390,6 +400,7 @@ def to_gson( features: List[Dict[str, Any]], format: str = ".feather", show_bbox: bool = True, + silence_warnings: bool = True, ) -> None: """Write a geojson/feather/parquet file from a list of geojson features. @@ -403,6 +414,8 @@ def to_gson( The output format. One of ".feather", ".parquet", ".geojson". show_bbox : bool, default=True If True, the bbox is added to the geojson object. + silence_warnings : bool, default=True + If True, warnings are silenced. """ out_fn = Path(out_fn) if format not in (".feather", ".parquet", ".geojson"): @@ -442,7 +455,8 @@ def to_gson( if show_bbox: geo["bbox"] = tuple(gdf.total_bounds) else: - warnings.warn(f"The {out_fn.name} file is empty.") + if not silence_warnings: + warnings.warn(f"The {out_fn.name} file is empty.") if format == ".feather": gdf.to_feather(out_fn.with_suffix(".feather"))