Skip to content

Commit

Permalink
make it work for nd hists
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp authored and henryiii committed Aug 22, 2024
1 parent 51647e3 commit 95ddb0f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
slices: list[_core.algorithm.reduce_command] = []
pick_each: dict[int, int] = {}
pick_set: dict[int, list[int]] = {}
reduced: CppHistogram | None = None

# Compute needed slices and projections
for i, ind in enumerate(indexes):
Expand Down Expand Up @@ -890,7 +891,8 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
)
# rebinning with groups
elif len(groups) != 0:
reduced = self._hist
if not reduced:
reduced = self._hist
axes = [reduced.axis(x) for x in range(reduced.rank())]
reduced_view = reduced.view(flow=True)
new_axes_indices = [axes[i].edges[0]]
Expand All @@ -913,7 +915,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
j = 1
for new_j, group in enumerate(groups):
for _ in range(group):
pos = [slice] * (i)
pos = [slice(None)] * (i)
new_view[(*pos, new_j + 1, ...)] += reduced_view[ # type: ignore[arg-type]
(*pos, j, ...) # type: ignore[arg-type]
]
Expand All @@ -922,10 +924,9 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
reduced = new_reduced

# Will be updated below
if slices or pick_set or pick_each or integrations:
if (slices or pick_set or pick_each or integrations) and not reduced:
reduced = self._hist
elif len(groups) == 0:
logger.debug("Reduce actions are all empty, just making a copy")
elif not reduced:
reduced = copy.copy(self._hist)

if pick_each:
Expand Down
52 changes: 51 additions & 1 deletion tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,6 @@ def test_rebin_1d():

hs = h[{0: slice(None, None, bh.tag.Rebinner(4))}]
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])
print("Here")

hs = h[{0: bh.tag.Rebinner(4)}]
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])
Expand Down Expand Up @@ -664,8 +663,59 @@ def test_rebin_nd():
assert h[{1: s[:: bh.rebin(2)]}].axes.size == (20, 15, 40)
assert h[{2: s[:: bh.rebin(2)]}].axes.size == (20, 30, 20)

assert h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (3, 30, 40)
assert h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 3, 40)
assert h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes.size == (20, 30, 3)
assert np.all(
np.isclose(
h[{0: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[0].edges, [1.0, 1.1, 1.3, 1.6]
)
)
assert np.all(
np.isclose(
h[{1: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[1].edges,
[1.0, 1.06666667, 1.2, 1.4],
)
)
assert np.all(
np.isclose(
h[{2: s[:: bh.rebin(groups=[1, 2, 3])]}].axes[2].edges,
[1.0, 1.05, 1.15, 1.3],
)
)

assert h[{0: s[:: bh.rebin(2)], 2: s[:: bh.rebin(2)]}].axes.size == (10, 30, 20)

assert h[
{0: s[:: bh.rebin(groups=[1, 2, 3])], 2: s[:: bh.rebin(groups=[1, 2, 3])]}
].axes.size == (3, 30, 3)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 3])],
2: s[:: bh.rebin(groups=[1, 2, 3])],
}
]
.axes[0]
.edges,
[1.0, 1.1, 1.3, 1.6],
)
)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 3])],
2: s[:: bh.rebin(groups=[1, 2, 3])],
}
]
.axes[2]
.edges,
[1.0, 1.05, 1.15, 1.3],
)
)

assert h[{1: s[:: bh.sum]}].axes.size == (20, 40)
assert h[{1: bh.sum}].axes.size == (20, 40)

Expand Down

0 comments on commit 95ddb0f

Please sign in to comment.