diff --git a/src/compute/layers/processing/SplitDataLayer.py b/src/compute/layers/processing/SplitDataLayer.py index f91bae92..e2ad8a57 100644 --- a/src/compute/layers/processing/SplitDataLayer.py +++ b/src/compute/layers/processing/SplitDataLayer.py @@ -15,13 +15,14 @@ class SplitDataLayer(Layer): "properties": { "settings": { "type": "object", - "required": ["split_method", "split_ratio", "split_num"], + "required": ["split_method", "split_ratio", "split_num", "split_parts"], "properties": { "split_method": { "type": "string", }, "split_ratio": {"type": "number"}, "split_num": {"type": "number"}, + "split_parts": {"type": "number"}, }, } }, @@ -34,8 +35,9 @@ def validate(self): split_method = self.settings["split_method"] split_ratio = self.settings["split_ratio"] split_num = self.settings["split_num"] + split_parts = self.settings["split_parts"] - allowed_methods = ["percent", "number", "classes", "tags"] + allowed_methods = ["percent", "number", "classes", "tags", "parts"] if split_method not in allowed_methods: raise BadSettingsError(f"Unknown split method selected: {split_method}") @@ -43,10 +45,15 @@ def validate(self): raise BadSettingsError( f"Invalid percentage value: {split_ratio}. Split percentage must be between 1 and 100" ) - if split_num < 1 or split_num > 10000: + if split_num < 1: raise BadSettingsError( - f"Invalid split value: {split_num}. Split value must be between 1 and 10000" + f"Invalid split value: {split_num}. Split value must be 1 or greater" ) + if split_parts < 1: + raise BadSettingsError( + f"Invalid split value: {split_parts}. Split value must be 1 or greater" + ) + super().validate() def requires_item(self): @@ -73,6 +80,11 @@ def _split_by_num() -> List[Tuple[ImageDescriptor, Annotation]]: split_index = int(item_idx / split_num) + (item_idx % split_num > 0) return [(replace_ds_name(f"split_{split_index}"), ann)] + def _split_by_parts() -> List[Tuple[ImageDescriptor, Annotation]]: + split_parts = self.settings["split_parts"] + split_index = item_idx % split_parts + return [(replace_ds_name(f"split_{split_index}"), ann)] + def _split_by_class() -> List[Tuple[ImageDescriptor, Annotation]]: image_labels = ann.labels if len(image_labels) == 0: @@ -106,7 +118,9 @@ def _split_by_tags() -> List[Tuple[ImageDescriptor, Annotation]]: "number": _split_by_num, "classes": _split_by_class, "tags": _split_by_tags, + "parts": _split_by_parts, } + split_method = self.settings["split_method"] func = split_func_map.get(split_method) items = func() diff --git a/src/ui/dtl/actions/other/split_data/layout/split_data_sidebar.py b/src/ui/dtl/actions/other/split_data/layout/split_data_sidebar.py index 7026cd72..9cb50639 100644 --- a/src/ui/dtl/actions/other/split_data/layout/split_data_sidebar.py +++ b/src/ui/dtl/actions/other/split_data/layout/split_data_sidebar.py @@ -21,7 +21,7 @@ def create_sidebar_widgets(): "Select percentage by which to distribute images across datasets", ) - sidebar_number_input = InputNumber(min=1, max=10000, value=50) + sidebar_number_input = InputNumber(min=1, max=None, value=50) sidebar_number_field = Field( sidebar_number_input, "Select number of images", @@ -29,17 +29,32 @@ def create_sidebar_widgets(): ) sidebar_number_field.hide() + sidebar_parts_input = InputNumber(min=1, max=None, value=5) + sidebar_parts_field = Field( + sidebar_parts_input, + "Select number of parts", + "Select number of datasets to split data into. Resulting datasets will have equal number of images", + ) + sidebar_parts_field.hide() + sidebar_items = [ Select.Item("percent", "by percent"), Select.Item("number", "by number"), Select.Item("classes", "by classes"), Select.Item("tags", "by tags"), + Select.Item("parts", "by parts"), ] sidebar_selector = Select(sidebar_items) sidebar_save_button = create_save_btn() sidebar_container = Container( - [sidebar_selector, sidebar_percent_field, sidebar_number_field, sidebar_save_button] + [ + sidebar_selector, + sidebar_percent_field, + sidebar_number_field, + sidebar_parts_field, + sidebar_save_button, + ] ) sidebar_selector_field = Field( sidebar_container, @@ -54,5 +69,7 @@ def create_sidebar_widgets(): sidebar_percent_field, sidebar_number_input, sidebar_number_field, + sidebar_parts_input, + sidebar_parts_field, sidebar_save_button, ) diff --git a/src/ui/dtl/actions/other/split_data/split_data.py b/src/ui/dtl/actions/other/split_data/split_data.py index 29efa1c5..62821b5a 100644 --- a/src/ui/dtl/actions/other/split_data/split_data.py +++ b/src/ui/dtl/actions/other/split_data/split_data.py @@ -26,6 +26,8 @@ def create_new_layer(cls, layer_id: Optional[str] = None): sidebar_percent_field, sidebar_number_input, sidebar_number_field, + sidebar_parts_input, + sidebar_parts_field, sidebar_save_button, ) = create_sidebar_widgets() @@ -43,14 +45,21 @@ def create_new_layer(cls, layer_id: Optional[str] = None): @sidebar_selector.value_changed def selector_cb(value): if value == "percent": - sidebar_percent_field.show() sidebar_number_field.hide() + sidebar_parts_field.hide() + sidebar_percent_field.show() elif value == "number": sidebar_percent_field.hide() + sidebar_parts_field.hide() sidebar_number_field.show() + elif value == "parts": + sidebar_percent_field.hide() + sidebar_number_field.hide() + sidebar_parts_field.show() else: sidebar_percent_field.hide() sidebar_number_field.hide() + sidebar_parts_field.hide() @sidebar_save_button.click def save_cb(): @@ -58,6 +67,7 @@ def save_cb(): "split_method": sidebar_selector.get_value(), "split_ratio": sidebar_percent_slider.get_value(), "split_num": sidebar_number_input.get_value(), + "split_parts": sidebar_parts_input.get_value(), } _set_settings_from_json(saved_settings) @@ -70,11 +80,13 @@ def _set_settings_from_json(settings: dict): method = settings.get("split_method", "percent") ratio = settings.get("split_ratio", 50) num = settings.get("split_num", 50) + parts = settings.get("split_parts", 5) saved_settings = { "split_method": method, "split_ratio": ratio, "split_num": num, + "split_parts": parts, } layout_current_method.set(f"Current method: {method}", "text") @@ -90,6 +102,12 @@ def _set_settings_from_json(settings: dict): f"Split value: {num} items per dataset", "text", ) + elif method == "parts": + layout_current_value.show() + layout_current_value.set( + f"Split value: {parts} parts", + "text", + ) else: layout_current_value.hide()