Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/[cartesian]: Absolute k indexation, Part 2: field.at(K=...) syntax #1680

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Assign,
AxisBound,
AxisInterval,
AbsoluteKIndex,
BinaryOperator,
BinOpExpr,
BlockStmt,
Expand Down Expand Up @@ -557,12 +558,15 @@ def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl:
)

def transform_offset(
self, offset: Dict[str, Union[int, Expr]], **kwargs: Any
self, offset: Dict[str, Union[int, Expr, AbsoluteKIndex]], **kwargs: Any
) -> Union[common.CartesianOffset, gtir.VariableKOffset]:
if isinstance(offset, AbsoluteKIndex):
k_to_gtir = self.visit(offset.k)
return gtir.AbsoluteKIndex(k=k_to_gtir)
k_val = offset.get("K", 0)
if isinstance(k_val, numbers.Integral):
return common.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val)
elif isinstance(k_val, Expr):
return gtir.VariableKOffset(k=self.visit(k_val, **kwargs))
else:
raise TypeError("Unrecognized vertical offset type")
raise TypeError("Unrecognized vertical indexing type")
65 changes: 51 additions & 14 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def visit_Return(self, node: ast.Return, *, target_node: ast.AST) -> ast.Assign:
)


def _filter_absolute_K_index_method(node: ast.Call) -> bool:
return isinstance(node.func, ast.Attribute) and node.func.attr == "at"


class CallInliner(ast.NodeTransformer):
"""Inlines calls to gtscript.function calls.

Expand Down Expand Up @@ -403,19 +407,27 @@ def visit_While(self, node: ast.While):
return node

def visit_Assign(self, node: ast.Assign):
if (
isinstance(node.value, ast.Call)
and gt_meta.get_qualified_name_from_node(node.value.func) not in gtscript.MATH_BUILTINS
):
assert len(node.targets) == 1
self.visit(node.value, target_node=node.targets[0])
# This node can be now removed since the trivial assignment has been already done
# in the Call visitor
return None
if isinstance(node.value, ast.Call):
if _filter_absolute_K_index_method(node.value):
return node
elif (
gt_meta.get_qualified_name_from_node(node.value.func) not in gtscript.MATH_BUILTINS
):
assert len(node.targets) == 1
self.visit(node.value, target_node=node.targets[0])
# This node can be now removed since the trivial assignment has been already done
# in the Call visitor
return None
else:
return self.generic_visit(node)
else:
return self.generic_visit(node)

def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complexity too high
# Filter for absolute indexation method '.at'
if self._filter_absolute_K_index_method(node):
return self._absolute_K_index_method(node)

call_name = gt_meta.get_qualified_name_from_node(node.func)

if call_name in self.call_stack:
Expand Down Expand Up @@ -1111,7 +1123,7 @@ def _eval_new_spatial_index(

def _eval_index(
self, node: ast.Subscript, field_axes: Optional[Set[Literal["I", "J", "K"]]] = None
) -> Optional[List[int]]:
) -> Optional[Union[List[int], nodes.AbsoluteKIndex]]:
tuple_or_expr = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
index_nodes = gtc_utils.listify(
tuple_or_expr.elts if isinstance(tuple_or_expr, ast.Tuple) else tuple_or_expr
Expand Down Expand Up @@ -1160,16 +1172,18 @@ def visit_Subscript(self, node: ast.Subscript):
assert index is not None
result.index = index[0]
else:
if isinstance(node.value, ast.Name):
if isinstance(index, nodes.AbsoluteKIndex):
result.offset = index
elif isinstance(node.value, ast.Name):
field_axes = self.fields[result.name].axes
if index is not None:
if len(field_axes) != len(index):
ro_field_message = ""
if len(field_axes) == 0:
ro_field_message = f"Did you mean .A{index}?"
ro_field_message = f"Did you mean absolute indexing via .A{index}?"
raise GTScriptSyntaxError(
f"Incorrect offset specification detected. Found {index}, "
f"but the field has dimensions ({', '.join(field_axes)}). "
f"Incorrect offset specification detected for {result.name}. "
f"Found index={index}, but {result.name} field has dimensions ({', '.join(field_axes)}). "
f"{ro_field_message}"
)
result.offset = {axis: value for axis, value in zip(field_axes, index)}
Expand Down Expand Up @@ -1377,7 +1391,30 @@ def visit_While(self, node: ast.While) -> list:

return result

def _absolute_K_index_method(self, node: ast.Call):
assert _filter_absolute_K_index_method(node)
if len(node.keywords) != 1:
raise GTScriptSyntaxError(
message="Absolute K index bad syntax. Must be of the form`.at(K=...)`",
loc=nodes.Location.from_ast_node(node),
)
if node.keywords[0].arg != "K":
raise GTScriptSyntaxError(
message="Absolute K index: bad syntax, only `K` is a valid parameter to `at`. "
"Must be of the form`.at(K=...)`",
loc=nodes.Location.from_ast_node(node),
)
field: nodes.FieldRef = self.visit(node.func.value)
assert isinstance(field, nodes.FieldRef)
field.offset = nodes.AbsoluteKIndex(k=self.visit(node.keywords[0].value))
return field

def visit_Call(self, node: ast.Call):
# We check for am absolute Field index in K
if _filter_absolute_K_index_method(node):
return self._absolute_K_index_method(node)

# We expect the Call is a native function to carry forward
native_fcn = nodes.NativeFunction.PYTHON_SYMBOL_TO_IR_OP[node.func.id]

args = [self.visit(arg) for arg in node.args]
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,17 @@ class VarRef(Ref):
loc = attribute(of=Location, optional=True)


@attribclass
class AbsoluteKIndex(Expr):
"""See gtc.common.AbsoluteKIndex"""

k = attribute(of=Any)


@attribclass
class FieldRef(Ref):
name = attribute(of=str)
offset = attribute(of=DictOf[str, UnionOf[int, Expr]])
offset = attribute(of=DictOf[str, UnionOf[int, Expr, AbsoluteKIndex]])
data_index = attribute(of=ListOf[Expr], factory=list)
loc = attribute(of=Location, optional=True)

Expand Down
30 changes: 28 additions & 2 deletions src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import enum
import functools
import typing
import numbers
from typing import (
Any,
ClassVar,
Expand Down Expand Up @@ -118,7 +119,7 @@ def isbool(self):
return self == self.BOOL

def isinteger(self):
return self in (self.INT8, self.INT32, self.INT64)
return self in (self.INT8, self.INT16, self.INT32, self.INT64)

def isfloat(self):
return self in (self.FLOAT32, self.FLOAT64)
Expand Down Expand Up @@ -331,14 +332,39 @@ def offset_expr_is_int(self, attribute: datamodels.Attribute, value: Any) -> Non
raise ValueError("Variable vertical index must be an integer expression")


class AbsoluteKIndex(eve.GenericNode, Generic[ExprT]):
"""Access a field with absolute K

