diff --git a/ear/core/importance.py b/ear/core/importance.py index 1d90fb8..dda315b 100644 --- a/ear/core/importance.py +++ b/ear/core/importance.py @@ -18,6 +18,7 @@ def filter_by_importance(rendering_items, Yields: RenderingItem """ f = mute_audioBlockFormat_by_importance(rendering_items, threshold) + f = mute_hoa_channels_by_importance(f, threshold) f = filter_audioObject_by_importance(f, threshold) f = filter_audioPackFormat_by_importance(f, threshold) return f @@ -115,3 +116,28 @@ def mute_unimportant_block(type_metadata): item.metadata_source, mute_unimportant_block ), ) + + +def mute_hoa_channels_by_importance(rendering_items, threshold): + def mute_unimportant_channels(type_metadata): + if min(type_metadata.importances) < threshold: + new_gains = [ + 0.0 if importance < threshold else gain + for (gain, importance) in zip( + type_metadata.gains, type_metadata.importances + ) + ] + return evolve(type_metadata, gains=new_gains) + else: + return type_metadata + + for item in rendering_items: + if isinstance(item, HOARenderingItem): + yield evolve( + item, + metadata_source=MetadataSourceMap( + item.metadata_source, mute_unimportant_channels + ), + ) + else: + yield item diff --git a/ear/core/test/test_importance.py b/ear/core/test/test_importance.py index ec0bbb1..198adf2 100644 --- a/ear/core/test/test_importance.py +++ b/ear/core/test/test_importance.py @@ -4,6 +4,7 @@ DirectSpeakersRenderingItem, DirectSpeakersTypeMetadata, HOARenderingItem, + HOATypeMetadata, DirectTrackSpec, ImportanceData, MetadataSourceIter, @@ -20,6 +21,7 @@ filter_audioObject_by_importance, filter_audioPackFormat_by_importance, ) +from attrs import evolve from fractions import Fraction import pytest @@ -185,3 +187,38 @@ def test_importance_filter_blocks_single_channel(make_type_metadata, make_render rendering_items_out = filter_by_importance(rendering_items, 6) [rendering_item_out] = rendering_items_out assert get_blocks(rendering_item_out.metadata_source) == expected + + +@pytest.mark.parametrize( + "gains", + [ + [1.0, 1.0, 1.0, 1.0], + [0.5, 0.25, 0.25, 0.25], + ], +) +def test_importance_filter_hoa(gains): + type_metadatas = [ + HOATypeMetadata( + orders=[0, 1, 1, 1], + degrees=[0, -1, 0, 1], + importances=[6, 5, 5, 5], + normalization="SN3D", + gains=gains, + ), + ] + expected = [ + evolve( + type_metadatas[0], + gains=[gains[0], 0.0, 0.0, 0.0], + ), + ] + rendering_items = [ + HOARenderingItem( + track_specs=[DirectTrackSpec(i) for i in range(4)], + metadata_source=MetadataSourceIter(type_metadatas), + ), + ] + + rendering_items_out = filter_by_importance(rendering_items, 6) + [rendering_item_out] = rendering_items_out + assert get_blocks(rendering_item_out.metadata_source) == expected