Skip to content

Commit

Permalink
[Lang] User-friendly exception when copying between ti.field (#3442)
Browse files Browse the repository at this point in the history
Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
ntlm1686 and taichi-gardener authored Dec 20, 2021
1 parent 3818ba3 commit 420e6b6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ def copy_from(self, other):
Args:
other (Field): The source field.
"""
assert isinstance(other, Field)
assert len(self.shape) == len(other.shape)
if not isinstance(other, Field):
raise TypeError('Cannot copy from a non-field object')
if self.shape != other.shape:
raise ValueError(f"ti.field shape {self.shape} does not match"
f" the source field shape {other.shape}")
taichi.lang.meta.tensor_to_tensor(self, other)

@python_scope
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,35 @@ def test_field_name():
for i in range(10):
d.append(ti.field(dtype=ti.f32, shape=(2, 3), name=f'd{i}'))
assert d[i].name == f'd{i}'


@ti.test()
@pytest.mark.parametrize('shape', field_shapes)
@pytest.mark.parametrize('dtype', [ti.i32, ti.f32])
def test_field_copy_from(shape, dtype):
x = ti.field(dtype=ti.f32, shape=shape)
other = ti.field(dtype=dtype, shape=shape)
other.fill(1)
x.copy_from(other)
convert = lambda arr: arr[0] if len(arr) == 1 else arr
assert (convert(x.shape) == shape)
assert (x.dtype == ti.f32)
assert ((x.to_numpy() == 1).all())


@ti.test()
def test_field_copy_from_with_mismatch_shape():
x = ti.field(dtype=ti.f32, shape=(2, 3))
for other_shape in [(2, ), (2, 2), (2, 3, 4)]:
other = ti.field(dtype=ti.f16, shape=other_shape)
with pytest.raises(ValueError):
x.copy_from(other)


@ti.test()
def test_field_copy_from_with_non_filed_object():
import numpy as np
x = ti.field(dtype=ti.f32, shape=(2, 3))
other = np.zeros((2, 3))
with pytest.raises(TypeError):
x.copy_from(other)

0 comments on commit 420e6b6

Please sign in to comment.