Skip to content

Commit

Permalink
Use Dask's 'broadcast trick' to save memory for single-valued arrays (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 26, 2024
1 parent 2e28072 commit 7314d08
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
from cubed.utils import memory_repr
from cubed.utils import broadcast_trick, memory_repr


class VirtualEmptyArray:
Expand All @@ -33,7 +33,8 @@ def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
return nxp.empty(indexer.shape, dtype=self.dtype)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)

@property
def oindex(self):
Expand Down Expand Up @@ -68,7 +69,10 @@ def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
return nxp.full(indexer.shape, fill_value=self.fill_value, dtype=self.dtype)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.full)(
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def oindex(self):
Expand Down
16 changes: 16 additions & 0 deletions cubed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from cubed.backend_array_api import namespace as nxp
from cubed.utils import (
block_id_to_offset,
broadcast_trick,
chunk_memory,
extract_stack_summaries,
join_path,
Expand Down Expand Up @@ -153,3 +156,16 @@ def test_map_nested_iterators():
assert count == 2
assert list(out1) == [4, 5]
assert count == 4


def test_broadcast_trick():
a = nxp.ones((10, 10), dtype=nxp.int8)
b = broadcast_trick(nxp.ones)((10, 10), dtype=nxp.int8)

assert_array_equal(a, b)
assert a.nbytes == 100
assert b.base.nbytes == 1

a = nxp.ones((), dtype=nxp.int8)
b = broadcast_trick(nxp.ones)((), dtype=nxp.int8)
assert_array_equal(a, b)
21 changes: 21 additions & 0 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from itertools import islice
from math import prod
from operator import add
Expand All @@ -17,6 +18,7 @@
import numpy as np
import tlz as toolz

from cubed.backend_array_api import namespace as nxp
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks
from cubed.vendor.dask.array.core import _check_regular_chunks

Expand Down Expand Up @@ -289,3 +291,22 @@ def map_nested(func, seq):
return map(lambda item: map_nested(func, item), seq)
else:
return func(seq)


def _broadcast_trick_inner(func, shape, *args, **kwargs):
# cupy-specific hack. numpy is happy with hardcoded shape=().
null_shape = () if shape == () else 1

return nxp.broadcast_to(func(*args, shape=null_shape, **kwargs), shape)


def broadcast_trick(func):
"""Apply Dask's broadcast trick to array API functions that produce arrays
containing a single value to save space in memory.
Note that this should only be used for arrays that never mutated.
"""
inner = partial(_broadcast_trick_inner, func)
inner.__doc__ = func.__doc__
inner.__name__ = func.__name__
return inner

0 comments on commit 7314d08

Please sign in to comment.