Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SDK Version #7

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions supervisely_integration/serve/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def load_on_device(
# variable for storing image ids from previous inference iterations
self.previous_image_id = None
# dict for storing model variables to avoid unnecessary calculations
self.cache = Cache(maxsize=100, ttl=5 * 60)
self.model_cache = Cache(maxsize=100, ttl=5 * 60)
# set variables for smart tool mode
self._inference_image_lock = threading.Lock()
self._inference_image_cache = Cache(maxsize=100, ttl=60)
Expand All @@ -139,9 +139,9 @@ def model_meta(self):

def set_image_data(self, input_image, settings):
if settings["input_image_id"] != self.previous_image_id:
if settings["input_image_id"] not in self.cache:
if settings["input_image_id"] not in self.model_cache:
self.predictor.set_image(input_image)
self.cache.set(
self.model_cache.set(
settings["input_image_id"],
{
"features": self.predictor.features,
Expand All @@ -150,7 +150,7 @@ def set_image_data(self, input_image, settings):
},
)
else:
cached_data = self.cache.get(settings["input_image_id"])
cached_data = self.model_cache.get(settings["input_image_id"])
self.predictor.features = cached_data["features"]
self.predictor.input_size = cached_data["input_size"]
self.predictor.original_size = cached_data["original_size"]
Expand Down Expand Up @@ -287,15 +287,15 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
self.set_image_data(input_image, settings)
# get predicted masks
if (
settings["input_image_id"] in self.cache
settings["input_image_id"] in self.model_cache
and (
self.cache.get(settings["input_image_id"]).get("previous_bbox")
self.model_cache.get(settings["input_image_id"]).get("previous_bbox")
== bbox_coordinates
).all()
and self.previous_image_id == settings["input_image_id"]
):
# get mask from previous predicton and use at as an input for new prediction
mask_input = self.cache.get(settings["input_image_id"])["mask_input"]
mask_input = self.model_cache.get(settings["input_image_id"])["mask_input"]
masks, scores, logits = self.predictor.predict(
point_coords=point_coordinates,
point_labels=point_labels,
Expand All @@ -311,12 +311,12 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
multimask_output=False,
)
# save bbox ccordinates and mask to cache
if settings["input_image_id"] in self.cache:
if settings["input_image_id"] in self.model_cache:
input_image_id = settings["input_image_id"]
cached_data = self.cache.get(input_image_id)
cached_data = self.model_cache.get(input_image_id)
cached_data["previous_bbox"] = bbox_coordinates
cached_data["mask_input"] = logits[0]
self.cache.set(input_image_id, cached_data)
self.model_cache.set(input_image_id, cached_data)
# update previous_image_id variable
self.previous_image_id = settings["input_image_id"]
mask = masks[0]
Expand All @@ -326,7 +326,6 @@ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[sly.nn.Pred
def serve(self):
super().serve()
server = self._app.get_server()
self.add_cache_endpoint(server)

@server.post("/smart_segmentation")
def smart_segmentation(response: Response, request: Request):
Expand Down Expand Up @@ -387,9 +386,9 @@ def smart_segmentation(response: Response, request: Request):
smtool_state,
api,
app_dir,
cache_load_img=self.download_image,
cache_load_frame=self.download_frame,
cache_load_img_hash=self.download_image_by_hash,
cache_load_img=self.cache.download_image,
cache_load_frame=self.cache.download_frame,
cache_load_img_hash=self.cache.download_image_by_hash,
)
self._inference_image_cache.set(hash_str, image_np)
else:
Expand Down Expand Up @@ -470,7 +469,7 @@ def smart_segmentation(response: Response, request: Request):
def is_online(response: Response, request: Request):
response = {"is_online": True}
return response

@server.post("/smart_segmentation_batched")
def smart_segmentation_batched(response: Response, request: Request):
response_batch = {}
Expand All @@ -487,7 +486,6 @@ def smart_segmentation_batched(response: Response, request: Request):
return response_batch



model = SegmentAnythingHQModel(
use_gui=True,
custom_inference_settings="./supervisely_integration/serve/custom_settings.yaml",
Expand Down