Skip to content

Commit

Permalink
Add raise_if_computes for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 8, 2024
1 parent 41692c8 commit b0dded3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def gensym(name="array"):
return f"{name}-{sym_counter:03}"


# see cubed.testing
compute_should_raise = False


# type representing either a CoreArray or a public-facing Array
T_ChunkedArray = TypeVar("T_ChunkedArray", bound="CoreArray")

Expand Down Expand Up @@ -269,6 +273,8 @@ def compute(
If True, intermediate arrays that have already been computed won't be
recomputed. Default is False.
"""
if compute_should_raise:
raise RuntimeError("'compute' was called")
spec = check_array_specs(arrays) # guarantees all arrays have same spec
plan = arrays_to_plan(*arrays)
if executor is None:
Expand Down
13 changes: 13 additions & 0 deletions cubed/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from contextlib import contextmanager

from cubed.core import array


@contextmanager
def raise_if_computes():
"""Returns a context manager for testing that ``compute`` is not called."""
array.compute_should_raise = True
try:
yield
finally:
array.compute_should_raise = False
21 changes: 21 additions & 0 deletions cubed/tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import cubed.array_api as xp
from cubed.testing import raise_if_computes


def test_raise_if_computes():
# shouldn't raise since compute has not been called
with raise_if_computes():
a = xp.ones((3, 3), chunks=(2, 2))
b = xp.negative(a)

# should raise since compute is called
with pytest.raises(RuntimeError):
with raise_if_computes():
b.compute()

# shouldn't raise since we are outside the context manager
assert_array_equal(b.compute(), -np.ones((3, 3)))

0 comments on commit b0dded3

Please sign in to comment.