Skip to content

Commit

Permalink
Merge pull request #961 from maelso/fix-redeclaration
Browse files Browse the repository at this point in the history
Scheduler: Fix redeclaration
  • Loading branch information
FabioLuporini authored Dec 2, 2019
2 parents 4692baa + 185f894 commit 959479e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 16 deletions.
25 changes: 15 additions & 10 deletions devito/ir/iet/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ExpressionBundle, Transformer, FindNodes, FindSymbols,
MapExprStmts, XSubs, iet_analyze)
from devito.symbolics import IntDiv, ccode, xreplace_indices
from devito.tools import as_mapper, as_tuple
from devito.tools import as_mapper, as_tuple, flatten
from devito.types import ConditionalDimension

__all__ = ['iet_build', 'iet_insert_decls', 'iet_insert_casts']
Expand Down Expand Up @@ -168,7 +168,7 @@ def iet_insert_decls(iet, external):
continue
elif i._mem_stack:
# On the stack
allocator.push_object_on_stack(iet[0], i)
allocator.push_array_on_stack(iet[0], i)
else:
# On the heap
allocator.push_array_on_heap(i)
Expand Down Expand Up @@ -199,16 +199,21 @@ def __init__(self):
self.stack = OrderedDict()

def push_object_on_stack(self, scope, obj):
"""Define an Array or a composite type (e.g., a struct) on the stack."""
"""Define a LocalObject on the stack."""
handle = self.stack.setdefault(scope, OrderedDict())
handle[obj] = Element(c.Value(obj._C_typename, obj.name))

if obj.is_LocalObject:
handle[obj] = Element(c.Value(obj._C_typename, obj.name))
else:
shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape)
alignment = "__attribute__((aligned(%d)))" % obj._data_alignment
value = "%s%s %s" % (obj.name, shape, alignment)
handle[obj] = Element(c.POD(obj.dtype, value))
def push_array_on_stack(self, scope, obj):
"""Define an Array on the stack."""
handle = self.stack.setdefault(scope, OrderedDict())

if obj in flatten(self.stack.values()):
return

shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape)
alignment = "__attribute__((aligned(%d)))" % obj._data_alignment
value = "%s%s %s" % (obj.name, shape, alignment)
handle[obj] = Element(c.POD(obj.dtype, value))

def push_scalar_on_stack(self, scope, expr):
"""Define a Scalar on the stack."""
Expand Down
22 changes: 20 additions & 2 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension,
NODE, CELL, dimensions, configuration, TensorFunction,
TensorTimeFunction, VectorFunction, VectorTimeFunction)
from devito.ir.iet import (Expression, Iteration, FindNodes, IsPerfectIteration,
from devito.ir.equations import ClusterizedEq
from devito.ir.iet import (Conditional, Expression, Iteration, FindNodes,
IsPerfectIteration, derive_parameters, iet_insert_decls,
retrieve_iteration_tree)
from devito.ir.support import Any, Backward, Forward
from devito.symbolics import indexify, retrieve_indexed
from devito.symbolics import ListInitializer, indexify, retrieve_indexed
from devito.tools import flatten
from devito.types import Array, Scalar

Expand Down Expand Up @@ -1143,6 +1145,22 @@ def test_stack_vector_temporaries(self):
timers->section0 += (double)(end_section0.tv_sec-start_section0.tv_sec)\
+(double)(end_section0.tv_usec-start_section0.tv_usec)/1000000;""" in str(operator)

def test_conditional_declarations(self):
a = Array(name='a', dimensions=(x,), dtype=np.int32, scope='stack')
list_initialize = Expression(ClusterizedEq(Eq(a, ListInitializer([0, 0]))))
iet = Conditional(x < 3, list_initialize, list_initialize)
parameters = derive_parameters(iet, True)
iet = iet_insert_decls(iet, parameters)
assert str(iet[0]) == """\
if (x < 3)
{
int a[x_size] = {0, 0};
}
else
{
int a[x_size] = {0, 0};
}"""


class TestLoopScheduling(object):

Expand Down
23 changes: 19 additions & 4 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# thus invalidating all of the future tests. This is guaranteed by the
# `pytestmark` above
from devito import Eq, Function, Grid, Operator, TimeFunction, configuration # noqa
from devito.ir.equations import ClusterizedEq # noqa
from devito.ir.iet import Conditional, Expression, derive_parameters, iet_insert_decls # noqa
from devito.ops.node_factory import OPSNodeFactory # noqa
from devito.ops.transformer import create_ops_arg, create_ops_dat, make_ops_ast, to_ops_stencil # noqa
from devito.ops.types import OpsAccessible, OpsDat, OpsStencil, OpsBlock # noqa
from devito.ops.types import Array, OpsAccessible, OpsDat, OpsStencil, OpsBlock # noqa
from devito.ops.utils import namespace, AccessibleInfo, OpsDatDecl, OpsArgDecl # noqa
from devito.symbolics import Byref, Literal, indexify # noqa
from devito.symbolics import Byref, ListInitializer, Literal, indexify # noqa
from devito.tools import dtype_to_cstr # noqa
from devito.types import Buffer, Constant, Symbol # noqa
from devito.types import Buffer, Constant, DefaultDimension, Symbol # noqa


class TestOPSExpression(object):
Expand Down Expand Up @@ -272,11 +274,24 @@ def test_create_ops_block(self, equation, expected):
])
def test_upper_bound(self, equation, expected):
grid = Grid((5, 5))
u = TimeFunction(name='u', grid=grid) # noqa
u = TimeFunction(name='u', grid=grid) # noqa
op = Operator(eval(equation))

assert expected in str(op.ccode)

@pytest.mark.parametrize('equation, declaration', [
('Eq(u.forward, u+1)',
'int OPS_Kernel_0_range[4]')
])
def test_single_declaration(self, equation, declaration):
grid = Grid((5, 5))
u = TimeFunction(name='u', grid=grid) # noqa
op = Operator(eval(equation))

occurrences = [i for i in str(op.ccode).split('\n') if declaration in i]

assert len(occurrences) == 1

@pytest.mark.parametrize('equation,expected', [
('Eq(u_2d.forward, u_2d + 1)',
'[\'ops_dat_fetch_data(u_dat[(time_M)%(2)],0,&(u[(time_M)%(2)]));\','
Expand Down

0 comments on commit 959479e

Please sign in to comment.