diff --git a/dol/paths.py b/dol/paths.py index 3c29d3bf..55571bfd 100644 --- a/dol/paths.py +++ b/dol/paths.py @@ -1152,6 +1152,10 @@ def _field_names(string_template): ) +def identity(x): + return x + + Codec = namedtuple('Codec', 'encoder decoder') FieldTypeNames = Literal['str', 'dict', 'tuple', 'namedtuple', 'simple_str'] @@ -1176,7 +1180,7 @@ class StringTemplate: template: A template string with fields to be extracted or filled in. field_patterns: A dictionary of field names and their regex patterns. simple_str_sep: A separator string for simple strings (i.e. strings without - fields). If `None`, the template string will be used as the separator. + fields). Examples: >>> st = StringTemplate( @@ -1209,57 +1213,103 @@ def __init__( template: str, *, field_patterns: dict = None, - simple_str_sep: str = None, + to_str_funcs: dict = None, + from_str_funcs: dict = None, + simple_str_sep: str = ',', namedtuple_type_name: str = 'NamedTuple', dflt_pattern: str = '.*', ): self._original_template = template - self.field_patterns = field_patterns or {} self.simple_str_sep = simple_str_sep self.namedtuple_type_name = namedtuple_type_name self.dflt_pattern = dflt_pattern - self.template = self._normalize_template(template) - self.field_names = _field_names(self.template) + + ( + self.template, + self.field_names, + to_str_funcs, + field_patterns_, + ) = self._extract_template_info(template) + + self.field_patterns = dict( + {field: self.dflt_pattern for field in self.field_names}, + **dict(field_patterns_, **(field_patterns or {})), + ) + self.to_str_funcs = dict( + {field: str for field in self.field_names}, + **dict(to_str_funcs, **(to_str_funcs or {})), + ) + self.from_str_funcs = dict( + {field: identity for field in self.field_names}, **(from_str_funcs or {}) + ) self.regex = self._compile_regex(self.template) - def _normalize_template(self, template): - r"""Normalizes the template so that each placeholder has a name and format_spec. + def _extract_template_info(self, template): + r"""Extracts information from the template. Namely: + + - normalized_template: A template where each placeholder has a field name + (if not given, "_{index}" will be used) - >>> StringTemplate('{}.ext').template - '{_1:.*}.ext' - >>> StringTemplate('{name}.ext').template - '{name:.*}.ext' - >>> StringTemplate('{:\w+}.ext').template - '{_1:\w+}.ext' - >>> StringTemplate('{name:\w+}.ext').template - '{name:\w+}.ext' + - field_names: The tuple of field names in the order they appear in template + + - to_str_funcs: A dict of field names and their corresponding to_str functions, + which will be used to convert the field values to strings when generating a + string. + + - field_patterns_: A dict of field names and their corresponding regex patterns, + which will be used to extract the field values from a string. + + These four values are used in the init to compute the parameters of the + instance. + + >>> st = StringTemplate('{:02.0f}/{name::\w+}') + >>> st.template + '{_1}/{name}' + >>> st.field_names + ('_1', 'name') + >>> st.field_patterns + {'_1': '.*', 'name': '\\w+'} + >>> st.regex.pattern + '(?P<_1>.*)/(?P\\w+)' + >>> to_str_funcs = st.to_str_funcs + >>> to_str_funcs['_1'](3) + '03' + >>> to_str_funcs['name']('Alice') + 'Alice' """ - def get_format_spec(field_name, current_format_spec): - if current_format_spec == '': - return self.field_patterns.get(field_name, self.dflt_pattern) - return current_format_spec + field_names = [] + field_patterns_ = {} + to_str_funcs = {} def parse_and_transform(): for index, (literal_text, field_name, format_spec, conversion) in enumerate( self._formatter.parse(template), 1 ): + field_name = f"_{index}" if field_name == '' else field_name + if field_name is not None: + field_names.append(field_name) # remember the field name + # extract format and pattern information: + if ':' not in format_spec: + format_spec += ':' + to_str_func_format, pattern = format_spec.split(':') + if to_str_func_format: + to_str_funcs[field_name] = ( + '{' + f":{to_str_func_format}" + '}' + ).format + field_patterns_[field_name] = pattern or self.dflt_pattern + # At this point you should have a valid field_name and empty format_spec yield ( literal_text, - f"_{index}" if field_name == '' else field_name, - get_format_spec(field_name, format_spec), + field_name, + '', conversion, ) - return string_unparse(parse_and_transform()) + normalized_template = string_unparse(parse_and_transform()) + return normalized_template, tuple(field_names), to_str_funcs, field_patterns_ - # TODO: For now we harded coded the "interpret format_specs as field patterns", - # but we could make this more general and allow for other types of - # "interpretations" (e.g. "interpret format_specs as cast functions"). - # Note: We removed the cast functions (that existed in legacy StrTupleDict class) - # to keep it simple, for now. The idea being that if you want to cast, you can - # do it yourself by adding egress/ingress to codecs. def _compile_regex(self, template): r"""Parses the template, generating regex for matching the template. Essentially, it weaves together the literal text parts and the format_specs @@ -1276,29 +1326,36 @@ def _compile_regex(self, template): '(?P<_1>.*)\\.ext' >>> StringTemplate('{name}.ext').regex.pattern '(?P.*)\\.ext' - >>> StringTemplate('{:\w+}.ext').regex.pattern + >>> StringTemplate('{::\w+}.ext').regex.pattern '(?P<_1>\\w+)\\.ext' - >>> StringTemplate('{name:\w+}.ext').regex.pattern + >>> StringTemplate('{name::\w+}.ext').regex.pattern + '(?P\\w+)\\.ext' + >>> StringTemplate('{:0.02f:\w+}.ext').regex.pattern + '(?P<_1>\\w+)\\.ext' + >>> StringTemplate('{name:0.02f:\w+}.ext').regex.pattern '(?P\\w+)\\.ext' - """ - def mk_named_capture_group(field_name, format_spec): - if field_name and format_spec: - return f"(?P<{field_name}>{format_spec})" + def mk_named_capture_group(field_name): + if field_name: + return f"(?P<{field_name}>{self.field_patterns[field_name]})" else: return "" def generate_pattern_parts(template): parts = self._formatter.parse(template) - for literal_text, field_name, format_spec, conversion in parts: - yield ( - re.escape(literal_text) - + mk_named_capture_group(field_name, format_spec) - ) + for literal_text, field_name, _, _ in parts: + yield re.escape(literal_text) + mk_named_capture_group(field_name) return re.compile(''.join(generate_pattern_parts(template))) + @staticmethod + def _assert_field_type(field_type: FieldTypeNames, name='field_type'): + if field_type not in FieldTypeNames.__args__: + raise ValueError( + f"{name} must be one of {FieldTypeNames}. Was: {field_type}" + ) + def codec(self, source: FieldTypeNames, target: FieldTypeNames): """Makes a ``(coder, decoder)`` pair for the given source and target types. @@ -1312,10 +1369,19 @@ def codec(self, source: FieldTypeNames, target: FieldTypeNames): >>> encoder({'name': 'Alice', 'age': '30'}) ('Alice', '30') """ + self._assert_field_type(target, 'target') + self._assert_field_type(source, 'source') coder = getattr(self, f'{source}_to_{target}') decoder = getattr(self, f'{target}_to_{source}') return Codec(coder, decoder) + def filt_iter(self, field_type: FieldTypeNames): + from dol.trans import filt_iter + + self._assert_field_type(field_type, 'field_type') + filt_func = getattr(self, f'match_{field_type}') + return filt_iter(filt=filt_func) + # @_return_none_if_none_input def str_to_dict(self, s: str) -> dict: """Parses the input string and returns a dictionary of extracted values. @@ -1332,7 +1398,7 @@ def str_to_dict(self, s: str) -> dict: return None match = self.regex.match(s) if match: - return match.groupdict() + return {k: self.from_str_funcs[k](v) for k, v in match.groupdict().items()} else: raise ValueError(f"String '{s}' does not match the template.") @@ -1350,6 +1416,7 @@ def dict_to_str(self, params: dict) -> str: """ if params is None: return None + params = {k: self.to_str_funcs[k](v) for k, v in params.items()} return self.template.format(**params) # @_return_none_if_none_input @@ -1400,7 +1467,7 @@ def str_to_tuple(self, s: str) -> tuple: return self.dict_to_tuple(self.str_to_dict(s)) # @_return_none_if_none_input - def tuple_to_str(self, params: tuple) -> str: + def tuple_to_str(self, param_vals: tuple) -> str: """Generates a string from the tuple values based on the template. >>> st = StringTemplate( @@ -1410,9 +1477,9 @@ def tuple_to_str(self, params: tuple) -> str: >>> st.tuple_to_str(('Alice', '30')) 'Alice is 30 years old.' """ - if params is None: + if param_vals is None: return None - return self.dict_to_str(self.tuple_to_dict(params)) + return self.dict_to_str(self.tuple_to_dict(param_vals)) # @_return_none_if_none_input def dict_to_namedtuple( @@ -1450,25 +1517,22 @@ def namedtuple_to_dict(self, nt): return dict(nt._asdict()) # TODO: Find way that doesn't involve private method # @_return_none_if_none_input - def str_to_simple_str(self, s: str, sep: str): + def str_to_simple_str(self, s: str, sep: str = None): """Converts a string to a simple string (i.e. a simple character-delimited string). >>> st = StringTemplate( ... "{name} is {age} years old.", ... field_patterns={"name": r"\w+", "age": r"\d+"} ... ) + >>> st.str_to_simple_str("Alice is 30 years old.") + 'Alice,30' >>> st.str_to_simple_str("Alice is 30 years old.", '-') 'Alice-30' """ + sep = sep or self.simple_str_sep if s is None: return None - elif sep is None: - if self.simple_str_sep is None: - raise ValueError( - 'Need to specify a sep (at method call time), or a simple_str_sep ' - '(at instiantiation time) to use str_to_simple_str' - ) - return sep.join(self.str_to_tuple(s)) + return sep.join(self.to_str_funcs[k](v) for k, v in self.str_to_dict(s).items()) # @_return_none_if_none_input def simple_str_to_str(self, ss: str, sep: str): @@ -1484,3 +1548,21 @@ def simple_str_to_str(self, ss: str, sep: str): if ss is None: return None return self.tuple_to_str(tuple(ss.split(sep))) + + def match_str(self, s: str) -> bool: + return self.regex.match(s) is not None + + def match_dict(self, params: dict) -> bool: + return self.match_str(self.dict_to_str(params)) + # Note: Could do: + # return all(self.field_patterns[k].match(v) for k, v in params.items()) + # but not sure that's even quicker (given regex is compiled) + + def match_tuple(self, param_vals: tuple) -> bool: + return self.match_str(self.tuple_to_str(param_vals)) + + def match_namedtuple(self, params: namedtuple) -> bool: + return self.match_str(self.namedtuple_to_str(params)) + + def match_simple_str(self, params: str) -> bool: + return self.match_str(self.simple_str_to_str(params)) diff --git a/dol/tests/test_paths.py b/dol/tests/test_paths.py index c31de821..3381a64d 100644 --- a/dol/tests/test_paths.py +++ b/dol/tests/test_paths.py @@ -1,24 +1,51 @@ """Tests for paths.py""" +from dol import StringTemplate -def test_string_template(): + +def test_string_template_template_construction(): + assert StringTemplate('{}.ext').template == '{_1}.ext' + assert StringTemplate('{name}.ext').template == '{name}.ext' + assert StringTemplate('{::\w+}.ext').template == '{_1}.ext' + assert StringTemplate('{name::\w+}.ext').template == '{name}.ext' + assert StringTemplate('{name::\w+}.ext').template == '{name}.ext' + assert StringTemplate('{name:0.02f}.ext').template == '{name}.ext' + assert StringTemplate('{name:0.02f:\w+}.ext').template == '{name}.ext' + assert StringTemplate('{:0.02f:\w+}.ext').template == '{_1}.ext' + + +def test_string_template_regex(): + assert StringTemplate('{}.ext').regex.pattern == '(?P<_1>.*)\\.ext' + assert StringTemplate('{name}.ext').regex.pattern == '(?P.*)\\.ext' + assert StringTemplate('{::\w+}.ext').regex.pattern == '(?P<_1>\\w+)\\.ext' + assert StringTemplate('{name::\w+}.ext').regex.pattern == '(?P\\w+)\\.ext' + assert StringTemplate('{:0.02f:\w+}.ext').regex.pattern == '(?P<_1>\\w+)\\.ext' + assert StringTemplate('{name:0.02f:\w+}.ext').regex.pattern == '(?P\\w+)\\.ext' + + +def test_string_template_simple(): from dol.paths import StringTemplate from collections import namedtuple st = StringTemplate( - '{name} is {age} years old.', field_patterns={'name': r'\w+', 'age': r'\d+'} + 'root/{name}/v_{version}.json', + field_patterns={'name': r'\w+', 'version': r'\d+'}, + from_str_funcs={'version': int}, ) - assert st.str_to_dict('Alice is 30 years old.') == {'name': 'Alice', 'age': '30'} - assert st.dict_to_str({'name': 'Alice', 'age': '30'}) == 'Alice is 30 years old.' - assert st.dict_to_tuple({'name': 'Alice', 'age': '30'}) == ('Alice', '30') - assert st.tuple_to_dict(('Alice', '30')) == {'name': 'Alice', 'age': '30'} - assert st.str_to_tuple('Alice is 30 years old.') == ('Alice', '30') - assert st.tuple_to_str(('Alice', '30')) == 'Alice is 30 years old.' + assert st.str_to_dict('root/Alice/v_30.json') == {'name': 'Alice', 'version': 30} + assert st.dict_to_str({'name': 'Alice', 'version': 30}) == 'root/Alice/v_30.json' + assert st.dict_to_tuple({'name': 'Alice', 'version': 30}) == ('Alice', 30) + assert st.tuple_to_dict(('Alice', 30)) == {'name': 'Alice', 'version': 30} + assert st.str_to_tuple('root/Alice/v_30.json') == ('Alice', 30) + assert st.tuple_to_str(('Alice', 30)) == 'root/Alice/v_30.json' + + VersionedFile = st.dict_to_namedtuple({'name': 'Alice', 'version': 30}) - Person = st.dict_to_namedtuple({'name': 'Alice', 'age': '30'}) - assert Person == namedtuple('Person', ['name', 'age'])('Alice', '30') - assert st.namedtuple_to_dict(Person) == {'name': 'Alice', 'age': '30'} + from collections import namedtuple + assert VersionedFile == namedtuple('VersionedFile', ['name', 'version'])('Alice', 30) + assert st.namedtuple_to_dict(VersionedFile) == {'name': 'Alice', 'version': 30} - assert st.str_to_simple_str('Alice is 30 years old.', '-') == 'Alice-30' - assert st.simple_str_to_str('Alice-30', '-') == 'Alice is 30 years old.' + assert st.str_to_simple_str('root/Alice/v_30.json') == 'Alice,30' + assert st.str_to_simple_str('root/Alice/v_30.json', '-') == 'Alice-30' + assert st.simple_str_to_str('Alice-30', '-') == 'root/Alice/v_30.json' \ No newline at end of file diff --git a/dol/tools.py b/dol/tools.py index f3106957..71ffb5bf 100644 --- a/dol/tools.py +++ b/dol/tools.py @@ -125,7 +125,7 @@ def __missing__(self, k): user_value = value_preprocessor(user_value) self[k] = user_value else: - super().__missing__(k) + super(type(self), self).__missing__(k) store.__missing__ = __missing__ return store