Skip to content

Commit

Permalink
Allow config scopes with type annotations. (#868)
Browse files Browse the repository at this point in the history
* Allow config scopes with type annotations.

* Reorder imports.

* Fix Flake8 errors.

* Black passed.

* Fix inline definitions.
  • Loading branch information
vnmabus authored Sep 9, 2022
1 parent ff189c5 commit 940ac81
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,5 @@ com_crashlytics_export_strings.xml
# GEdit temporary files
*~

/.pytest_cache/
pip-wheel-metadata/
65 changes: 50 additions & 15 deletions sacred/config/config_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import inspect
import io
import re
from tokenize import tokenize, TokenError, COMMENT
from copy import copy
import textwrap
import token

from copy import copy
from sacred import SETTINGS
from sacred.config.config_summary import ConfigSummary
from sacred.config.utils import dogmatize, normalize_or_die, recursive_fill_in
from sacred.config.signature import get_argspec
from sacred.utils import ConfigError
from tokenize import generate_tokens, tokenize, TokenError, COMMENT


class ConfigScope:
Expand Down Expand Up @@ -94,20 +96,53 @@ def __call__(self, fixed=None, preset=None, fallback=None):

def get_function_body(func):
func_code_lines, start_idx = inspect.getsourcelines(func)
func_code = "".join(func_code_lines)
arg = "(?:[a-zA-Z_][a-zA-Z0-9_]*)"
arguments = r"{0}(?:\s*,\s*{0})*,?".format(arg)
func_def = re.compile(
r"^[ \t]*def[ \t]*{}[ \t]*\(\s*({})?\s*\)[ \t]*:[ \t]*(?:#[^\n]*)?\n".format(
func.__name__, arguments
),
flags=re.MULTILINE,
func_code = textwrap.dedent("".join(func_code_lines))
# Lines are now dedented
func_code_lines = func_code.splitlines(True)
func_ast = ast.parse(func_code)
first_code = func_ast.body[0].body[0]
line_offset = first_code.lineno
col_offset = first_code.col_offset

# Add also previous empty / comment lines
acceptable_tokens = {
token.NEWLINE,
token.INDENT,
token.DEDENT,
token.COMMENT,
token.ENDMARKER,
}
last_token_type_acceptable = True
line_offset_fixed = line_offset
col_offset_fixed = col_offset
iterator = iter(func_code_lines)
for parsed_token in generate_tokens(lambda: next(iterator)):

token_acceptable = parsed_token.type in acceptable_tokens or (
parsed_token.type == token.NL and last_token_type_acceptable
)

# If the token ends after the start of the first code,
# we have finished
if parsed_token.end[0] > line_offset or (
parsed_token.end[0] == line_offset and parsed_token.end[1] >= col_offset
):
break

if not token_acceptable:
line_offset_fixed = parsed_token.end[0]
col_offset_fixed = parsed_token.end[1]

last_token_type_acceptable = token_acceptable

func_body = (
# First line, without first part if needed
func_code_lines[line_offset_fixed - 1][col_offset_fixed:]
# Rest of the lines
+ "".join(func_code_lines[line_offset_fixed:])
)
defs = list(re.finditer(func_def, func_code))
assert defs
line_offset = start_idx + func_code[: defs[0].end()].count("\n") - 1
func_body = func_code[defs[0].end() :]
return func_body, line_offset

return func_body, start_idx + line_offset_fixed


def is_empty_or_comment(line):
Expand Down
25 changes: 23 additions & 2 deletions tests/test_config/test_config_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def evil_indentation_func(a,
def subfunc():
return 23

body = '''# Lets do the most evil things with indentation
body = ''' # test comment
# Lets do the most evil things with indentation
# 1
# 2
# ran
Expand All @@ -343,7 +344,8 @@ def subfunc():
return 23
'''

dedented_body = '''# Lets do the most evil things with indentation
dedented_body = '''# test comment
# Lets do the most evil things with indentation
# 1
# 2
# ran
Expand Down Expand Up @@ -382,3 +384,22 @@ def test_get_function_body():
def test_config_scope_can_deal_with_indentation_madness():
# assert_no_raise:
ConfigScope(evil_indentation_func)


def test_conf_scope_with_type_info():
@ConfigScope
def conf_scope(a: int) -> None:
answer = 2 * a

cfg = conf_scope(preset={"a": 21})
assert cfg["answer"] == 42


def test_conf_scope_in_same_line():
# fmt: off
@ConfigScope
def conf_scope(a: int) -> None: answer = 2 * a
# fmt: on

cfg = conf_scope(preset={"a": 21})
assert cfg["answer"] == 42

0 comments on commit 940ac81

Please sign in to comment.