diff --git a/automol/util/__init__.py b/automol/util/__init__.py index dcae22e3..6518d272 100644 --- a/automol/util/__init__.py +++ b/automol/util/__init__.py @@ -1,12 +1,10 @@ -""" common utilities used by automol -""" +"""Common utilities used by automol.""" from automol.util import dict_, heuristic, matrix, ring, tensor, vector, zmat_conv from automol.util._util import ( breakby, equivalence_partition, flatten, - translate, formula_from_symbols, is_even_permutation, is_odd_permutation, @@ -18,6 +16,7 @@ scale_iterable, separate_negatives, sort_by_list, + translate, value_similar_to, ) from automol.util.zmat_conv import ZmatConv diff --git a/automol/util/_util.py b/automol/util/_util.py index 351cebc9..e7bebe15 100644 --- a/automol/util/_util.py +++ b/automol/util/_util.py @@ -1,16 +1,15 @@ -""" miscellaneous utilities -""" +"""miscellaneous utilities.""" import itertools from collections.abc import Collection, Iterable from numbers import Number -from typing import Any, List +from typing import Any, Callable from phydat import ptab -def partner(pair: List, item: Any) -> Any: - """Get the partner of an item in a pair +def partner(pair: Any, item: Any) -> Any: + """Get the partner of an item in a pair. The two items must be distinct @@ -23,13 +22,13 @@ def partner(pair: List, item: Any) -> Any: return next(iter(pair - {item})) -def flatten(lst): - """Flatten an arbitrarily nested list of lists (iterator) +def flatten(lst: list): + """Flatten an arbitrarily nested list of lists (iterator). Source: https://stackoverflow.com/a/2158532 """ for elem in lst: - if isinstance(elem, Iterable) and not isinstance(elem, (str, bytes)): + if isinstance(elem, Iterable) and not isinstance(elem, str | bytes): yield from flatten(elem) else: yield elem @@ -38,17 +37,16 @@ def flatten(lst): def translate( seq: Collection, trans_dct: dict, drop: bool = False, item_typ: type = Number ) -> Collection: - """Translate items in a nested sequence or collection with a dictionary + """Translate items in a nested sequence or collection with a dictionary. :param seq: An arbitrarily nested sequence or collection :param trans_dct: A translation dictionary :param drop: Drop values missing from translation dictionary?, defaults to False - :param item_typ: The type of item to translate, defaults to Number - :return: _description_ + :return: Translated version of collection """ - def transform_(seq_in): - """Recursively convert a nested list of z-matrix keys to geometry keys""" + def transform_(seq_in: Collection) -> Collection: + """Recursively convert a nested list of z-matrix keys to geometry keys.""" assert isinstance(seq_in, Collection), f"Cannot process non-sequence {seq_in}" type_ = type(seq_in) @@ -64,24 +62,22 @@ def transform_(seq_in): return transform_(seq) -def is_odd_permutation(seq1: List, seq2: List): +def is_odd_permutation(seq1: list, seq2: list) -> bool: """Determine whether a permutation of a sequence is odd. - :param seq1: the first sequence - :param seq2: the second sequence, which must be a permuation of the first + :param seq1: The first sequence + :param seq2: The second sequence, which must be a permuation of the first :returns: True if the permutation is even, False if it is odd - :rtype: bool """ return not is_even_permutation(seq1, seq2) -def is_even_permutation(seq1: List, seq2: List, check: bool = True): - """Determine whether a permutation of a sequence is even or odd. +def is_even_permutation(seq1: list, seq2: list, check: bool = True): + """Determine whether a permutation of a sequence is even. - :param seq1: the first sequence - :param seq2: the second sequence, which must be a permuation of the first + :param seq1: The first sequence + :param seq2: The second sequence, which must be a permuation of the first :returns: True if the permutation is even, False if it is odd - :rtype: bool """ size = len(seq1) if check: @@ -103,16 +99,15 @@ def is_even_permutation(seq1: List, seq2: List, check: bool = True): return parity -def equivalence_partition(iterable, relation, perfect=False): - """Partitions a set of objects into equivalence classes +def equivalence_partition(iterable: Collection, relation: Callable [[Any,Any], bool], perfect:bool =False) -> list: + """Partitions a set of objects into equivalence classes. canned function taken from https://stackoverflow.com/a/38924631 - Args: - iterable: collection of objects to be partitioned - relation: equivalence relation. I.e. relation(o1,o2) evaluates to True - if and only if o1 and o2 are equivalent - perfect: is this a perfect equivalence relation, where a = c and b = c + :param iterable: Collection of objects to be partitioned + :param relation: equivalence relation. I.e. relation(o1,o2) evaluates to True + if and only if o1 and o2 are equivalent + :param perfect: is this a perfect equivalence relation, where a = c and b = c guarantees a = b? if not, an extra search is performed to make sure that a, b, and c still end up in the same class @@ -132,7 +127,7 @@ def equivalence_partition(iterable, relation, perfect=False): found = True break if not found: # it is in a new class - classes.append(set([obj])) + classes.append({obj}) # 2. Now, account for the possibility of 'imperfect' equivalence relations, # where the relation gives a = c and b = c, but not a = b, and yet we still @@ -156,28 +151,24 @@ def equivalence_partition(iterable, relation, perfect=False): # Useful functions on Python objects -def move_item_to_front(lst, item): +def move_item_to_front(lst: list | tuple, item) -> tuple: """Move an item to the front of a list. :param lst: the list - :type lst: list or tuple :param item: the item, which must be in `lst` :returns: the list, with the item moved to front - :rtype: tuple """ lst = list(lst) lst.insert(0, lst.pop(lst.index(item))) return tuple(lst) -def move_item_to_end(lst, item): +def move_item_to_end(lst: list | tuple, item) -> tuple: """Move an item to the end of a list. :param lst: the list - :type lst: list or tuple :param item: the item, which must be in `lst` :returns: the list, with the item moved to end - :rtype: tuple """ lst = list(lst) lst.append(lst.pop(lst.index(item))) @@ -211,8 +202,7 @@ def breakby(lst, elem): def separate_negatives(lst): - """Seperate a list of numbers into negative and nonnegative (>= 0)""" - + """Seperate a list of numbers into negative and nonnegative (>= 0).""" neg_lst = tuple(val for val in lst if val < 0) pos_lst = tuple(val for val in lst if val >= 0) @@ -220,15 +210,14 @@ def separate_negatives(lst): def value_similar_to(val, lst, thresh): - """Check if a value is close to some lst of values within some threshold""" + """Check if a value is close to some lst of values within some threshold.""" return any(abs(val - vali) < thresh for vali in lst) def scale_iterable(iterable, scale_factor): - """Scale some type of iterable of floats by a scale factor""" - + """Scale some type of iterable of floats by a scale factor.""" if isinstance(iterable, list): - iterable = list(val * scale_factor for val in iterable) + iterable = [val * scale_factor for val in iterable] elif isinstance(iterable, tuple): iterable = tuple(val * scale_factor for val in iterable) @@ -238,14 +227,14 @@ def scale_iterable(iterable, scale_factor): def remove_duplicates_with_order(lst): """Remove all duplicates of a list while not reordering the list.""" if isinstance(lst, list): - lst = list(n for i, n in enumerate(lst) if n not in lst[:i]) + lst = [n for i, n in enumerate(lst) if n not in lst[:i]] if isinstance(lst, tuple): lst = tuple(n for i, n in enumerate(lst) if n not in lst[:i]) return lst -def sort_by_list(lst, ref_lst, include_missing=True): +def sort_by_list(lst: tuple, ref_lst: tuple, include_missing=True) -> tuple: """Order the elements of the list by using the priorities given by some reference lst. @@ -258,18 +247,14 @@ def sort_by_list(lst, ref_lst, include_missing=True): dropped if the user specifies not to include it. :param lst: list to sort - :type lst: tuple :param ref_lst: list which sets the order of the previous list - :type ref_lst: tuple - :rtype: tuple """ - # Split input list by elements in and not in reference list x_in_ref = tuple(x for x in lst if x in ref_lst) x_missing = tuple(x for x in lst if x not in ref_lst) # Sorted list of elements in th reference - sort_lst = tuple(sorted(list(x_in_ref), key=lambda x: ref_lst.index(x))) + sort_lst = tuple(sorted(x_in_ref, key=lambda x: ref_lst.index(x))) # If request append the missing elements if include_missing: @@ -278,29 +263,24 @@ def sort_by_list(lst, ref_lst, include_missing=True): return sort_lst -def formula_from_symbols(symbs): +def formula_from_symbols(symbs: tuple[str]) -> str: """Build a molecular formula from a list of atomic symbols. (note: dummy atoms will be filtered out and cases will be standardized) :param symbs: atomic symbols - :type symbs: tuple(str) - :rtype: str """ - symbs = list(filter(ptab.to_number, map(ptab.to_symbol, symbs))) return _unique_item_counts(symbs) -def _unique_item_counts(iterable): +def _unique_item_counts(iterable: Iterable) -> dict[object:int]: """Build a dictionary giving the count of each unique item in a sequence. :param iterable: sequence to obtain counts for :type iterable: iterable object - :rtype: dict[obj: int] """ - items = tuple(iterable) return {item: items.count(item) for item in sorted(set(items))} diff --git a/docs/source/automol/util.md b/docs/source/automol/util.md new file mode 100644 index 00000000..56ae457c --- /dev/null +++ b/docs/source/automol/util.md @@ -0,0 +1,5 @@ +# automol.util +some words +```{eval-rst} +.. automodule:: automol.util +``` \ No newline at end of file diff --git a/docs/source/index.md b/docs/source/index.md index cbc8c21b..c3c411f9 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -5,4 +5,5 @@ Welcome to my documentation website... ```{toctree} :hidden: automol/inchi_key.md +automol/util.md ``` diff --git a/lint.sh b/lint.sh index 76d031bf..8e6f343d 100755 --- a/lint.sh +++ b/lint.sh @@ -3,7 +3,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) FILES=( "automol/inchi_key.py" "automol/error.py" - "automol/vmat.py" + "automol/util/_util.py" + "automol/util/__init__.py" ) ( diff --git a/pyproject.toml b/pyproject.toml index ac6e373e..63bbe338 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ extend-ignore = [ "D413", # Missing blank line after last section "D416", # Section name should end with a colon "N806", # Variable in function should be lowercase + "C901", # Too complex ] [tool.mypy]