-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Index builtin #1699
Changes from 3 commits
e1e30b2
181b912
60d2100
19998af
a8c99b8
8031bf9
655e260
b53897a
1540e8d
a61ef9a
f558e35
0c50850
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,9 +31,9 @@ | |
INT_LITERAL: SIGNED_INT | ||
FLOAT_LITERAL: SIGNED_FLOAT | ||
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" | ||
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | ||
AXIS_LITERAL: CNAME ("ᵥ" | "ₕ") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed for consistency. |
||
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
ID_NAME: CNAME | ||
AXIS_NAME: CNAME ("ᵥ" | "ₕ") | ||
|
||
?prec0: prec1 | ||
| "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam | ||
|
@@ -84,7 +84,7 @@ | |
else_branch_seperator: "else" | ||
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" | ||
|
||
named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" | ||
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" | ||
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" | ||
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" | ||
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" | ||
|
@@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: | |
def ID_NAME(self, value: lark_lexer.Token) -> str: | ||
return value.value | ||
|
||
def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral: | ||
def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral: | ||
name = value.value[:-1] | ||
kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL | ||
return ir.AxisLiteral(value=name, kind=kind) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -260,6 +260,17 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll | |||||||||||||||||||||||
#include <gridtools/fn/${grid_type_str}.hpp> | ||||||||||||||||||||||||
#include <gridtools/fn/sid_neighbor_table.hpp> | ||||||||||||||||||||||||
#include <gridtools/stencil/global_parameter.hpp> | ||||||||||||||||||||||||
#include <gridtools/stencil/positional.hpp> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// TODO(havogt): move to gtfn? | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd propose to wait until GridTools/gridtools#1806 is merged next week. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once #1720 is in, we can remove this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I removed that code. |
||||||||||||||||||||||||
namespace gridtools{ | ||||||||||||||||||||||||
namespace fn{ | ||||||||||||||||||||||||
template <class T> | ||||||||||||||||||||||||
auto index(T){ | ||||||||||||||||||||||||
return stencil::positional<std::decay_t<T>>();} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
namespace generated{ | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E | |
tuple_constructor=lambda *elements: SidComposite(values=list(elements)), | ||
) | ||
|
||
assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check was superfluous from the beginning since the format is checked anyway when the |
||
lowered_inputs.append(lowered_input_as_sid) | ||
|
||
backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -357,7 +357,10 @@ def visit_StencilClosure( | |||||
closure_state = closure_sdfg.add_state("closure_entry") | ||||||
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) | ||||||
|
||||||
input_names = [str(inp.id) for inp in node.inputs] | ||||||
assert all( | ||||||
isinstance(inp, SymRef) for inp in node.inputs | ||||||
) # backend only supports SymRef inputs, not `index` calls | ||||||
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
or similar |
||||||
neighbor_tables = get_used_connectivities(node, self.offset_provider) | ||||||
connectivity_names = [ | ||||||
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() | ||||||
|
@@ -565,7 +568,7 @@ def _visit_scan_stencil_closure( | |||||
|
||||||
assert isinstance(node.output, SymRef) | ||||||
neighbor_tables = get_used_connectivities(node, self.offset_provider) | ||||||
input_names = [str(inp.id) for inp in node.inputs] | ||||||
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls | ||||||
connectivity_names = [ | ||||||
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() | ||||||
] | ||||||
|
@@ -732,7 +735,7 @@ def _visit_parallel_stencil_closure( | |||||
], | ||||||
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: | ||||||
neighbor_tables = get_used_connectivities(node, self.offset_provider) | ||||||
input_names = [str(inp.id) for inp in node.inputs] | ||||||
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls | ||||||
connectivity_names = [ | ||||||
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() | ||||||
] | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We moved the comment to a better place. With respect to the comment itself: The
IndexField
class is rather strange. It has two modes of operation: Either it is a field with (conceptually)field(domain) == domain
or it is a zero-dimensional field. Both modes don't share any implementation similarities, but are mushed into the same class. The way the class behaves is then controlled using_cur_index
. It would be much simpler to just makefield[index]
return a zero-dimensional field which is exactly what we want instead of re-implementing it here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what you are saying...