Skip to content

Commit

Permalink
feat: extend StringTemplate for unnamed fields
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Nov 16, 2023
1 parent 34b4d78 commit 09c3279
Showing 1 changed file with 138 additions and 25 deletions.
163 changes: 138 additions & 25 deletions dol/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ def _path_get(
except caught_errors as error:
if callable(on_error):
return on_error(
dict(obj=obj, path=path, result=result, k=k, error=error,)
dict(
obj=obj,
path=path,
result=result,
k=k,
error=error,
)
)
elif isinstance(on_error, str):
# use on_error as a message, raising the same error class
Expand Down Expand Up @@ -429,8 +435,8 @@ def path_filter(
>>> vals
[42, 'meaning of life']
Note: pkv_filt is first to match the order of the arguments of the
builtin filter function.
Note: pkv_filt is first to match the order of the arguments of the
builtin filter function.
"""
_leaf_yield = partial(_path_matcher_leaf_yield, pkv_filt, None)
kwargs = dict(leaf_yield=_leaf_yield, breadth_first=breadth_first)
Expand Down Expand Up @@ -652,7 +658,11 @@ def __init__(self, _prefix=''):

@store_decorator
def mk_relative_path_store(
store_cls=None, *, name=None, with_key_validation=False, prefix_attr='_prefix',
store_cls=None,
*,
name=None,
with_key_validation=False,
prefix_attr='_prefix',
):
"""
Expand Down Expand Up @@ -1097,6 +1107,51 @@ def _func(self, *args, **kwargs):
return _func


from typing import Iterable, Tuple

string_formatter = string.Formatter()


def string_unparse(parsing_result: Iterable[Tuple[str, str, str, str]]):
"""The inverse of string.Formatter.parse
Will ravel
>>> import string
>>> formatter = string.Formatter()
>>> string_unparse(formatter.parse('literal{name!c:spec}'))
'literal{name!c:spec}'
"""
reconstructed = ''
for literal_text, field_name, format_spec, conversion in parsing_result:
reconstructed += literal_text
if field_name is not None:
field = f'{{{field_name}'
if conversion:
assert (
len(conversion) == 1
), f"conversion can only be a single character: {conversion=}"
field += f'!{conversion}'
if format_spec:
field += f':{format_spec}'
field += '}'
reconstructed += field
return reconstructed


def _field_names(string_template):
"""
Returns the field names in a string template.
>>> _field_names("{name} is {age} years old.")
('name', 'age')
"""
parsing_result = string_formatter.parse(string_template)
return tuple(
field_name for _, field_name, _, _ in parsing_result if field_name is not None
)


Codec = namedtuple('Codec', 'encoder decoder')
FieldTypeNames = Literal['str', 'dict', 'tuple', 'namedtuple', 'simple_str']

Expand Down Expand Up @@ -1147,45 +1202,102 @@ class StringTemplate:
('Alice', '30')
"""

_formatter = string_formatter

def __init__(
self,
template: str,
*,
field_patterns: dict = None,
simple_str_sep: str = None,
namedtuple_type_name: str = 'NamedTuple',
dflt_pattern: str = '.*',
):
self.template = template
self._original_template = template
self.field_patterns = field_patterns or {}
self.simple_str_sep = simple_str_sep
self.namedtuple_type_name = namedtuple_type_name
self.regex = None
self._construct_regex()
self.dflt_pattern = dflt_pattern
self.template = self._normalize_template(template)
self.field_names = _field_names(self.template)
self.regex = self._compile_regex(self.template)

def _normalize_template(self, template):
r"""Normalizes the template so that each placeholder has a name and format_spec.
>>> StringTemplate('{}.ext').template
'{_1:.*}.ext'
>>> StringTemplate('{name}.ext').template
'{name:.*}.ext'
>>> StringTemplate('{:\w+}.ext').template
'{_1:\w+}.ext'
>>> StringTemplate('{name:\w+}.ext').template
'{name:\w+}.ext'
"""

def get_format_spec(field_name, current_format_spec):
if current_format_spec == '':
return self.field_patterns.get(field_name, self.dflt_pattern)
return current_format_spec

def parse_and_transform():
for index, (literal_text, field_name, format_spec, conversion) in enumerate(
self._formatter.parse(template), 1
):
yield (
literal_text,
f"_{index}" if field_name == '' else field_name,
get_format_spec(field_name, format_spec),
conversion,
)

return string_unparse(parse_and_transform())

# TODO: For now we harded coded the "interpret format_specs as field patterns",
# but we could make this more general and allow for other types of
# "interpretations" (e.g. "interpret format_specs as cast functions").
# Note: We removed the cast functions (that existed in legacy StrTupleDict class)
# to keep it simple, for now. The idea being that if you want to cast, you can
# do it yourself by adding egress/ingress to codecs.
def _construct_regex(self):
formatter = string.Formatter()
pattern = self.template
self.field_names = []
for literal_text, field_name, format_spec, conversion in formatter.parse(
self.template
):
# Check if the field_name has either a format_spec (regex) in the template
# or a matching regex in the field_patterns dictionary before adding it
# to the field_names list.
if field_name and (format_spec or field_name in self.field_patterns):
self.field_names.append(field_name)
regex = format_spec or self.field_patterns.get(field_name, '.*?')
to_replace = (
'{' + field_name + (':' + format_spec if format_spec else '') + '}'
def _compile_regex(self, template):
r"""Parses the template, generating regex for matching the template.
Essentially, it weaves together the literal text parts and the format_specs
parts, transformed into name-caputuring regex patterns.
Note that the literal text parts are regex-escaped so that they are not
interpreted as regex. For example, if the template is "{name}.txt", the
literal text part is replaced with "\\.txt", to avoid that the "." is
interpreted as a regex wildcard. This would otherwise match any character.
Instead, the escaped dot is matched literally.
See https://docs.python.org/3/library/re.html#re.escape for more information.
>>> StringTemplate('{}.ext').regex.pattern
'(?P<_1>.*)\\.ext'
>>> StringTemplate('{name}.ext').regex.pattern
'(?P<name>.*)\\.ext'
>>> StringTemplate('{:\w+}.ext').regex.pattern
'(?P<_1>\\w+)\\.ext'
>>> StringTemplate('{name:\w+}.ext').regex.pattern
'(?P<name>\\w+)\\.ext'
"""

def mk_named_capture_group(field_name, format_spec):
if field_name and format_spec:
return f"(?P<{field_name}>{format_spec})"
else:
return ""

def generate_pattern_parts(template):
parts = self._formatter.parse(template)
for literal_text, field_name, format_spec, conversion in parts:
yield (
re.escape(literal_text)
+ mk_named_capture_group(field_name, format_spec)
)
pattern = pattern.replace(to_replace, f'(?P<{field_name}>{regex})')
self.regex = re.compile(pattern)

return re.compile(''.join(generate_pattern_parts(template)))

def codec(self, source: FieldTypeNames, target: FieldTypeNames):
"""Makes a ``(coder, decoder)`` pair for the given source and target types.
Expand Down Expand Up @@ -1304,7 +1416,8 @@ def tuple_to_str(self, params: tuple) -> str:

# @_return_none_if_none_input
def dict_to_namedtuple(
self, params: dict,
self,
params: dict,
):
"""Generates a namedtuple from the dictionary values based on the template.
Expand Down

0 comments on commit 09c3279

Please sign in to comment.