Skip to content

Commit

Permalink
Merge pull request #7 from supervisely-ecosystem/niko-test
Browse files Browse the repository at this point in the history
Update SDK Version
  • Loading branch information
NikolaiPetukhov authored Jun 11, 2024
2 parents f71c24b + ced558b commit d1f3d5b
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions supervisely_integration/serve/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def load_on_device(
# variable for storing image ids from previous inference iterations
self.previous_image_id = None
# dict for storing model variables to avoid unnecessary calculations
self.cache = Cache(maxsize=100, ttl=5 * 60)
self.model_cache = Cache(maxsize=100, ttl=5 * 60)
# set variables for smart tool mode
self._inference_image_lock = threading.Lock()
self._inference_image_cache = Cache(maxsize=100, ttl=60)
Expand All @@ -139,9 +139,9 @@ def model_meta(self):

def set_image_data(self, input_image, settings):
if settings["input_image_id"] != self.previous_image_id:
if settings["input_image_id"] not in self.cache:
if settings["input_image_id"] not in self.model_cache:
self.predictor.set_image(input_image)
self.cache.set(
self.model_cache.set(
settings["input_image_id"],
{
"features": self.predictor.features,
Expand All @@ -150,7 +150,7 @@ def set_image_data(self, input_image, settings):
},
)
else:
cached_data = self.cache.get(settings["input_image_id"])
cached_data = self.model_cache.get(settings["input_image_id"])
self.predictor.features = cached_data["features"]
self.predictor.input_size = cached_data["input_size"]
self.predictor.original_size = cached_data["original_size"]
Expand Down Expand Up @@ -287,15 +287,15 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
self.set_image_data(input_image, settings)
# get predicted masks
if (
settings["input_image_id"] in self.cache
settings["input_image_id"] in self.model_cache
and (
self.cache.get(settings["input_image_id"]).get("previous_bbox")
self.model_cache.get(settings["input_image_id"]).get("previous_bbox")
== bbox_coordinates
).all()
and self.previous_image_id == settings["input_image_id"]
):
# get mask from previous predicton and use at as an input for new prediction
mask_input = self.cache.get(settings["input_image_id"])["mask_input"]
mask_input = self.model_cache.get(settings["input_image_id"])["mask_input"]
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
Expand All @@ -311,12 +311,12 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
multimask_output=False,
)
# save bbox ccordinates and mask to cache
if settings["input_image_id"] in self.cache:
if settings["input_image_id"] in self.model_cache:
input_image_id = settings["input_image_id"]
cached_data = self.cache.get(input_image_id)
cached_data = self.model_cache.get(input_image_id)
cached_data["previous_bbox"] = bbox_coordinates
cached_data["mask_input"] = logits[0]
self.cache.set(input_image_id, cached_data)
self.model_cache.set(input_image_id, cached_data)
# update previous_image_id variable
self.previous_image_id = settings["input_image_id"]
mask = masks[0]
Expand All @@ -326,7 +326,6 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
def serve(self):
super().serve()
server = self._app.get_server()
self.add_cache_endpoint(server)

@server.post("/smart_segmentation")
def smart_segmentation(response: Response, request: Request):
Expand Down Expand Up @@ -387,9 +386,9 @@ def smart_segmentation(response: Response, request: Request):
smtool_state,
api,
app_dir,
cache_load_img=self.download_image,
cache_load_frame=self.download_frame,
cache_load_img_hash=self.download_image_by_hash,
cache_load_img=self.cache.download_image,
cache_load_frame=self.cache.download_frame,
cache_load_img_hash=self.cache.download_image_by_hash,
)
self._inference_image_cache.set(hash_str, image_np)
else:
Expand Down Expand Up @@ -470,7 +469,7 @@ def smart_segmentation(response: Response, request: Request):
def is_online(response: Response, request: Request):
response = {"is_online": True}
return response

@server.post("/smart_segmentation_batched")
def smart_segmentation_batched(response: Response, request: Request):
response_batch = {}
Expand All @@ -487,7 +486,6 @@ def smart_segmentation_batched(response: Response, request: Request):
return response_batch



model = SegmentAnythingHQModel(
use_gui=True,
custom_inference_settings="./supervisely_integration/serve/custom_settings.yaml",
Expand Down

0 comments on commit d1f3d5b

Please sign in to comment.