Restrictions:
- Centered I/J
- No data dimensions
- Read-only
"""

k: Union[int, ExprT]

def to_dict(self) -> Dict[str, Optional[int]]:
return {"i": 0, "j": 0, "k": None}

@datamodels.validator("k")
def offset_expr_is_int(self, attribute: datamodels.Attribute, value: Any) -> None:
if isinstance(value, numbers.Real):
if not isinstance(value, int):
raise ValueError("Absolute vertical index literal must be an integer")
else:
value = typing.cast(Expr, value)
if value.dtype is not DataType.AUTO and not value.dtype.isinteger():
raise ValueError("Absolute vertical index must be an integer expression")


class ScalarAccess(LocNode):
name: eve.Coerced[eve.SymbolRef]
kind: ExprKind = ExprKind.SCALAR


class FieldAccess(eve.GenericNode, Generic[ExprT, VariableKOffsetT]):
name: eve.Coerced[eve.SymbolRef]
offset: Union[CartesianOffset, VariableKOffsetT]
offset: Union[CartesianOffset, VariableKOffsetT, AbsoluteKIndex]
data_index: List[ExprT] = eve.field(default_factory=list)
kind: ExprKind = ExprKind.FIELD

Expand Down
17 changes: 16 additions & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,22 @@ class VariableKOffset(common.VariableKOffset[Expr]):


class IndexAccess(common.FieldAccess, Expr):
offset: Optional[Union[common.CartesianOffset, VariableKOffset]]
offset: Optional[
Union[
common.CartesianOffset,
VariableKOffset,
common.AbsoluteKIndex,
Literal,
ScalarAccess, # For field index
]
]

@datamodels.validator("offset")
def offset_is_integer(self, attribute: datamodels.Attribute, v: Expr) -> None:
if (isinstance(v, ScalarAccess) or isinstance(v, Literal)) and not v.dtype.isinteger():
raise ValueError(
f"Index access, when ScalarAcces/Literal, must be an integer, got {v.dtype}."
)


class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt):
Expand Down
21 changes: 20 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,27 @@ def visit_FieldAccess(
)
name = get_tasklet_symbol(node.name, node.offset, is_target=is_target)
if node.data_index:
if isinstance(node.offset, common.AbsoluteKIndex):
raise RuntimeError("Absolute K indexing cannot work with data index")
res = dcir.IndexAccess(
name=name, offset=None, data_index=node.data_index, dtype=node.dtype
name=name,
offset=None,
data_index=node.data_index,
dtype=node.dtype,
)
elif isinstance(node.offset, common.AbsoluteKIndex):
res = dcir.IndexAccess(
name=name,
offset=self.visit(
node.offset.k,
is_target=is_target,
targets=targets,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
**kwargs,
),
data_index=[],
dtype=node.dtype,
)
else:
res = dcir.ScalarAccess(name=name, dtype=node.dtype)
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs):
def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs):
return self._visit_offset(node, **kwargs)

def visit_AbsoluteKIndex(self, node: common.AbsoluteKIndex, **kwargs):
idx = self.visit(self.visit(node.k))
return str(idx)

def visit_IndexAccess(
self,
node: dcir.IndexAccess,
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,17 @@ def _make_access_info(
grid_subset,
is_write,
) -> dcir.FieldAccessInfo:
"""Compute how the field get accessed on the grid"""

# Check we have expression offsets in K
# OR write offsets in K
# OR absolute indexing in K
offset = [offset_node.to_dict()[k] for k in "ijk"]
if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write):
if (
isinstance(offset_node, oir.VariableKOffset)
or (offset[2] != 0 and is_write)
or isinstance(offset_node, oir.AbsoluteKIndex)
):
variable_offset_axes = [dcir.Axis.K]
else:
variable_offset_axes = []
Expand Down
54 changes: 44 additions & 10 deletions src/gt4py/cartesian/gtc/gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class VariableKOffset(common.VariableKOffset[Expr]):
pass


class AbsoluteKIndex(common.AbsoluteKIndex[Expr]):
"""See gtc.common.AbsoluteKIndex"""

pass


class ScalarAccess(common.ScalarAccess, Expr): # type: ignore
pass

Expand Down Expand Up @@ -83,14 +89,26 @@ def no_write_and_read_with_offset_of_same_field(
) -> None:
if isinstance(instance.left, FieldAccess):
offset_reads = (
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
) | eve.walk_values(instance.right).filter(_variablek_fieldaccess).getattr(
"name"
).to_set()
(
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_absolutekindex_fieldaccess)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_variablek_fieldaccess)
.getattr("name")
.to_set()
)
)
if instance.left.name in offset_reads:
raise ValueError("Self-assignment with offset is illegal.")

Expand Down Expand Up @@ -246,11 +264,27 @@ def param_names(self) -> List[str]:


def _cartesian_fieldaccess(node) -> bool:
return isinstance(node, FieldAccess) and not isinstance(node.offset, VariableKOffset)
return (
isinstance(node, FieldAccess)
and not isinstance(node.offset, VariableKOffset)
and not isinstance(node.offset, AbsoluteKIndex)
)


def _variablek_fieldaccess(node) -> bool:
return isinstance(node, FieldAccess) and isinstance(node.offset, VariableKOffset)
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, VariableKOffset)
and not isinstance(node.offset, AbsoluteKIndex)
)


def _absolutekindex_fieldaccess(node) -> bool:
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, AbsoluteKIndex)
and not isinstance(node.offset, VariableKOffset)
)


# TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/cartesian/gtc/gtir_to_oir.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def visit_FieldAccess(self, node: gtir.FieldAccess) -> oir.FieldAccess:
loc=node.loc,
)

def visit_AbsoluteKIndex(self, node: gtir.AbsoluteKIndex) -> oir.AbsoluteKIndex:
return oir.AbsoluteKIndex(k=self.visit(node.k))

def visit_VariableKOffset(self, node: gtir.VariableKOffset) -> oir.VariableKOffset:
return oir.VariableKOffset(k=self.visit(node.k))

Expand Down
Loading
Loading