Skip to content

Commit

Permalink
Merge pull request #8 from edopao/fix_cast
Browse files Browse the repository at this point in the history
Edit dace code (no functional change)
  • Loading branch information
edopao authored Dec 2, 2024
2 parents 6751a07 + 096c5c6 commit b13ddc0
Showing 1 changed file with 17 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import Any, Callable, Sequence
from typing import Any, Callable

import numpy as np

Expand Down Expand Up @@ -120,45 +120,37 @@ class PythonCodegen(codegen.TemplatedGenerator):

Literal = as_fmt("{value}")

def _visit_deref(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str:
assert len(node.args) == 1
if isinstance(node.args[0], gtir.SymRef):
return self.visit(node.args[0], symbol_mapping=symbol_mapping)
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")

def _visit_lambda(
self,
node: gtir.Lambda,
node_args: Sequence[gtir.Node],
symbol_mapping: dict[str, gtir.Node],
) -> str:
symbol_mapping |= {param.id: arg for param, arg in zip(node.params, node_args)}
return self.visit(node.expr, symbol_mapping=symbol_mapping)

def visit_FunCall(self, node: gtir.FunCall, symbol_mapping: dict[str, gtir.Node]) -> str:
def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str:
if isinstance(node.fun, gtir.Lambda):
return self._visit_lambda(node.fun, node.args, symbol_mapping=symbol_mapping)
# update the mapping from lambda parameters to corresponding argument expressions
args_map |= {p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)}
return self.visit(node.fun.expr, args_map=args_map)
elif cpm.is_call_to(node, "deref"):
return self._visit_deref(node, symbol_mapping=symbol_mapping)
assert len(node.args) == 1
if not isinstance(node.args[0], gtir.SymRef):
# shift expressions are not expected in this visitor context
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")
return self.visit(node.args[0], args_map=args_map)
elif isinstance(node.fun, gtir.SymRef):
args = self.visit(node.args, symbol_mapping=symbol_mapping)
args = self.visit(node.args, args_map=args_map)
builtin_name = str(node.fun.id)
return format_builtin(builtin_name, *args)
raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).")

def visit_SymRef(self, node: gtir.SymRef, symbol_mapping: dict[str, gtir.Node]) -> str:
def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str:
symbol = str(node.id)
if symbol_mapping and symbol in symbol_mapping:
mapped_node = symbol_mapping[symbol]
return self.visit(mapped_node, symbol_mapping=symbol_mapping)
if symbol in args_map:
return self.visit(args_map[symbol], args_map=args_map)
return symbol


def get_source(node: gtir.Node) -> str:
"""
Specialized visit method for symbolic expressions.
The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions.
Returns:
A string containing the Python code corresponding to a symbolic expression
"""
return PythonCodegen.apply(node, symbol_mapping={})
return PythonCodegen.apply(node, args_map={})

0 comments on commit b13ddc0

Please sign in to comment.