diff --git a/src/frdc/load/label_studio.py b/src/frdc/load/label_studio.py index 5486e92..8fd5ec2 100644 --- a/src/frdc/load/label_studio.py +++ b/src/frdc/load/label_studio.py @@ -16,8 +16,6 @@ def get_bounds_and_labels(self) -> tuple[list[tuple[int, int]], list[str]]: bounds = [] labels = [] - # for ann_ix, ann in enumerate(self["annotations"]): - ann = self["annotations"][0] results = ann["result"] for r_ix, r in enumerate(results): @@ -60,24 +58,15 @@ def get_task( project_id: int = 1, ): proj = LABEL_STUDIO_CLIENT.get_project(project_id) - # Get the task that has the file name - filter = Filters.create( - Filters.AND, - [ - Filters.item( - # The GS path is in the image column, so we can just filter on that - Column.data("image"), - Operator.CONTAINS, - Type.String, - Path(file_name).as_posix(), - ) - ], - ) - tasks = proj.get_tasks(filter) - - if len(tasks) > 1: + task_ids = [ + task["id"] + for task in proj.get_tasks() + if file_name.as_posix() in task["storage_filename"] + ] + + if len(task_ids) > 1: warn(f"More than 1 task found for {file_name}, using the first one") - elif len(tasks) == 0: + elif len(task_ids) == 0: raise ValueError(f"No task found for {file_name}") - return Task(tasks[0]) + return Task(proj.get_task(task_ids[0]))