Skip to content

Commit

Permalink
handle nargout;
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoqing committed May 24, 2022
1 parent 06a54d3 commit 3bf89c2
Showing 1 changed file with 62 additions and 26 deletions.
88 changes: 62 additions & 26 deletions mh_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def dynamic_selection_visitor(self, node: Dynamic_Selection, n_parent, relation)
def selection_visitor(self, node: Selection, n_parent, relation):
self[node] = f"{self.pop(node.n_prefix)}.{self.pop(node.n_field)}"

def while_statement_visitor(
self, node: While_Statement, n_parent, relation
):
def while_statement_visitor(self, node: While_Statement, n_parent, relation):
self[node] = (
f"while {self.pop(node.n_guard)}:\n"
f"{self.indent(self.pop(node.n_body))}\n"
Expand Down Expand Up @@ -220,15 +218,38 @@ def reference_visitor(self, node: Reference, n_parent, relation):
args = ", ".join(self.pop(i) for i in node.l_args)
self[node] = f"{self.pop(node.n_ident)}{bra}{args}{ket}"

@staticmethod
def _function_root(node: Identifier):
f = node
while f is not None and not isinstance(f, Function_Definition):
f = f.n_parent
return f

def identifier_visitor(self, node: Identifier, n_parent, relation):
value = node.t_ident.value

# Python and Matlab keywords set are different. Luckily, Matlab forbids identifier start with `_`.
prefix = '_' if value in (
*keyword.kwlist, # python keyword can not be overwritten
'I', 'M', 'C', # mat2py keywords
) else ''
self[node] = f'{prefix}{value}'
prefix = (
"_"
if value
in (
*keyword.kwlist, # python keyword can not be overwritten
"I",
"M",
"C", # mat2py keywords
)
else ""
)

if value == "nargout":
f = self._function_root(node)
if f:
f.contains_nargout = True
if value == "nargin":
f = self._function_root(node)
if f:
f.contains_nargin = True
self[node] = f"{prefix}{value}"

def number_literal_visitor(self, node: Number_Literal, n_parent, relation):
self[node] = node.t_value.value.replace("i", "j")
Expand Down Expand Up @@ -292,23 +313,38 @@ def script_file_visitor(self, node: Script_File, n_parent, relation):
def sequence_of_statements_visitor(
self, node: Sequence_Of_Statements, n_parent, relation
):
self[node] = "\n".join([self.pop(l) for l in node.l_statements]) if node.l_statements else "pass"
self[node] = (
"\n".join([self.pop(l) for l in node.l_statements])
if node.l_statements
else "pass"
)

def function_pointer_visitor(
self, node: Function_Pointer, n_parent, relation
):
def function_pointer_visitor(self, node: Function_Pointer, n_parent, relation):
self[node] = self.pop(node.n_name)

def function_definition_visitor(
self, node: Function_Definition, n_parent, relation
):
contains_nargout = getattr(node, "contains_nargout", False)
n_name = self.pop(node.n_sig.n_name)
n_body = self.indent(self.pop(node.n_body))
n_body = self.pop(node.n_body)
l_inputs = ", ".join([self.pop(i) for i in node.n_sig.l_inputs])
if contains_nargout:
l_inputs += ", nargout=None"
n_outputs = len(node.n_sig.l_outputs)
l_outputs = ", ".join([self.pop(i) for i in node.n_sig.l_outputs])
if l_outputs != "":
l_outputs = self.indent("return {}".format(l_outputs))
self[node] = f"def {n_name}({l_inputs}):\n{n_body}\n{l_outputs}\n"
if n_outputs > 1 and contains_nargout:
n_body = (
f"{l_outputs} = (None,)*{n_outputs}\n\n"
+ n_body
+ "\nreturn{nargout_str}\n"
)
elif n_outputs > 0:
n_body = n_body + "\nreturn{nargout_str}\n"
nargout_str = f"({l_outputs})[:nargout]" if contains_nargout else l_outputs
self[node] = f"def {n_name}({l_inputs}):\n{self.indent(n_body)}\n".format(
nargout_str=f" {nargout_str}" if n_outputs > 0 else ""
)

def compound_assignment_statement_visitor(
self, node: Compound_Assignment_Statement, n_parent, relation
Expand Down Expand Up @@ -407,8 +443,7 @@ def break_statement_visitor(self, node: Break_Statement, n_parent, relation):
self[node] = "break"

def return_statement_visitor(self, node: Return_Statement, n_parent, relation):
self[node] = "return"
# TODO: fix the return value
self[node] = "return{nargout_str}"

def row_visitor(self, node: Row, n_parent, relation):
if len(node.l_items) == 1:
Expand Down Expand Up @@ -648,26 +683,27 @@ def process_one_file(path: [Path, str], options=None, mh=None):
backend.process_result(backend.process_wp(wp))

for msg in chain.from_iterable(wp.mh.messages.get(wp.filename, {}).values()):
if msg.kind == "error" and "expected IDENTIFIER, reached EOF instead" == msg.message:
if (
msg.kind == "error"
and "expected IDENTIFIER, reached EOF instead" == msg.message
):
raise EOFError(msg.message)
elif msg.kind.endswith("error"):
mh.emit_message(msg)
raise SyntaxError(msg.message)


def process_one_block(src: str, inline=True, format=False):
with NamedTemporaryFile('w', suffix='.m') as f:
with NamedTemporaryFile("w", suffix=".m") as f:
f.write(src)
f.flush()

target_path = Path(f.name).with_suffix('.py')
target_path = Path(f.name).with_suffix(".py")

try:
options = parse_args([
"mh_python",
"--single",
"--python-alongside",
f.name])
options = parse_args(
["mh_python", "--single", "--python-alongside", f.name]
)
if inline is True:
options.inline_mode = True
if format is True:
Expand Down

0 comments on commit 3bf89c2

Please sign in to comment.