Skip to content

Commit

Permalink
Numpy fill accepts also variables (#1420)
Browse files Browse the repository at this point in the history
This PR is for addressing issue
[#1389](#1389).

---------

Co-authored-by: acalotoiu <[email protected]>
Co-authored-by: BenWeber42 <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent 43ca982 commit 40ed438
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
7 changes: 1 addition & 6 deletions dace/frontend/common/op_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ def _get_all_bases(class_or_name: Union[str, Type]) -> List[str]:
"""
if isinstance(class_or_name, str):
return [class_or_name]

classes = [class_or_name.__name__]
for base in class_or_name.__bases__:
classes.extend(_get_all_bases(base))

return deduplicate(classes)
return [base.__name__ for base in class_or_name.__mro__]


class Replacements(object):
Expand Down
25 changes: 14 additions & 11 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ class ExtNodeTransformer(ast.NodeTransformer):
bodies in order to discern DaCe statements from others.
"""
def visit_TopLevel(self, node):
clsname = type(node).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
return getattr(self, "visit_TopLevel" + clsname)(node)
visitor_name = "visit_TopLevel" + type(node).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
return visitor(node)
else:
return self.visit(node)

Expand Down Expand Up @@ -480,21 +481,23 @@ class ExtNodeVisitor(ast.NodeVisitor):
top-level expressions in bodies in order to discern DaCe statements
from others. """
def visit_TopLevel(self, node):
clsname = type(node).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
getattr(self, "visit_TopLevel" + clsname)(node)
visitor_name = "visit_TopLevel" + type(node).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
return visitor(node)
else:
self.visit(node)
return self.visit(node)

def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
for value in old_value:
if isinstance(value, ast.AST):
if (field == 'body' or field == 'orelse'):
clsname = type(value).__name__
if getattr(self, "visit_TopLevel" + clsname, False):
getattr(self, "visit_TopLevel" + clsname)(value)
if field == 'body' or field == 'orelse':
visitor_name = "visit_TopLevel" + type(value).__name__
if hasattr(self, visitor_name):
visitor = getattr(self, visitor_name)
visitor(value)
else:
self.visit(value)
else:
Expand Down
45 changes: 37 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,10 @@ def _elementwise(pv: 'ProgramVisitor',
else:
state.add_mapped_tasklet(
name="_elementwise_",
map_ranges={'__i%d' % i: '0:%s' % n
for i, n in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
map_ranges={f'__i{dim}': f'0:{N}' for dim, N in enumerate(inparr.shape)},
inputs={'__inp': Memlet.simple(in_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
code=code,
outputs={'__out': Memlet.simple(out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))},
outputs={'__out': Memlet.simple(out_array, ','.join([f'__i{dim}' for dim in range(len(inparr.shape))]))},
external_edges=True)

return out_array
Expand Down Expand Up @@ -4232,10 +4231,40 @@ def _ndarray_copy(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str) ->
@oprepo.replaces_method('Array', 'fill')
@oprepo.replaces_method('Scalar', 'fill')
@oprepo.replaces_method('View', 'fill')
def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Number) -> str:
if not isinstance(value, (Number, np.bool_)):
raise mem_parser.DaceSyntaxError(pv, None, "Fill value {f} must be a number!".format(f=value))
return _elementwise(pv, sdfg, state, "lambda x: {}".format(value), arr, arr)
def _ndarray_fill(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, value: Union[str, Number,
sp.Expr]) -> str:
assert arr in sdfg.arrays

if isinstance(value, sp.Expr):
raise NotImplementedError(
f"{arr}.fill is not implemented for symbolic expressions ({value}).") # Look at `full`.

if isinstance(value, (Number, np.bool_)):
body = value
inputs = {}
elif isinstance(value, str) and value in sdfg.arrays:
value_array = sdfg.arrays[value]
if not isinstance(value_array, data.Scalar):
raise mem_parser.DaceSyntaxError(
pv, None, f"{arr}.fill requires a scalar argument, but {type(value_array)} was given.")
body = '__inp'
inputs = {'__inp': dace.Memlet(data=value, subset='0')}
else:
raise mem_parser.DaceSyntaxError(pv, None, f"Unsupported argument '{value}' for {arr}.fill.")

shape = sdfg.arrays[arr].shape
state.add_mapped_tasklet(
'_numpy_fill_',
map_ranges={
f"__i{dim}": f"0:{s}"
for dim, s in enumerate(shape)
},
inputs=inputs,
code=f"__out = {body}",
outputs={'__out': dace.Memlet.simple(arr, ",".join([f"__i{dim}" for dim in range(len(shape))]))},
external_edges=True)

return arr


@oprepo.replaces_method('Array', 'reshape')
Expand Down
14 changes: 14 additions & 0 deletions tests/numpy/ndarray_attributes_methods_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def test_fill(A: dace.int32[M, N]):
return A # return A.fill(5) doesn't work because A is not copied


@compare_numpy_output()
def test_fill2(A: dace.int32[M, N], a: dace.int32):
A.fill(a)
return A # return A.fill(5) doesn't work because A is not copied


@compare_numpy_output()
def test_fill3(A: dace.int32[M, N], a: dace.int32):
A.fill(a + 1)
return A


@compare_numpy_output()
def test_reshape(A: dace.float32[N, N]):
return A.reshape([1, N * N])
Expand Down Expand Up @@ -124,6 +136,8 @@ def test_any():
test_copy()
test_astype()
test_fill()
test_fill2()
test_fill3()
test_reshape()
test_transpose1()
test_transpose2()
Expand Down

0 comments on commit 40ed438

Please sign in to comment.