diff --git a/plugins/io/__init__.py b/plugins/io/__init__.py index 35c83a43..31888ae5 100644 --- a/plugins/io/__init__.py +++ b/plugins/io/__init__.py @@ -2117,11 +2117,16 @@ def _is_valid_csv_field(field): return isinstance(field, fof._PRIMITIVE_FIELDS) -def _get_fields_with_type(view, type): +def _get_fields_with_type(view, type, media_type="image"): + get_field_schema = ( + view.get_frame_field_schema + if media_type == fom.VIDEO + else view.get_field_schema + ) if issubclass(type, fo.Field): - return view.get_field_schema(ftype=type).keys() + return get_field_schema(ftype=type).keys() - return view.get_field_schema(embedded_doc_type=type).keys() + return get_field_schema(embedded_doc_type=type).keys() def _get_export_types(view, export_type, allow_coercion=False): @@ -2523,6 +2528,8 @@ def execute(self, ctx): target = ctx.params.get("target", None) output_dir = _parse_path(ctx, "output_dir") label_fields = ctx.params.get("label_fields", None) + if ctx.dataset.media_type == fom.VIDEO: + label_fields = [f"frames.{field}" for field in label_fields] overwrite = ctx.params.get("overwrite", False) target_view = _get_target_view(ctx, target) @@ -2573,7 +2580,10 @@ def _draw_labels_inputs(ctx, inputs): target_view = _get_target_view(ctx, target) label_field_choices = types.Dropdown(multiple=True) - for field in _get_fields_with_type(target_view, fo.Label): + label_fields = _get_fields_with_type( + target_view, fo.Label, media_type=ctx.dataset.media_type + ) + for field in label_fields: label_field_choices.add_choice(field, label=field) inputs.list(