Skip to content

Commit

Permalink
Tasklet codegen: type hints and early returns
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Cattaneo committed Nov 5, 2024
1 parent d9ca5d3 commit ef7b47c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 33 deletions.
1 change: 0 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def visit_Tasklet(
node,
read_memlets=node.read_memlets,
write_memlets=node.write_memlets,
sdfg_ctx=sdfg_ctx,
symtable=symtable,
)

Expand Down
64 changes: 32 additions & 32 deletions src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _visit_offset(
*,
access_info: dcir.FieldAccessInfo,
decl: dcir.FieldDecl,
**kwargs,
**kwargs: Any,
) -> str:
int_sizes: List[Optional[int]] = []
for i, axis in enumerate(access_info.axes()):
Expand Down Expand Up @@ -60,27 +60,27 @@ def _visit_offset(
res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1])
return str(res)

def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs):
def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs: Any) -> str:
return self._visit_offset(node, **kwargs)

def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs):
def visit_VariableKOffset(self, node: dcir.VariableKOffset, **kwargs: Any) -> str:
return self._visit_offset(node, **kwargs)

def visit_IndexAccess(
self,
node: dcir.IndexAccess,
*,
is_target,
sdfg_ctx,
is_target: bool,
symtable: ChainMap[eve.SymbolRef, dcir.Decl],
**kwargs,
):
**kwargs: Any,
) -> str:
if is_target:
memlets = kwargs["write_memlets"]
else:
# if this node is not a target, it will still use the symbol of the write memlet if the
# field was previously written in the same memlet.
memlets = kwargs["read_memlets"] + kwargs["write_memlets"]

try:
memlet = next(mem for mem in memlets if mem.connector == node.name)
except StopIteration:
Expand All @@ -101,12 +101,12 @@ def visit_IndexAccess(
)
)
index_strs.extend(
self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs)
for idx in node.data_index
self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index
)
return f"{node.name}[{','.join(index_strs)}]"

def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs):
def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str:
# Visiting order matters because targets must not contain the target symbols from the left visit
right = self.visit(node.right, is_target=False, **kwargs)
left = self.visit(node.left, is_target=True, **kwargs)
return f"{left} = {right}"
Expand All @@ -120,18 +120,18 @@ def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs):
def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str:
if builtin == common.BuiltInLiteral.TRUE:
return "True"
elif builtin == common.BuiltInLiteral.FALSE:
if builtin == common.BuiltInLiteral.FALSE:
return "False"
raise NotImplementedError("Not implemented BuiltInLiteral encountered.")

def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs):
def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs: Any) -> str:
value = self.visit(literal.value, in_idx=in_idx, **kwargs)
if in_idx:
return str(value)
else:
return "{dtype}({value})".format(
dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value
)

return "{dtype}({value})".format(
dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value
)

Cast = as_fmt("{dtype}({expr})")

Expand Down Expand Up @@ -178,26 +178,26 @@ def visit_NativeFuncCall(self, call: common.NativeFuncCall, **kwargs: Any) -> st
def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str:
if dtype == common.DataType.BOOL:
return "dace.bool_"
elif dtype == common.DataType.INT8:
if dtype == common.DataType.INT8:
return "dace.int8"
elif dtype == common.DataType.INT16:
if dtype == common.DataType.INT16:
return "dace.int16"
elif dtype == common.DataType.INT32:
if dtype == common.DataType.INT32:
return "dace.int32"
elif dtype == common.DataType.INT64:
if dtype == common.DataType.INT64:
return "dace.int64"
elif dtype == common.DataType.FLOAT32:
if dtype == common.DataType.FLOAT32:
return "dace.float32"
elif dtype == common.DataType.FLOAT64:
if dtype == common.DataType.FLOAT64:
return "dace.float64"
raise NotImplementedError("Not implemented DataType encountered.")

def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str:
if op == common.UnaryOperator.NOT:
return " not "
elif op == common.UnaryOperator.NEG:
if op == common.UnaryOperator.NEG:
return "-"
elif op == common.UnaryOperator.POS:
if op == common.UnaryOperator.POS:
return "+"
raise NotImplementedError("Not implemented UnaryOperator encountered.")

Expand All @@ -207,16 +207,16 @@ def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str:

LocalScalarDecl = as_fmt("{name}: {dtype}")

def visit_Tasklet(self, node: dcir.Tasklet, **kwargs):
def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str:
return "\n".join(self.visit(node.decls, **kwargs) + self.visit(node.stmts, **kwargs))

def _visit_conditional(
self,
cond: Optional[Union[dcir.Expr, common.HorizontalMask]],
body: List[dcir.Stmt],
keyword,
**kwargs,
):
keyword: str,
**kwargs: Any,
) -> str:
mask_str = ""
indent = ""
if cond is not None and (cond_str := self.visit(cond, is_target=False, **kwargs)):
Expand All @@ -226,16 +226,16 @@ def _visit_conditional(
body_code = [indent + b for b in body_code]
return "\n".join([mask_str, *body_code])

def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs):
def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs: Any) -> str:
return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs)

def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs):
def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs: Any) -> str:
return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs)

def visit_While(self, node: dcir.While, **kwargs):
def visit_While(self, node: dcir.While, **kwargs: Any) -> Any:
return self._visit_conditional(cond=node.cond, body=node.body, keyword="while", **kwargs)

def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs):
def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs: Any) -> str:
clauses: List[str] = []

for axis, interval in zip(dcir.Axis.dims_horizontal(), node.intervals):
Expand Down

0 comments on commit ef7b47c

Please sign in to comment.