diff --git a/dol/paths.py b/dol/paths.py index 0bd64006..3c29d3bf 100644 --- a/dol/paths.py +++ b/dol/paths.py @@ -101,7 +101,13 @@ def _path_get( except caught_errors as error: if callable(on_error): return on_error( - dict(obj=obj, path=path, result=result, k=k, error=error,) + dict( + obj=obj, + path=path, + result=result, + k=k, + error=error, + ) ) elif isinstance(on_error, str): # use on_error as a message, raising the same error class @@ -429,8 +435,8 @@ def path_filter( >>> vals [42, 'meaning of life'] - Note: pkv_filt is first to match the order of the arguments of the - builtin filter function. + Note: pkv_filt is first to match the order of the arguments of the + builtin filter function. """ _leaf_yield = partial(_path_matcher_leaf_yield, pkv_filt, None) kwargs = dict(leaf_yield=_leaf_yield, breadth_first=breadth_first) @@ -652,7 +658,11 @@ def __init__(self, _prefix=''): @store_decorator def mk_relative_path_store( - store_cls=None, *, name=None, with_key_validation=False, prefix_attr='_prefix', + store_cls=None, + *, + name=None, + with_key_validation=False, + prefix_attr='_prefix', ): """ @@ -1097,6 +1107,51 @@ def _func(self, *args, **kwargs): return _func +from typing import Iterable, Tuple + +string_formatter = string.Formatter() + + +def string_unparse(parsing_result: Iterable[Tuple[str, str, str, str]]): + """The inverse of string.Formatter.parse + + Will ravel + + >>> import string + >>> formatter = string.Formatter() + >>> string_unparse(formatter.parse('literal{name!c:spec}')) + 'literal{name!c:spec}' + """ + reconstructed = '' + for literal_text, field_name, format_spec, conversion in parsing_result: + reconstructed += literal_text + if field_name is not None: + field = f'{{{field_name}' + if conversion: + assert ( + len(conversion) == 1 + ), f"conversion can only be a single character: {conversion=}" + field += f'!{conversion}' + if format_spec: + field += f':{format_spec}' + field += '}' + reconstructed += field + return reconstructed + + +def _field_names(string_template): + """ + Returns the field names in a string template. + + >>> _field_names("{name} is {age} years old.") + ('name', 'age') + """ + parsing_result = string_formatter.parse(string_template) + return tuple( + field_name for _, field_name, _, _ in parsing_result if field_name is not None + ) + + Codec = namedtuple('Codec', 'encoder decoder') FieldTypeNames = Literal['str', 'dict', 'tuple', 'namedtuple', 'simple_str'] @@ -1147,6 +1202,8 @@ class StringTemplate: ('Alice', '30') """ + _formatter = string_formatter + def __init__( self, template: str, @@ -1154,13 +1211,48 @@ def __init__( field_patterns: dict = None, simple_str_sep: str = None, namedtuple_type_name: str = 'NamedTuple', + dflt_pattern: str = '.*', ): - self.template = template + self._original_template = template self.field_patterns = field_patterns or {} self.simple_str_sep = simple_str_sep self.namedtuple_type_name = namedtuple_type_name - self.regex = None - self._construct_regex() + self.dflt_pattern = dflt_pattern + self.template = self._normalize_template(template) + self.field_names = _field_names(self.template) + self.regex = self._compile_regex(self.template) + + def _normalize_template(self, template): + r"""Normalizes the template so that each placeholder has a name and format_spec. + + >>> StringTemplate('{}.ext').template + '{_1:.*}.ext' + >>> StringTemplate('{name}.ext').template + '{name:.*}.ext' + >>> StringTemplate('{:\w+}.ext').template + '{_1:\w+}.ext' + >>> StringTemplate('{name:\w+}.ext').template + '{name:\w+}.ext' + + """ + + def get_format_spec(field_name, current_format_spec): + if current_format_spec == '': + return self.field_patterns.get(field_name, self.dflt_pattern) + return current_format_spec + + def parse_and_transform(): + for index, (literal_text, field_name, format_spec, conversion) in enumerate( + self._formatter.parse(template), 1 + ): + yield ( + literal_text, + f"_{index}" if field_name == '' else field_name, + get_format_spec(field_name, format_spec), + conversion, + ) + + return string_unparse(parse_and_transform()) # TODO: For now we harded coded the "interpret format_specs as field patterns", # but we could make this more general and allow for other types of @@ -1168,24 +1260,44 @@ def __init__( # Note: We removed the cast functions (that existed in legacy StrTupleDict class) # to keep it simple, for now. The idea being that if you want to cast, you can # do it yourself by adding egress/ingress to codecs. - def _construct_regex(self): - formatter = string.Formatter() - pattern = self.template - self.field_names = [] - for literal_text, field_name, format_spec, conversion in formatter.parse( - self.template - ): - # Check if the field_name has either a format_spec (regex) in the template - # or a matching regex in the field_patterns dictionary before adding it - # to the field_names list. - if field_name and (format_spec or field_name in self.field_patterns): - self.field_names.append(field_name) - regex = format_spec or self.field_patterns.get(field_name, '.*?') - to_replace = ( - '{' + field_name + (':' + format_spec if format_spec else '') + '}' + def _compile_regex(self, template): + r"""Parses the template, generating regex for matching the template. + Essentially, it weaves together the literal text parts and the format_specs + parts, transformed into name-caputuring regex patterns. + + Note that the literal text parts are regex-escaped so that they are not + interpreted as regex. For example, if the template is "{name}.txt", the + literal text part is replaced with "\\.txt", to avoid that the "." is + interpreted as a regex wildcard. This would otherwise match any character. + Instead, the escaped dot is matched literally. + See https://docs.python.org/3/library/re.html#re.escape for more information. + + >>> StringTemplate('{}.ext').regex.pattern + '(?P<_1>.*)\\.ext' + >>> StringTemplate('{name}.ext').regex.pattern + '(?P.*)\\.ext' + >>> StringTemplate('{:\w+}.ext').regex.pattern + '(?P<_1>\\w+)\\.ext' + >>> StringTemplate('{name:\w+}.ext').regex.pattern + '(?P\\w+)\\.ext' + + """ + + def mk_named_capture_group(field_name, format_spec): + if field_name and format_spec: + return f"(?P<{field_name}>{format_spec})" + else: + return "" + + def generate_pattern_parts(template): + parts = self._formatter.parse(template) + for literal_text, field_name, format_spec, conversion in parts: + yield ( + re.escape(literal_text) + + mk_named_capture_group(field_name, format_spec) ) - pattern = pattern.replace(to_replace, f'(?P<{field_name}>{regex})') - self.regex = re.compile(pattern) + + return re.compile(''.join(generate_pattern_parts(template))) def codec(self, source: FieldTypeNames, target: FieldTypeNames): """Makes a ``(coder, decoder)`` pair for the given source and target types. @@ -1304,7 +1416,8 @@ def tuple_to_str(self, params: tuple) -> str: # @_return_none_if_none_input def dict_to_namedtuple( - self, params: dict, + self, + params: dict, ): """Generates a namedtuple from the dictionary values based on the template.