Skip to content

Commit

Permalink
Nicely error out when a data-dependent numpy.full variant is used.
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Oct 29, 2024
1 parent 3cadf6f commit 7b2e2bb
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 73 deletions.
14 changes: 10 additions & 4 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4630,10 +4630,16 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
self._add_state('call_%d' % node.lineno)
self.last_block.set_default_lineinfo(self.current_lineinfo)

if found_ufunc:
result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords)
else:
result = func(self, self.sdfg, self.last_block, *args, **keywords)
try:
if found_ufunc:
result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords)
else:
result = func(self, self.sdfg, self.last_block, *args, **keywords)
except DaceSyntaxError as ex:
# Attach source information to exception
if ex.node is None:
ex.node = node
raise

self.last_block.set_default_lineinfo(None)

Expand Down
108 changes: 43 additions & 65 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,29 +324,28 @@ def _numpy_full(pv: ProgramVisitor,
dtype = dtype or vtype

# Handle one-dimensional inputs
try:
iter(shape)
except TypeError:
if isinstance(shape, (Number, str)) or symbolic.issymbolic(shape):
shape = [shape]

if any(isinstance(s, str) for s in shape):
raise DaceSyntaxError(
pv, None, f'Data-dependent shape {shape} is currently not allowed. Only constants '
'and symbolic values can be used.')

name, _ = sdfg.add_temp_transient(shape, dtype)

if is_data:
state.add_mapped_tasklet(
'_numpy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
},
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)},
dict(__inp=dace.Memlet(data=fill_value, subset='0')),
"__out = __inp",
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
else:
state.add_mapped_tasklet(
'_numpy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
}, {},
'_numpy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
Expand Down Expand Up @@ -466,10 +465,8 @@ def _numpy_flip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, axis
inpidx = ','.join([f'__i{i}' for i in range(ndim)])
outidx = ','.join([f'{s} - __i{i} - 1' if a else f'__i{i}' for i, (a, s) in enumerate(zip(axis, desc.shape))])
state.add_mapped_tasklet(name="_numpy_flip_",
map_ranges={
f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)
},
map_ranges={f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)},
inputs={'__inp': Memlet(f'{arr}[{inpidx}]')},
code='__out = __inp',
outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')},
Expand Down Expand Up @@ -539,10 +536,8 @@ def _numpy_rot90(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, k=1

outidx = ','.join(out_indices)
state.add_mapped_tasklet(name="_rot90_",
map_ranges={
f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)
},
map_ranges={f'__i{i}': f'0:{s}:1'
for i, s in enumerate(desc.shape)},
inputs={'__inp': Memlet(f'{arr}[{inpidx}]')},
code='__out = __inp',
outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')},
Expand Down Expand Up @@ -651,7 +646,8 @@ def _elementwise(pv: 'ProgramVisitor',
else:
state.add_mapped_tasklet(
name="_elementwise_",
map_ranges={f'__i{dim}': f'0:{N}' for dim, N in enumerate(inparr.shape)},
map_ranges={f'__i{dim}': f'0:{N}'
for dim, N in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(in_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
code=code,
outputs={'__out': Memlet.simple(out_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
Expand Down Expand Up @@ -701,10 +697,8 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype:
else:
state.add_mapped_tasklet(
name=func,
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)
},
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
code='__out = {f}(__inp)'.format(f=func),
outputs={'__out': Memlet.simple(outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
Expand Down Expand Up @@ -1053,27 +1047,22 @@ def _argminmax(pv: ProgramVisitor,
code = "__init = _val_and_idx(val={}, idx=-1)".format(
dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype))

nest.add_state().add_mapped_tasklet(name="_arg{}_convert_".format(func),
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis
},
inputs={},
code=code,
outputs={
'__init':
Memlet.simple(
reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)
nest.add_state().add_mapped_tasklet(
name="_arg{}_convert_".format(func),
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
inputs={},
code=code,
outputs={
'__init': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
},
external_edges=True)

nest.add_state().add_mapped_tasklet(
name="_arg{}_reduce_".format(func),
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape)
},
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape)},
inputs={'__in': Memlet.simple(a, ','.join('__i%d' % i for i in range(len(a_arr.shape))))},
code="__out = _val_and_idx(idx={}, val=__in)".format("__i%d" % axis),
outputs={
Expand All @@ -1093,10 +1082,8 @@ def _argminmax(pv: ProgramVisitor,

nest.add_state().add_mapped_tasklet(
name="_arg{}_extract_".format(func),
map_ranges={
'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis
},
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(a_arr.shape) if i != axis},
inputs={
'__in': Memlet.simple(reduced_structs,
','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis))
Expand Down Expand Up @@ -1219,10 +1206,9 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
opcode = 'not'

name, _ = sdfg.add_temp_transient(arr1.shape, restype, arr1.storage)
state.add_mapped_tasklet("_%s_" % opname, {
'__i%d' % i: '0:%s' % s
for i, s in enumerate(arr1.shape)
}, {'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
state.add_mapped_tasklet("_%s_" % opname, {'__i%d' % i: '0:%s' % s
for i, s in enumerate(arr1.shape)},
{'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
'__out = %s __in1' % opcode,
{'__out': Memlet.simple(name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))},
external_edges=True)
Expand Down Expand Up @@ -4323,10 +4309,8 @@ def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, va
shape = sdfg.arrays[arr].shape
state.add_mapped_tasklet(
'_numpy_fill_',
map_ranges={
f"__i{dim}": f"0:{s}"
for dim, s in enumerate(shape)
},
map_ranges={f"__i{dim}": f"0:{s}"
for dim, s in enumerate(shape)},
inputs=inputs,
code=f"__out = {body}",
outputs={'__out': dace.Memlet.simple(arr, ",".join([f"__i{dim}" for dim in range(len(shape))]))},
Expand Down Expand Up @@ -4550,12 +4534,14 @@ def _ndarray_astype(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str,
dtype = dtypes.typeclass(dtype)
return _datatype_converter(sdfg, state, arr, dtype)[0]


@oprepo.replaces_operator('Array', 'MatMult', otherclass='StorageType')
def _cast_storage(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, stype: dace.StorageType) -> str:
desc = sdfg.arrays[arr]
desc.storage = stype
return arr


# Replacements that need ufuncs ###############################################
# TODO: Fix by separating to different modules and importing

Expand Down Expand Up @@ -4759,13 +4745,7 @@ def _tensordot(pv: 'ProgramVisitor',

@oprepo.replaces("cupy._core.core.ndarray")
@oprepo.replaces("cupy.ndarray")
def _define_cupy_local(
pv: "ProgramVisitor",
sdfg: SDFG,
state: SDFGState,
shape: Shape,
dtype: typeclass
):
def _define_cupy_local(pv: "ProgramVisitor", sdfg: SDFG, state: SDFGState, shape: Shape, dtype: typeclass):
"""Defines a local array in a DaCe program."""
if not isinstance(shape, (list, tuple)):
shape = [shape]
Expand Down Expand Up @@ -4793,10 +4773,8 @@ def _cupy_full(pv: ProgramVisitor,
name, _ = sdfg.add_temp_transient(shape, dtype, storage=dtypes.StorageType.GPU_Global)

state.add_mapped_tasklet(
'_cupy_full_', {
"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)
}, {},
'_cupy_full_', {"__i{}".format(i): "0: {}".format(s)
for i, s in enumerate(shape)}, {},
"__out = {}".format(fill_value),
dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))),
external_edges=True)
Expand Down
8 changes: 4 additions & 4 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,13 +761,13 @@ def add_symbol(self, name, stype, find_new_name: bool = False):
if name in self.symbols:
raise FileExistsError(f'Symbol "{name}" already exists in SDFG')
if name in self.arrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a data descriptor.')
raise FileExistsError(f'Cannot create symbol "{name}", the name is used by a data descriptor.')
if name in self._subarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a subarray.')
raise FileExistsError(f'Cannot create symbol "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a RedistrArray.')
raise FileExistsError(f'Cannot create symbol "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a ProcessGrid.')
raise FileExistsError(f'Cannot create symbol "{name}", the name is used by a ProcessGrid.')
if not isinstance(stype, dtypes.typeclass):
stype = dtypes.dtype_to_typeclass(stype)
self.symbols[name] = stype
Expand Down
28 changes: 28 additions & 0 deletions tests/numpy/array_creation_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace.frontend.python.common import DaceSyntaxError
import numpy as np
from common import compare_numpy_output
import pytest

# M = dace.symbol('M')
# N = dace.symbol('N')
Expand Down Expand Up @@ -218,6 +220,30 @@ def zeros_symbolic_size():
assert (out.dtype == np.uint32)


def test_ones_scalar_size_scalar():

@dace.program
def ones_scalar_size(k: dace.int32):
a = np.ones(k, dtype=np.uint32)
return np.sum(a)

with pytest.raises(DaceSyntaxError):
out = ones_scalar_size(20)
assert out == 20


def test_ones_scalar_size():

@dace.program
def ones_scalar_size(k: dace.int32):
a = np.ones((k, k), dtype=np.uint32)
return np.sum(a)

with pytest.raises(DaceSyntaxError):
out = ones_scalar_size(20)
assert out == 20 * 20


if __name__ == "__main__":
test_empty()
test_empty_like1()
Expand Down Expand Up @@ -246,3 +272,5 @@ def zeros_symbolic_size():
test_strides_2()
test_strides_3()
test_zeros_symbolic_size_scalar()
test_ones_scalar_size_scalar()
test_ones_scalar_size()

0 comments on commit 7b2e2bb

Please sign in to comment.