Skip to content

Commit

Permalink
feat: Extended StringTemplate
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Nov 16, 2023
1 parent 09c3279 commit 6529e6e
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 65 deletions.
184 changes: 133 additions & 51 deletions dol/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,10 @@ def _field_names(string_template):
)


def identity(x):
return x


Codec = namedtuple('Codec', 'encoder decoder')
FieldTypeNames = Literal['str', 'dict', 'tuple', 'namedtuple', 'simple_str']

Expand All @@ -1176,7 +1180,7 @@ class StringTemplate:
template: A template string with fields to be extracted or filled in.
field_patterns: A dictionary of field names and their regex patterns.
simple_str_sep: A separator string for simple strings (i.e. strings without
fields). If `None`, the template string will be used as the separator.
fields).
Examples:
>>> st = StringTemplate(
Expand Down Expand Up @@ -1209,57 +1213,103 @@ def __init__(
template: str,
*,
field_patterns: dict = None,
simple_str_sep: str = None,
to_str_funcs: dict = None,
from_str_funcs: dict = None,
simple_str_sep: str = ',',
namedtuple_type_name: str = 'NamedTuple',
dflt_pattern: str = '.*',
):
self._original_template = template
self.field_patterns = field_patterns or {}
self.simple_str_sep = simple_str_sep
self.namedtuple_type_name = namedtuple_type_name
self.dflt_pattern = dflt_pattern
self.template = self._normalize_template(template)
self.field_names = _field_names(self.template)

(
self.template,
self.field_names,
to_str_funcs,
field_patterns_,
) = self._extract_template_info(template)

self.field_patterns = dict(
{field: self.dflt_pattern for field in self.field_names},
**dict(field_patterns_, **(field_patterns or {})),
)
self.to_str_funcs = dict(
{field: str for field in self.field_names},
**dict(to_str_funcs, **(to_str_funcs or {})),
)
self.from_str_funcs = dict(
{field: identity for field in self.field_names}, **(from_str_funcs or {})
)
self.regex = self._compile_regex(self.template)

def _normalize_template(self, template):
r"""Normalizes the template so that each placeholder has a name and format_spec.
def _extract_template_info(self, template):
r"""Extracts information from the template. Namely:
- normalized_template: A template where each placeholder has a field name
(if not given, "_{index}" will be used)
>>> StringTemplate('{}.ext').template
'{_1:.*}.ext'
>>> StringTemplate('{name}.ext').template
'{name:.*}.ext'
>>> StringTemplate('{:\w+}.ext').template
'{_1:\w+}.ext'
>>> StringTemplate('{name:\w+}.ext').template
'{name:\w+}.ext'
- field_names: The tuple of field names in the order they appear in template
- to_str_funcs: A dict of field names and their corresponding to_str functions,
which will be used to convert the field values to strings when generating a
string.
- field_patterns_: A dict of field names and their corresponding regex patterns,
which will be used to extract the field values from a string.
These four values are used in the init to compute the parameters of the
instance.
>>> st = StringTemplate('{:02.0f}/{name::\w+}')
>>> st.template
'{_1}/{name}'
>>> st.field_names
('_1', 'name')
>>> st.field_patterns
{'_1': '.*', 'name': '\\w+'}
>>> st.regex.pattern
'(?P<_1>.*)/(?P<name>\\w+)'
>>> to_str_funcs = st.to_str_funcs
>>> to_str_funcs['_1'](3)
'03'
>>> to_str_funcs['name']('Alice')
'Alice'
"""

def get_format_spec(field_name, current_format_spec):
if current_format_spec == '':
return self.field_patterns.get(field_name, self.dflt_pattern)
return current_format_spec
field_names = []
field_patterns_ = {}
to_str_funcs = {}

def parse_and_transform():
for index, (literal_text, field_name, format_spec, conversion) in enumerate(
self._formatter.parse(template), 1
):
field_name = f"_{index}" if field_name == '' else field_name
if field_name is not None:
field_names.append(field_name) # remember the field name
# extract format and pattern information:
if ':' not in format_spec:
format_spec += ':'
to_str_func_format, pattern = format_spec.split(':')
if to_str_func_format:
to_str_funcs[field_name] = (
'{' + f":{to_str_func_format}" + '}'
).format
field_patterns_[field_name] = pattern or self.dflt_pattern
# At this point you should have a valid field_name and empty format_spec
yield (
literal_text,
f"_{index}" if field_name == '' else field_name,
get_format_spec(field_name, format_spec),
field_name,
'',
conversion,
)

return string_unparse(parse_and_transform())
normalized_template = string_unparse(parse_and_transform())
return normalized_template, tuple(field_names), to_str_funcs, field_patterns_

# TODO: For now we harded coded the "interpret format_specs as field patterns",
# but we could make this more general and allow for other types of
# "interpretations" (e.g. "interpret format_specs as cast functions").
# Note: We removed the cast functions (that existed in legacy StrTupleDict class)
# to keep it simple, for now. The idea being that if you want to cast, you can
# do it yourself by adding egress/ingress to codecs.
def _compile_regex(self, template):
r"""Parses the template, generating regex for matching the template.
Essentially, it weaves together the literal text parts and the format_specs
Expand All @@ -1276,29 +1326,36 @@ def _compile_regex(self, template):
'(?P<_1>.*)\\.ext'
>>> StringTemplate('{name}.ext').regex.pattern
'(?P<name>.*)\\.ext'
>>> StringTemplate('{:\w+}.ext').regex.pattern
>>> StringTemplate('{::\w+}.ext').regex.pattern
'(?P<_1>\\w+)\\.ext'
>>> StringTemplate('{name:\w+}.ext').regex.pattern
>>> StringTemplate('{name::\w+}.ext').regex.pattern
'(?P<name>\\w+)\\.ext'
>>> StringTemplate('{:0.02f:\w+}.ext').regex.pattern
'(?P<_1>\\w+)\\.ext'
>>> StringTemplate('{name:0.02f:\w+}.ext').regex.pattern
'(?P<name>\\w+)\\.ext'
"""

def mk_named_capture_group(field_name, format_spec):
if field_name and format_spec:
return f"(?P<{field_name}>{format_spec})"
def mk_named_capture_group(field_name):
if field_name:
return f"(?P<{field_name}>{self.field_patterns[field_name]})"
else:
return ""

def generate_pattern_parts(template):
parts = self._formatter.parse(template)
for literal_text, field_name, format_spec, conversion in parts:
yield (
re.escape(literal_text)
+ mk_named_capture_group(field_name, format_spec)
)
for literal_text, field_name, _, _ in parts:
yield re.escape(literal_text) + mk_named_capture_group(field_name)

return re.compile(''.join(generate_pattern_parts(template)))

@staticmethod
def _assert_field_type(field_type: FieldTypeNames, name='field_type'):
if field_type not in FieldTypeNames.__args__:
raise ValueError(
f"{name} must be one of {FieldTypeNames}. Was: {field_type}"
)

def codec(self, source: FieldTypeNames, target: FieldTypeNames):
"""Makes a ``(coder, decoder)`` pair for the given source and target types.
Expand All @@ -1312,10 +1369,19 @@ def codec(self, source: FieldTypeNames, target: FieldTypeNames):
>>> encoder({'name': 'Alice', 'age': '30'})
('Alice', '30')
"""
self._assert_field_type(target, 'target')
self._assert_field_type(source, 'source')
coder = getattr(self, f'{source}_to_{target}')
decoder = getattr(self, f'{target}_to_{source}')
return Codec(coder, decoder)

def filt_iter(self, field_type: FieldTypeNames):
from dol.trans import filt_iter

self._assert_field_type(field_type, 'field_type')
filt_func = getattr(self, f'match_{field_type}')
return filt_iter(filt=filt_func)

# @_return_none_if_none_input
def str_to_dict(self, s: str) -> dict:
"""Parses the input string and returns a dictionary of extracted values.
Expand All @@ -1332,7 +1398,7 @@ def str_to_dict(self, s: str) -> dict:
return None
match = self.regex.match(s)
if match:
return match.groupdict()
return {k: self.from_str_funcs[k](v) for k, v in match.groupdict().items()}
else:
raise ValueError(f"String '{s}' does not match the template.")

Expand All @@ -1350,6 +1416,7 @@ def dict_to_str(self, params: dict) -> str:
"""
if params is None:
return None
params = {k: self.to_str_funcs[k](v) for k, v in params.items()}
return self.template.format(**params)

# @_return_none_if_none_input
Expand Down Expand Up @@ -1400,7 +1467,7 @@ def str_to_tuple(self, s: str) -> tuple:
return self.dict_to_tuple(self.str_to_dict(s))

# @_return_none_if_none_input
def tuple_to_str(self, params: tuple) -> str:
def tuple_to_str(self, param_vals: tuple) -> str:
"""Generates a string from the tuple values based on the template.
>>> st = StringTemplate(
Expand All @@ -1410,9 +1477,9 @@ def tuple_to_str(self, params: tuple) -> str:
>>> st.tuple_to_str(('Alice', '30'))
'Alice is 30 years old.'
"""
if params is None:
if param_vals is None:
return None
return self.dict_to_str(self.tuple_to_dict(params))
return self.dict_to_str(self.tuple_to_dict(param_vals))

# @_return_none_if_none_input
def dict_to_namedtuple(
Expand Down Expand Up @@ -1450,25 +1517,22 @@ def namedtuple_to_dict(self, nt):
return dict(nt._asdict()) # TODO: Find way that doesn't involve private method

# @_return_none_if_none_input
def str_to_simple_str(self, s: str, sep: str):
def str_to_simple_str(self, s: str, sep: str = None):
"""Converts a string to a simple string (i.e. a simple character-delimited string).
>>> st = StringTemplate(
... "{name} is {age} years old.",
... field_patterns={"name": r"\w+", "age": r"\d+"}
... )
>>> st.str_to_simple_str("Alice is 30 years old.")
'Alice,30'
>>> st.str_to_simple_str("Alice is 30 years old.", '-')
'Alice-30'
"""
sep = sep or self.simple_str_sep
if s is None:
return None
elif sep is None:
if self.simple_str_sep is None:
raise ValueError(
'Need to specify a sep (at method call time), or a simple_str_sep '
'(at instiantiation time) to use str_to_simple_str'
)
return sep.join(self.str_to_tuple(s))
return sep.join(self.to_str_funcs[k](v) for k, v in self.str_to_dict(s).items())

# @_return_none_if_none_input
def simple_str_to_str(self, ss: str, sep: str):
Expand All @@ -1484,3 +1548,21 @@ def simple_str_to_str(self, ss: str, sep: str):
if ss is None:
return None
return self.tuple_to_str(tuple(ss.split(sep)))

def match_str(self, s: str) -> bool:
return self.regex.match(s) is not None

def match_dict(self, params: dict) -> bool:
return self.match_str(self.dict_to_str(params))
# Note: Could do:
# return all(self.field_patterns[k].match(v) for k, v in params.items())
# but not sure that's even quicker (given regex is compiled)

def match_tuple(self, param_vals: tuple) -> bool:
return self.match_str(self.tuple_to_str(param_vals))

def match_namedtuple(self, params: namedtuple) -> bool:
return self.match_str(self.namedtuple_to_str(params))

def match_simple_str(self, params: str) -> bool:
return self.match_str(self.simple_str_to_str(params))
53 changes: 40 additions & 13 deletions dol/tests/test_paths.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,51 @@
"""Tests for paths.py"""

from dol import StringTemplate

def test_string_template():

def test_string_template_template_construction():
assert StringTemplate('{}.ext').template == '{_1}.ext'
assert StringTemplate('{name}.ext').template == '{name}.ext'
assert StringTemplate('{::\w+}.ext').template == '{_1}.ext'
assert StringTemplate('{name::\w+}.ext').template == '{name}.ext'
assert StringTemplate('{name::\w+}.ext').template == '{name}.ext'
assert StringTemplate('{name:0.02f}.ext').template == '{name}.ext'
assert StringTemplate('{name:0.02f:\w+}.ext').template == '{name}.ext'
assert StringTemplate('{:0.02f:\w+}.ext').template == '{_1}.ext'


def test_string_template_regex():
assert StringTemplate('{}.ext').regex.pattern == '(?P<_1>.*)\\.ext'
assert StringTemplate('{name}.ext').regex.pattern == '(?P<name>.*)\\.ext'
assert StringTemplate('{::\w+}.ext').regex.pattern == '(?P<_1>\\w+)\\.ext'
assert StringTemplate('{name::\w+}.ext').regex.pattern == '(?P<name>\\w+)\\.ext'
assert StringTemplate('{:0.02f:\w+}.ext').regex.pattern == '(?P<_1>\\w+)\\.ext'
assert StringTemplate('{name:0.02f:\w+}.ext').regex.pattern == '(?P<name>\\w+)\\.ext'


def test_string_template_simple():
from dol.paths import StringTemplate
from collections import namedtuple

st = StringTemplate(
'{name} is {age} years old.', field_patterns={'name': r'\w+', 'age': r'\d+'}
'root/{name}/v_{version}.json',
field_patterns={'name': r'\w+', 'version': r'\d+'},
from_str_funcs={'version': int},
)

assert st.str_to_dict('Alice is 30 years old.') == {'name': 'Alice', 'age': '30'}
assert st.dict_to_str({'name': 'Alice', 'age': '30'}) == 'Alice is 30 years old.'
assert st.dict_to_tuple({'name': 'Alice', 'age': '30'}) == ('Alice', '30')
assert st.tuple_to_dict(('Alice', '30')) == {'name': 'Alice', 'age': '30'}
assert st.str_to_tuple('Alice is 30 years old.') == ('Alice', '30')
assert st.tuple_to_str(('Alice', '30')) == 'Alice is 30 years old.'
assert st.str_to_dict('root/Alice/v_30.json') == {'name': 'Alice', 'version': 30}
assert st.dict_to_str({'name': 'Alice', 'version': 30}) == 'root/Alice/v_30.json'
assert st.dict_to_tuple({'name': 'Alice', 'version': 30}) == ('Alice', 30)
assert st.tuple_to_dict(('Alice', 30)) == {'name': 'Alice', 'version': 30}
assert st.str_to_tuple('root/Alice/v_30.json') == ('Alice', 30)
assert st.tuple_to_str(('Alice', 30)) == 'root/Alice/v_30.json'

VersionedFile = st.dict_to_namedtuple({'name': 'Alice', 'version': 30})

Person = st.dict_to_namedtuple({'name': 'Alice', 'age': '30'})
assert Person == namedtuple('Person', ['name', 'age'])('Alice', '30')
assert st.namedtuple_to_dict(Person) == {'name': 'Alice', 'age': '30'}
from collections import namedtuple
assert VersionedFile == namedtuple('VersionedFile', ['name', 'version'])('Alice', 30)
assert st.namedtuple_to_dict(VersionedFile) == {'name': 'Alice', 'version': 30}

assert st.str_to_simple_str('Alice is 30 years old.', '-') == 'Alice-30'
assert st.simple_str_to_str('Alice-30', '-') == 'Alice is 30 years old.'
assert st.str_to_simple_str('root/Alice/v_30.json') == 'Alice,30'
assert st.str_to_simple_str('root/Alice/v_30.json', '-') == 'Alice-30'
assert st.simple_str_to_str('Alice-30', '-') == 'root/Alice/v_30.json'
2 changes: 1 addition & 1 deletion dol/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __missing__(self, k):
user_value = value_preprocessor(user_value)
self[k] = user_value
else:
super().__missing__(k)
super(type(self), self).__missing__(k)

store.__missing__ = __missing__
return store
Expand Down

0 comments on commit 6529e6e

Please sign in to comment.