diff --git a/miss_hit/mh_python.py b/miss_hit/mh_python.py index a78af6f..910cb5f 100644 --- a/miss_hit/mh_python.py +++ b/miss_hit/mh_python.py @@ -126,9 +126,7 @@ def dynamic_selection_visitor(self, node: Dynamic_Selection, n_parent, relation) def selection_visitor(self, node: Selection, n_parent, relation): self[node] = f"{self.pop(node.n_prefix)}.{self.pop(node.n_field)}" - def while_statement_visitor( - self, node: While_Statement, n_parent, relation - ): + def while_statement_visitor(self, node: While_Statement, n_parent, relation): self[node] = ( f"while {self.pop(node.n_guard)}:\n" f"{self.indent(self.pop(node.n_body))}\n" @@ -220,15 +218,38 @@ def reference_visitor(self, node: Reference, n_parent, relation): args = ", ".join(self.pop(i) for i in node.l_args) self[node] = f"{self.pop(node.n_ident)}{bra}{args}{ket}" + @staticmethod + def _function_root(node: Identifier): + f = node + while f is not None and not isinstance(f, Function_Definition): + f = f.n_parent + return f + def identifier_visitor(self, node: Identifier, n_parent, relation): value = node.t_ident.value # Python and Matlab keywords set are different. Luckily, Matlab forbids identifier start with `_`. - prefix = '_' if value in ( - *keyword.kwlist, # python keyword can not be overwritten - 'I', 'M', 'C', # mat2py keywords - ) else '' - self[node] = f'{prefix}{value}' + prefix = ( + "_" + if value + in ( + *keyword.kwlist, # python keyword can not be overwritten + "I", + "M", + "C", # mat2py keywords + ) + else "" + ) + + if value == "nargout": + f = self._function_root(node) + if f: + f.contains_nargout = True + if value == "nargin": + f = self._function_root(node) + if f: + f.contains_nargin = True + self[node] = f"{prefix}{value}" def number_literal_visitor(self, node: Number_Literal, n_parent, relation): self[node] = node.t_value.value.replace("i", "j") @@ -292,23 +313,38 @@ def script_file_visitor(self, node: Script_File, n_parent, relation): def sequence_of_statements_visitor( self, node: Sequence_Of_Statements, n_parent, relation ): - self[node] = "\n".join([self.pop(l) for l in node.l_statements]) if node.l_statements else "pass" + self[node] = ( + "\n".join([self.pop(l) for l in node.l_statements]) + if node.l_statements + else "pass" + ) - def function_pointer_visitor( - self, node: Function_Pointer, n_parent, relation - ): + def function_pointer_visitor(self, node: Function_Pointer, n_parent, relation): self[node] = self.pop(node.n_name) def function_definition_visitor( self, node: Function_Definition, n_parent, relation ): + contains_nargout = getattr(node, "contains_nargout", False) n_name = self.pop(node.n_sig.n_name) - n_body = self.indent(self.pop(node.n_body)) + n_body = self.pop(node.n_body) l_inputs = ", ".join([self.pop(i) for i in node.n_sig.l_inputs]) + if contains_nargout: + l_inputs += ", nargout=None" + n_outputs = len(node.n_sig.l_outputs) l_outputs = ", ".join([self.pop(i) for i in node.n_sig.l_outputs]) - if l_outputs != "": - l_outputs = self.indent("return {}".format(l_outputs)) - self[node] = f"def {n_name}({l_inputs}):\n{n_body}\n{l_outputs}\n" + if n_outputs > 1 and contains_nargout: + n_body = ( + f"{l_outputs} = (None,)*{n_outputs}\n\n" + + n_body + + "\nreturn{nargout_str}\n" + ) + elif n_outputs > 0: + n_body = n_body + "\nreturn{nargout_str}\n" + nargout_str = f"({l_outputs})[:nargout]" if contains_nargout else l_outputs + self[node] = f"def {n_name}({l_inputs}):\n{self.indent(n_body)}\n".format( + nargout_str=f" {nargout_str}" if n_outputs > 0 else "" + ) def compound_assignment_statement_visitor( self, node: Compound_Assignment_Statement, n_parent, relation @@ -407,8 +443,7 @@ def break_statement_visitor(self, node: Break_Statement, n_parent, relation): self[node] = "break" def return_statement_visitor(self, node: Return_Statement, n_parent, relation): - self[node] = "return" - # TODO: fix the return value + self[node] = "return{nargout_str}" def row_visitor(self, node: Row, n_parent, relation): if len(node.l_items) == 1: @@ -648,7 +683,10 @@ def process_one_file(path: [Path, str], options=None, mh=None): backend.process_result(backend.process_wp(wp)) for msg in chain.from_iterable(wp.mh.messages.get(wp.filename, {}).values()): - if msg.kind == "error" and "expected IDENTIFIER, reached EOF instead" == msg.message: + if ( + msg.kind == "error" + and "expected IDENTIFIER, reached EOF instead" == msg.message + ): raise EOFError(msg.message) elif msg.kind.endswith("error"): mh.emit_message(msg) @@ -656,18 +694,16 @@ def process_one_file(path: [Path, str], options=None, mh=None): def process_one_block(src: str, inline=True, format=False): - with NamedTemporaryFile('w', suffix='.m') as f: + with NamedTemporaryFile("w", suffix=".m") as f: f.write(src) f.flush() - target_path = Path(f.name).with_suffix('.py') + target_path = Path(f.name).with_suffix(".py") try: - options = parse_args([ - "mh_python", - "--single", - "--python-alongside", - f.name]) + options = parse_args( + ["mh_python", "--single", "--python-alongside", f.name] + ) if inline is True: options.inline_mode = True if format is True: