diff --git a/nixio/block.py b/nixio/block.py index fcbe715b..da4d8097 100644 --- a/nixio/block.py +++ b/nixio/block.py @@ -84,8 +84,6 @@ def create_multi_tag(self, name="", type_="", positions=0, util.check_entity_name_and_type(name, type_) util.check_entity_input(positions) - if not isinstance(positions, DataArray): - raise TypeError("DataArray expected for 'positions'") multi_tags = self._h5group.open_group("multi_tags") if name in multi_tags: raise exceptions.DuplicateName("create_multi_tag") diff --git a/nixio/multi_tag.py b/nixio/multi_tag.py index b5e0ed0a..89940bf9 100644 --- a/nixio/multi_tag.py +++ b/nixio/multi_tag.py @@ -16,7 +16,7 @@ from .data_view import DataView from .link_type import LinkType from .exceptions import (OutOfBounds, IncompatibleDimensions, - UninitializedEntity) + UninitializedEntity, DuplicateName) from .section import Section @@ -52,7 +52,14 @@ def positions(self, da): raise TypeError("MultiTag.positions cannot be None.") if "positions" in self._h5group: del self._h5group["positions"] - self._h5group.create_link(da, "positions") + pos = da + if not isinstance(da, DataArray): + blk = self._parent + name = "f{self.name}-positions" + if name in blk.data_arrays: + del blk.data_arrays[name] + pos = blk.create_data_array(name, "multi_tag_positions", data=da) + self._h5group.create_link(pos, "positions") if self._parent._parent.time_auto_update: self.force_updated_at() @@ -74,7 +81,14 @@ def extents(self, da): if da is None: del self._h5group["extents"] else: - self._h5group.create_link(da, "extents") + ext = da + if not isinstance(da, DataArray): + blk = self._parent + name = "f{self.name}-extents" + if name in blk.data_arrays: + del blk.data_arrays[name] + ext = blk.create_data_array(name, "multi_tag_positions", data=da) + self._h5group.create_link(ext, "extents") if self._parent._parent.time_auto_update: self.force_updated_at() diff --git a/nixio/test/test_multi_tag.py b/nixio/test/test_multi_tag.py index 7d045079..972f89e2 100644 --- a/nixio/test/test_multi_tag.py +++ b/nixio/test/test_multi_tag.py @@ -102,6 +102,19 @@ def tearDown(self): self.file.close() self.tmpdir.cleanup() + def test_multi_tag_new_constructor(self): + pos = np.random.random((2,3)) + ext = np.random.random((2,3)) + mt = self.block.create_multi_tag("conv_test", "test", pos) + mt.extents = ext + np.testing.assert_almost_equal(pos, mt.positions[:]) + np.testing.assert_almost_equal(ext, mt.extents[:]) + # try reset positions and ext + pos_new = np.random.random((2,3)) + ext_new = np.random.random((2,3)) + mt.positions = pos_new + mt.extents = ext_new + def test_multi_tag_flex(self): pos1d = self.block.create_data_array("pos1", "pos", data=[[0], [1]]) pos1d1d = self.block.create_data_array("pos1d1d", "pos", data=[0, 1])