diff --git a/supervisely_integration/serve/src/main.py b/supervisely_integration/serve/src/main.py index d31fff8..d54c23b 100644 --- a/supervisely_integration/serve/src/main.py +++ b/supervisely_integration/serve/src/main.py @@ -114,7 +114,7 @@ def load_on_device( # variable for storing image ids from previous inference iterations self.previous_image_id = None # dict for storing model variables to avoid unnecessary calculations - self.cache = Cache(maxsize=100, ttl=5 * 60) + self.model_cache = Cache(maxsize=100, ttl=5 * 60) # set variables for smart tool mode self._inference_image_lock = threading.Lock() self._inference_image_cache = Cache(maxsize=100, ttl=60) @@ -139,9 +139,9 @@ def model_meta(self): def set_image_data(self, input_image, settings): if settings["input_image_id"] != self.previous_image_id: - if settings["input_image_id"] not in self.cache: + if settings["input_image_id"] not in self.model_cache: self.predictor.set_image(input_image) - self.cache.set( + self.model_cache.set( settings["input_image_id"], { "features": self.predictor.features, @@ -150,7 +150,7 @@ def set_image_data(self, input_image, settings): }, ) else: - cached_data = self.cache.get(settings["input_image_id"]) + cached_data = self.model_cache.get(settings["input_image_id"]) self.predictor.features = cached_data["features"] self.predictor.input_size = cached_data["input_size"] self.predictor.original_size = cached_data["original_size"] @@ -287,15 +287,15 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred self.set_image_data(input_image, settings) # get predicted masks if ( - settings["input_image_id"] in self.cache + settings["input_image_id"] in self.model_cache and ( - self.cache.get(settings["input_image_id"]).get("previous_bbox") + self.model_cache.get(settings["input_image_id"]).get("previous_bbox") == bbox_coordinates ).all() and self.previous_image_id == settings["input_image_id"] ): # get mask from previous predicton and use at as an input for new prediction - mask_input = self.cache.get(settings["input_image_id"])["mask_input"] + mask_input = self.model_cache.get(settings["input_image_id"])["mask_input"] masks, scores, logits = self.predictor.predict( point_coords=point_coordinates, point_labels=point_labels, @@ -311,12 +311,12 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred multimask_output=False, ) # save bbox ccordinates and mask to cache - if settings["input_image_id"] in self.cache: + if settings["input_image_id"] in self.model_cache: input_image_id = settings["input_image_id"] - cached_data = self.cache.get(input_image_id) + cached_data = self.model_cache.get(input_image_id) cached_data["previous_bbox"] = bbox_coordinates cached_data["mask_input"] = logits[0] - self.cache.set(input_image_id, cached_data) + self.model_cache.set(input_image_id, cached_data) # update previous_image_id variable self.previous_image_id = settings["input_image_id"] mask = masks[0] @@ -326,7 +326,6 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred def serve(self): super().serve() server = self._app.get_server() - self.add_cache_endpoint(server) @server.post("/smart_segmentation") def smart_segmentation(response: Response, request: Request): @@ -387,9 +386,9 @@ def smart_segmentation(response: Response, request: Request): smtool_state, api, app_dir, - cache_load_img=self.download_image, - cache_load_frame=self.download_frame, - cache_load_img_hash=self.download_image_by_hash, + cache_load_img=self.cache.download_image, + cache_load_frame=self.cache.download_frame, + cache_load_img_hash=self.cache.download_image_by_hash, ) self._inference_image_cache.set(hash_str, image_np) else: @@ -470,7 +469,7 @@ def smart_segmentation(response: Response, request: Request): def is_online(response: Response, request: Request): response = {"is_online": True} return response - + @server.post("/smart_segmentation_batched") def smart_segmentation_batched(response: Response, request: Request): response_batch = {} @@ -487,7 +486,6 @@ def smart_segmentation_batched(response: Response, request: Request): return response_batch - model = SegmentAnythingHQModel( use_gui=True, custom_inference_settings="./supervisely_integration/serve/custom_settings.yaml",