diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8314796..75d936b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index acff300..00caf34 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ corresponding $\LaTeX$ expression: 1. *Which Python versions are supported?* - Syntaxes on **Pythons 3.7 to 3.11** are officially supported, or will be supported. + Syntaxes on **Pythons 3.9 to 3.13** are officially supported, or will be supported. 2. *Which technique is used?* diff --git a/pyproject.toml b/pyproject.toml index 1576dce..1d08669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" name = "latexify-py" description = "Generates LaTeX math description from Python functions." readme = "README.md" -requires-python = ">=3.7, <3.12" +requires-python = ">=3.9, <3.14" license = {text = "Apache Software License 2.0"} authors = [ {name = "Yusuke Oda", email = "odashi@inspiredco.ai"} @@ -24,11 +24,11 @@ classifiers = [ "Framework :: IPython", "Framework :: Jupyter", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Software Development :: Code Generators", "Topic :: Text Processing :: Markup :: LaTeX", @@ -43,17 +43,17 @@ dynamic = [ [project.optional-dependencies] dev = [ "build>=0.8", - "black>=22.10", - "flake8>=5.0", + "black>=24.3", + "flake8>=6.0", "isort>=5.10", - "mypy>=0.991", + "mypy>=1.9", "notebook>=6.5.1", - "pyproject-flake8>=5.0", + "pyproject-flake8>=6.0", "pytest>=7.1", "twine>=4.0", ] mypy = [ - "mypy>=0.991", + "mypy>=1.9", "pytest>=7.1", ] diff --git a/src/latexify/analyzers_test.py b/src/latexify/analyzers_test.py index 52802d7..314ce2f 100644 --- a/src/latexify/analyzers_test.py +++ b/src/latexify/analyzers_test.py @@ -9,7 +9,6 @@ from latexify import analyzers, ast_utils, exceptions, test_utils -@test_utils.require_at_least(8) @pytest.mark.parametrize( "code,start,stop,step,start_int,stop_int,step_int", [ diff --git a/src/latexify/ast_utils.py b/src/latexify/ast_utils.py index d9da518..efac1c3 100644 --- a/src/latexify/ast_utils.py +++ b/src/latexify/ast_utils.py @@ -56,24 +56,12 @@ def make_constant(value: Any) -> ast.expr: Raises: ValueError: Unsupported value type. """ - if sys.version_info.minor < 8: - if value is None or value is False or value is True: - return ast.NameConstant(value=value) - if value is ...: - return ast.Ellipsis() - if isinstance(value, (int, float, complex)): - return ast.Num(n=value) - if isinstance(value, str): - return ast.Str(s=value) - if isinstance(value, bytes): - return ast.Bytes(s=value) - else: - if ( - value is None - or value is ... - or isinstance(value, (bool, int, float, complex, str, bytes)) - ): - return ast.Constant(value=value) + if ( + value is None + or value is ... + or isinstance(value, (bool, int, float, complex, str, bytes)) + ): + return ast.Constant(value=value) raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}") @@ -87,13 +75,7 @@ def is_constant(node: ast.AST) -> bool: Returns: True if the node is a constant, False otherwise. """ - if sys.version_info.minor < 8: - return isinstance( - node, - (ast.Bytes, ast.Constant, ast.Ellipsis, ast.NameConstant, ast.Num, ast.Str), - ) - else: - return isinstance(node, ast.Constant) + return isinstance(node, ast.Constant) def is_str(node: ast.AST) -> bool: @@ -120,20 +102,12 @@ def extract_int_or_none(node: ast.expr) -> int | None: Returns: Extracted int value, or None if extraction failed. """ - if sys.version_info.minor < 8: - if ( - isinstance(node, ast.Num) - and isinstance(node.n, int) - and not isinstance(node.n, bool) - ): - return node.n - else: - if ( - isinstance(node, ast.Constant) - and isinstance(node.value, int) - and not isinstance(node.n, bool) - ): - return node.value + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, int) + and not isinstance(node.value, bool) + ): + return node.value return None @@ -173,3 +147,65 @@ def extract_function_name_or_none(node: ast.Call) -> str | None: return node.func.attr return None + + +def create_function_def( + name, + args, + body, + decorator_list, + returns=None, + type_comment=None, + type_params=None, + lineno=None, + col_offset=None, + end_lineno=None, + end_col_offset=None, +) -> ast.FunctionDef: + """Creates a FunctionDef node. + + This function generates an `ast.FunctionDef` node, optionally removing + the `type_params` keyword argument for Python versions below 3.12. + + Args: + name: Name of the function. + args: Arguments of the function. + body: Body of the function. + decorator_list: List of decorators. + returns: Return type of the function. + type_comment: Type comment of the function. + type_params: Type parameters of the function. + lineno: Line number of the function definition. + col_offset: Column offset of the function definition. + end_lineno: End line number of the function definition. + end_col_offset: End column offset of the function definition. + + Returns: + ast.FunctionDef: The generated FunctionDef node. + """ + if sys.version_info.minor < 12: + return ast.FunctionDef( + name=name, + args=args, + body=body, + decorator_list=decorator_list, + returns=returns, + type_comment=type_comment, + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + ) # type: ignore + return ast.FunctionDef( + name=name, + args=args, + body=body, + decorator_list=decorator_list, + returns=returns, + type_comment=type_comment, + type_params=type_params, + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + ) # type: ignore diff --git a/src/latexify/ast_utils_test.py b/src/latexify/ast_utils_test.py index 0e9cfa2..2b0ce71 100644 --- a/src/latexify/ast_utils_test.py +++ b/src/latexify/ast_utils_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import sys from typing import Any import pytest @@ -34,29 +35,6 @@ def test_make_attribute() -> None: ) -@test_utils.require_at_most(7) -@pytest.mark.parametrize( - "value,expected", - [ - (None, ast.NameConstant(value=None)), - (False, ast.NameConstant(value=False)), - (True, ast.NameConstant(value=True)), - (..., ast.Ellipsis()), - (123, ast.Num(n=123)), - (4.5, ast.Num(n=4.5)), - (6 + 7j, ast.Num(n=6 + 7j)), - ("foo", ast.Str(s="foo")), - (b"bar", ast.Bytes(s=b"bar")), - ], -) -def test_make_constant_legacy(value: Any, expected: ast.Constant) -> None: - test_utils.assert_ast_equal( - observed=ast_utils.make_constant(value), - expected=expected, - ) - - -@test_utils.require_at_least(8) @pytest.mark.parametrize( "value,expected", [ @@ -83,17 +61,16 @@ def test_make_constant_invalid() -> None: ast_utils.make_constant(object()) -@test_utils.require_at_most(7) @pytest.mark.parametrize( "value,expected", [ - (ast.Bytes(s=b"foo"), True), - (ast.Constant("bar"), True), - (ast.Ellipsis(), True), - (ast.NameConstant(value=None), True), - (ast.Num(n=123), True), - (ast.Str(s="baz"), True), - (ast.Expr(value=ast.Num(456)), False), + (ast.Constant(value=b"foo"), True), + (ast.Constant(value="bar"), True), + (ast.Constant(value=...), True), + (ast.Constant(value=None), True), + (ast.Constant(value=123), True), + (ast.Constant(value="baz"), True), + (ast.Expr(value=ast.Constant(value=456)), False), (ast.Global(names=["qux"]), False), ], ) @@ -101,7 +78,6 @@ def test_is_constant_legacy(value: ast.AST, expected: bool) -> None: assert ast_utils.is_constant(value) is expected -@test_utils.require_at_least(8) @pytest.mark.parametrize( "value,expected", [ @@ -114,17 +90,16 @@ def test_is_constant(value: ast.AST, expected: bool) -> None: assert ast_utils.is_constant(value) is expected -@test_utils.require_at_most(7) @pytest.mark.parametrize( "value,expected", [ - (ast.Bytes(s=b"foo"), False), - (ast.Constant("bar"), True), - (ast.Ellipsis(), False), - (ast.NameConstant(value=None), False), - (ast.Num(n=123), False), - (ast.Str(s="baz"), True), - (ast.Expr(value=ast.Num(456)), False), + (ast.Constant(value=b"foo"), False), + (ast.Constant(value="bar"), True), + (ast.Constant(value=...), False), + (ast.Constant(value=None), False), + (ast.Constant(value=123), False), + (ast.Constant(value="baz"), True), + (ast.Expr(value=ast.Constant(value=456)), False), (ast.Global(names=["qux"]), False), ], ) @@ -132,7 +107,6 @@ def test_is_str_legacy(value: ast.AST, expected: bool) -> None: assert ast_utils.is_str(value) is expected -@test_utils.require_at_least(8) @pytest.mark.parametrize( "value,expected", [ @@ -194,6 +168,7 @@ def test_extract_int_invalid() -> None: ast.Call( func=ast.Name(id="hypot", ctx=ast.Load()), args=[], + keywords=[], ), "hypot", ), @@ -205,13 +180,17 @@ def test_extract_int_invalid() -> None: ctx=ast.Load(), ), args=[], + keywords=[], ), "hypot", ), ( ast.Call( - func=ast.Call(func=ast.Name(id="foo", ctx=ast.Load()), args=[]), + func=ast.Call( + func=ast.Name(id="foo", ctx=ast.Load()), args=[], keywords=[] + ), args=[], + keywords=[], ), None, ), @@ -219,3 +198,40 @@ def test_extract_int_invalid() -> None: ) def test_extract_function_name_or_none(value: ast.Call, expected: str | None) -> None: assert ast_utils.extract_function_name_or_none(value) == expected + + +def test_create_function_def() -> None: + expected_args = ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="x")], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ) + + kwargs = { + "name": "test_func", + "args": expected_args, + "body": [ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], + "decorator_list": [], + "returns": None, + "type_comment": None, + "lineno": 1, + "col_offset": 0, + "end_lineno": 2, + "end_col_offset": 0, + } + if sys.version_info.minor >= 12: + kwargs["type_params"] = [] + + func_def = ast_utils.create_function_def(**kwargs) + assert isinstance(func_def, ast.FunctionDef) + assert func_def.name == "test_func" + + assert func_def.args.posonlyargs == expected_args.posonlyargs + assert func_def.args.args == expected_args.args + assert func_def.args.kwonlyargs == expected_args.kwonlyargs + assert func_def.args.kw_defaults == expected_args.kw_defaults + assert func_def.args.defaults == expected_args.defaults diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py index e869777..abfc82f 100644 --- a/src/latexify/codegen/expression_codegen_test.py +++ b/src/latexify/codegen/expression_codegen_test.py @@ -6,7 +6,7 @@ import pytest -from latexify import ast_utils, exceptions, test_utils +from latexify import ast_utils, exceptions from latexify.codegen import expression_codegen @@ -792,7 +792,6 @@ def test_visit_boolop(code: str, latex: str) -> None: assert expression_codegen.ExpressionCodegen().visit(tree) == latex -@test_utils.require_at_most(7) @pytest.mark.parametrize( "code,cls,latex", [ @@ -817,7 +816,6 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No assert expression_codegen.ExpressionCodegen().visit(tree) == latex -@test_utils.require_at_least(8) @pytest.mark.parametrize( "code,latex", [ diff --git a/src/latexify/codegen/expression_rules_test.py b/src/latexify/codegen/expression_rules_test.py index d2c7dfe..394e72d 100644 --- a/src/latexify/codegen/expression_rules_test.py +++ b/src/latexify/codegen/expression_rules_test.py @@ -12,13 +12,41 @@ @pytest.mark.parametrize( "node,precedence", [ - (ast.Call(), expression_rules._CALL_PRECEDENCE), - (ast.BinOp(op=ast.Add()), expression_rules._PRECEDENCES[ast.Add]), - (ast.UnaryOp(op=ast.UAdd()), expression_rules._PRECEDENCES[ast.UAdd]), - (ast.BoolOp(op=ast.And()), expression_rules._PRECEDENCES[ast.And]), - (ast.Compare(ops=[ast.Eq()]), expression_rules._PRECEDENCES[ast.Eq]), - (ast.Name(), expression_rules._INF_PRECEDENCE), - (ast.Attribute(), expression_rules._INF_PRECEDENCE), + ( + ast.Call(func=ast.Name(id="func", ctx=ast.Load()), args=[], keywords=[]), + expression_rules._CALL_PRECEDENCE, + ), + ( + ast.BinOp( + left=ast.Name(id="left", ctx=ast.Load()), + op=ast.Add(), + right=ast.Name(id="right", ctx=ast.Load()), + ), + expression_rules._PRECEDENCES[ast.Add], + ), + ( + ast.UnaryOp(op=ast.UAdd(), operand=ast.Name(id="operand", ctx=ast.Load())), + expression_rules._PRECEDENCES[ast.UAdd], + ), + ( + ast.BoolOp(op=ast.And(), values=[ast.Name(id="value", ctx=ast.Load())]), + expression_rules._PRECEDENCES[ast.And], + ), + ( + ast.Compare( + left=ast.Name(id="left", ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Name(id="right", ctx=ast.Load())], + ), + expression_rules._PRECEDENCES[ast.Eq], + ), + (ast.Name(id="name", ctx=ast.Load()), expression_rules._INF_PRECEDENCE), + ( + ast.Attribute( + value=ast.Name(id="value", ctx=ast.Load()), attr="attr", ctx=ast.Load() + ), + expression_rules._INF_PRECEDENCE, + ), ], ) def test_get_precedence(node: ast.AST, precedence: int) -> None: diff --git a/src/latexify/parser_test.py b/src/latexify/parser_test.py index 5a3afcb..a047503 100644 --- a/src/latexify/parser_test.py +++ b/src/latexify/parser_test.py @@ -6,7 +6,7 @@ import pytest -from latexify import exceptions, parser, test_utils +from latexify import ast_utils, exceptions, parser, test_utils def test_parse_function_with_posonlyargs() -> None: @@ -15,14 +15,29 @@ def f(x): expected = ast.Module( body=[ - ast.FunctionDef( + ast_utils.create_function_def( name="f", args=ast.arguments( + posonlyargs=[], args=[ast.arg(arg="x")], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], ), body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], + decorator_list=[], + returns=None, + type_comment=None, + type_params=[], + lineno=1, + col_offset=0, + end_lineno=2, + end_col_offset=0, ) ], + type_ignores=[], ) obtained = parser.parse_function(f) diff --git a/src/latexify/test_utils.py b/src/latexify/test_utils.py index e143a64..8f82594 100644 --- a/src/latexify/test_utils.py +++ b/src/latexify/test_utils.py @@ -71,11 +71,15 @@ def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: Returns: True if observed and expected represent the same AST, False otherwise. """ + ignore_keys = {"lineno", "col_offset", "end_lineno", "end_col_offset", "kind"} + if sys.version_info.minor <= 12: + ignore_keys.add("type_params") + try: assert type(observed) is type(expected) for k, ve in vars(expected).items(): - if k in {"col_offset", "end_col_offset", "end_lineno", "kind", "lineno"}: + if k in ignore_keys: continue vo = getattr(observed, k) # May cause AttributeError. @@ -94,7 +98,7 @@ def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: assert vo == ve except (AssertionError, AttributeError): - return False + raise # raise to debug easier. return True @@ -109,19 +113,10 @@ def assert_ast_equal(observed: ast.AST, expected: ast.AST) -> None: Raises: AssertionError: observed and expected represent different ASTs. """ - if sys.version_info.minor >= 9: - assert ast_equal( - observed, expected - ), f"""\ -AST does not match. -observed={ast.dump(observed, indent=4)} -expected={ast.dump(expected, indent=4)} -""" - else: - assert ast_equal( - observed, expected - ), f"""\ -AST does not match. -observed={ast.dump(observed)} -expected={ast.dump(expected)} -""" + assert ast_equal( + observed, expected + ), f"""\ + AST does not match. + observed={ast.dump(observed, indent=4)} + expected={ast.dump(expected, indent=4)} + """ diff --git a/src/latexify/transformers/assignment_reducer.py b/src/latexify/transformers/assignment_reducer.py index ef35af2..8094839 100644 --- a/src/latexify/transformers/assignment_reducer.py +++ b/src/latexify/transformers/assignment_reducer.py @@ -5,7 +5,7 @@ import ast from typing import Any -from latexify import exceptions +from latexify import ast_utils, exceptions class AssignmentReducer(ast.NodeTransformer): @@ -66,12 +66,14 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # Pop stack self._assignments = parent_assignments - - return ast.FunctionDef( + type_params = getattr(node, "type_params", []) + return ast_utils.create_function_def( name=node.name, args=node.args, body=[return_transformed], decorator_list=node.decorator_list, + returns=node.returns, + type_params=type_params, ) def visit_Name(self, node: ast.Name) -> Any: diff --git a/src/latexify/transformers/assignment_reducer_test.py b/src/latexify/transformers/assignment_reducer_test.py index 437fc7d..816de17 100644 --- a/src/latexify/transformers/assignment_reducer_test.py +++ b/src/latexify/transformers/assignment_reducer_test.py @@ -19,18 +19,21 @@ def _make_ast(body: list[ast.stmt]) -> ast.Module: """ return ast.Module( body=[ - ast.FunctionDef( + ast_utils.create_function_def( name="f", args=ast.arguments( args=[ast.arg(arg="x")], kwonlyargs=[], kw_defaults=[], defaults=[], + posonlyargs=[], ), body=body, decorator_list=[], + type_params=[], ) ], + type_ignores=[], ) diff --git a/src/latexify/transformers/docstring_remover_test.py b/src/latexify/transformers/docstring_remover_test.py index d65c524..7a41146 100644 --- a/src/latexify/transformers/docstring_remover_test.py +++ b/src/latexify/transformers/docstring_remover_test.py @@ -17,16 +17,31 @@ def f(): tree = parser.parse_function(f).body[0] assert isinstance(tree, ast.FunctionDef) - expected = ast.FunctionDef( + expected = ast_utils.create_function_def( name="f", body=[ ast.Assign( targets=[ast.Name(id="x", ctx=ast.Store())], value=ast_utils.make_constant(42), ), - ast.Expr(value=ast.Call(func=ast.Name(id="f", ctx=ast.Load()))), + ast.Expr( + value=ast.Call( + func=ast.Name(id="f", ctx=ast.Load()), args=[], keywords=[] + ) + ), ast.Return(value=ast.Name(id="x", ctx=ast.Load())), ], + args=ast.arguments( + posonlyargs=[], + args=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ), + decorator_list=[], + type_params=[], ) transformed = DocstringRemover().visit(tree) test_utils.assert_ast_equal(transformed, expected) diff --git a/src/latexify/transformers/function_expander.py b/src/latexify/transformers/function_expander.py index 529de43..806345c 100644 --- a/src/latexify/transformers/function_expander.py +++ b/src/latexify/transformers/function_expander.py @@ -60,6 +60,7 @@ def _atan2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast. right=function_expander.visit(node.args[1]), ) ], + keywords=[], ) @@ -88,6 +89,7 @@ def _expm1_expander(function_expander: FunctionExpander, node: ast.Call) -> ast. ast.Call( func=ast.Name(id="exp", ctx=ast.Load()), args=[node.args[0]], + keywords=[], ) ), op=ast.Sub(), @@ -114,6 +116,7 @@ def _hypot_expander(function_expander: FunctionExpander, node: ast.Call) -> ast. return ast.Call( func=ast.Name(id="sqrt", ctx=ast.Load()), args=[args_reduced], + keywords=[], ) @@ -128,6 +131,7 @@ def _log1p_expander(function_expander: FunctionExpander, node: ast.Call) -> ast. right=function_expander.visit(node.args[0]), ) ], + keywords=[], ) diff --git a/src/latexify/transformers/function_expander_test.py b/src/latexify/transformers/function_expander_test.py index a193245..a1d5d48 100644 --- a/src/latexify/transformers/function_expander_test.py +++ b/src/latexify/transformers/function_expander_test.py @@ -27,6 +27,7 @@ def test_exp() -> None: tree = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), @@ -41,10 +42,12 @@ def test_exp_unchanged() -> None: tree = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ) transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -54,6 +57,7 @@ def test_exp_with_attribute() -> None: tree = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), @@ -68,10 +72,12 @@ def test_exp_unchanged_with_attribute() -> None: tree = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], + keywords=[], ) transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -84,8 +90,10 @@ def test_exp_nested1() -> None: ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ) ], + keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), @@ -107,8 +115,10 @@ def test_exp_nested2() -> None: ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ) ], + keywords=[], ) expected = ast.Call( func=ast_utils.make_name("f"), @@ -119,6 +129,7 @@ def test_exp_nested2() -> None: right=ast_utils.make_name("x"), ) ], + keywords=[], ) transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -128,6 +139,7 @@ def test_atan2() -> None: tree = ast.Call( func=ast_utils.make_name("atan2"), args=[ast_utils.make_name("y"), ast_utils.make_name("x")], + keywords=[], ) expected = ast.Call( func=ast_utils.make_name("atan"), @@ -138,6 +150,7 @@ def test_atan2() -> None: right=ast_utils.make_name("x"), ) ], + keywords=[], ) transformed = FunctionExpander({"atan2"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -147,6 +160,7 @@ def test_exp2() -> None: tree = ast.Call( func=ast_utils.make_name("exp2"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.BinOp( left=ast_utils.make_constant(2), @@ -161,11 +175,13 @@ def test_expm1() -> None: tree = ast.Call( func=ast_utils.make_name("expm1"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.BinOp( left=ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], + keywords=[], ), op=ast.Sub(), right=ast_utils.make_constant(1), @@ -178,6 +194,7 @@ def test_hypot() -> None: tree = ast.Call( func=ast_utils.make_name("hypot"), args=[ast_utils.make_name("x"), ast_utils.make_name("y")], + keywords=[], ) expected = ast.Call( func=ast_utils.make_name("sqrt"), @@ -196,13 +213,14 @@ def test_hypot() -> None: ), ) ], + keywords=[], ) transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_hypot_no_args() -> None: - tree = ast.Call(func=ast_utils.make_name("hypot"), args=[]) + tree = ast.Call(func=ast_utils.make_name("hypot"), args=[], keywords=[]) expected = ast_utils.make_constant(0) transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -212,6 +230,7 @@ def test_log1p() -> None: tree = ast.Call( func=ast_utils.make_name("log1p"), args=[ast_utils.make_name("x")], + keywords=[], ) expected = ast.Call( func=ast_utils.make_name("log"), @@ -222,6 +241,7 @@ def test_log1p() -> None: right=ast_utils.make_name("x"), ) ], + keywords=[], ) transformed = FunctionExpander({"log1p"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) @@ -231,6 +251,7 @@ def test_pow() -> None: tree = ast.Call( func=ast_utils.make_name("pow"), args=[ast_utils.make_name("x"), ast_utils.make_name("y")], + keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("x"), diff --git a/src/latexify/transformers/identifier_replacer.py b/src/latexify/transformers/identifier_replacer.py index 2c3320c..aa0296e 100644 --- a/src/latexify/transformers/identifier_replacer.py +++ b/src/latexify/transformers/identifier_replacer.py @@ -4,9 +4,10 @@ import ast import keyword -import sys from typing import cast +from latexify import ast_utils + class IdentifierReplacer(ast.NodeTransformer): """NodeTransformer to replace identifier names. @@ -49,27 +50,23 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: """Visit a FunctionDef node.""" visited = cast(ast.FunctionDef, super().generic_visit(node)) - if sys.version_info.minor < 8: - args = ast.arguments( - args=self._replace_args(visited.args.args), - kwonlyargs=self._replace_args(visited.args.kwonlyargs), - kw_defaults=visited.args.kw_defaults, - defaults=visited.args.defaults, - ) - else: - args = ast.arguments( - posonlyargs=self._replace_args(visited.args.posonlyargs), # from 3.8 - args=self._replace_args(visited.args.args), - kwonlyargs=self._replace_args(visited.args.kwonlyargs), - kw_defaults=visited.args.kw_defaults, - defaults=visited.args.defaults, - ) - - return ast.FunctionDef( + args = ast.arguments( + posonlyargs=self._replace_args(visited.args.posonlyargs), + args=self._replace_args(visited.args.args), + vararg=visited.args.vararg, + kwonlyargs=self._replace_args(visited.args.kwonlyargs), + kw_defaults=visited.args.kw_defaults, + kwarg=visited.args.kwarg, + defaults=visited.args.defaults, + ) + type_params = getattr(visited, "type_params", []) + return ast_utils.create_function_def( name=self._mapping.get(visited.name, visited.name), args=args, body=visited.body, decorator_list=visited.decorator_list, + returns=visited.returns, + type_params=type_params, ) def visit_Name(self, node: ast.Name) -> ast.Name: diff --git a/src/latexify/transformers/identifier_replacer_test.py b/src/latexify/transformers/identifier_replacer_test.py index def9052..16a5ac4 100644 --- a/src/latexify/transformers/identifier_replacer_test.py +++ b/src/latexify/transformers/identifier_replacer_test.py @@ -6,7 +6,7 @@ import pytest -from latexify import test_utils +from latexify import ast_utils, test_utils from latexify.transformers.identifier_replacer import IdentifierReplacer @@ -35,54 +35,12 @@ def test_name_not_replaced() -> None: test_utils.assert_ast_equal(transformed, expected) -@test_utils.require_at_most(7) -def test_functiondef() -> None: - # Subtree of: - # @d - # def f(y=b, *, z=c): - # pass - source = ast.FunctionDef( - name="f", - args=ast.arguments( - args=[ast.arg(arg="y")], - kwonlyargs=[ast.arg(arg="z")], - kw_defaults=[ast.Name(id="c", ctx=ast.Load())], - defaults=[ - ast.Name(id="a", ctx=ast.Load()), - ast.Name(id="b", ctx=ast.Load()), - ], - ), - body=[ast.Pass()], - decorator_list=[ast.Name(id="d", ctx=ast.Load())], - ) - - expected = ast.FunctionDef( - name="F", - args=ast.arguments( - args=[ast.arg(arg="Y")], - kwonlyargs=[ast.arg(arg="Z")], - kw_defaults=[ast.Name(id="C", ctx=ast.Load())], - defaults=[ - ast.Name(id="A", ctx=ast.Load()), - ast.Name(id="B", ctx=ast.Load()), - ], - ), - body=[ast.Pass()], - decorator_list=[ast.Name(id="D", ctx=ast.Load())], - ) - - mapping = {x: x.upper() for x in "abcdfyz"} - transformed = IdentifierReplacer(mapping).visit(source) - test_utils.assert_ast_equal(transformed, expected) - - -@test_utils.require_at_least(8) def test_functiondef_with_posonlyargs() -> None: # Subtree of: # @d # def f(x=a, /, y=b, *, z=c): # pass - source = ast.FunctionDef( + source = ast_utils.create_function_def( name="f", args=ast.arguments( posonlyargs=[ast.arg(arg="x")], @@ -96,9 +54,12 @@ def test_functiondef_with_posonlyargs() -> None: ), body=[ast.Pass()], decorator_list=[ast.Name(id="d", ctx=ast.Load())], + returns=None, + type_comment=None, + type_params=[], ) - expected = ast.FunctionDef( + expected = ast_utils.create_function_def( name="F", args=ast.arguments( posonlyargs=[ast.arg(arg="X")], @@ -112,6 +73,9 @@ def test_functiondef_with_posonlyargs() -> None: ), body=[ast.Pass()], decorator_list=[ast.Name(id="D", ctx=ast.Load())], + returns=None, + type_comment=None, + type_params=[], ) mapping = {x: x.upper() for x in "abcdfxyz"}