Skip to content

Commit

Permalink
Allow supplying np_arrays or list for multi_tags positions
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchekc committed Mar 6, 2020
1 parent b19e5ad commit 41ceed4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
2 changes: 0 additions & 2 deletions nixio/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def create_multi_tag(self, name="", type_="", positions=0,

util.check_entity_name_and_type(name, type_)
util.check_entity_input(positions)
if not isinstance(positions, DataArray):
raise TypeError("DataArray expected for 'positions'")
multi_tags = self._h5group.open_group("multi_tags")
if name in multi_tags:
raise exceptions.DuplicateName("create_multi_tag")
Expand Down
20 changes: 17 additions & 3 deletions nixio/multi_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .data_view import DataView
from .link_type import LinkType
from .exceptions import (OutOfBounds, IncompatibleDimensions,
UninitializedEntity)
UninitializedEntity, DuplicateName)
from .section import Section


Expand Down Expand Up @@ -52,7 +52,14 @@ def positions(self, da):
raise TypeError("MultiTag.positions cannot be None.")
if "positions" in self._h5group:
del self._h5group["positions"]
self._h5group.create_link(da, "positions")
pos = da
if not isinstance(da, DataArray):
blk = self._parent
name = "f{self.name}-positions"
if name in blk.data_arrays:
del blk.data_arrays[name]
pos = blk.create_data_array(name, "multi_tag_positions", data=da)
self._h5group.create_link(pos, "positions")
if self._parent._parent.time_auto_update:
self.force_updated_at()

Expand All @@ -74,7 +81,14 @@ def extents(self, da):
if da is None:
del self._h5group["extents"]
else:
self._h5group.create_link(da, "extents")
ext = da
if not isinstance(da, DataArray):
blk = self._parent
name = "f{self.name}-extents"
if name in blk.data_arrays:
del blk.data_arrays[name]
ext = blk.create_data_array(name, "multi_tag_positions", data=da)
self._h5group.create_link(ext, "extents")
if self._parent._parent.time_auto_update:
self.force_updated_at()

Expand Down
13 changes: 13 additions & 0 deletions nixio/test/test_multi_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ def tearDown(self):
self.file.close()
self.tmpdir.cleanup()

def test_multi_tag_new_constructor(self):
pos = np.random.random((2,3))
ext = np.random.random((2,3))
mt = self.block.create_multi_tag("conv_test", "test", pos)
mt.extents = ext
np.testing.assert_almost_equal(pos, mt.positions[:])
np.testing.assert_almost_equal(ext, mt.extents[:])
# try reset positions and ext
pos_new = np.random.random((2,3))
ext_new = np.random.random((2,3))
mt.positions = pos_new
mt.extents = ext_new

def test_multi_tag_flex(self):
pos1d = self.block.create_data_array("pos1", "pos", data=[[0], [1]])
pos1d1d = self.block.create_data_array("pos1d1d", "pos", data=[0, 1])
Expand Down

0 comments on commit 41ceed4

Please sign in to comment.