Skip to content

Commit

Permalink
Merge pull request #522 from Rosalbam1/dev
Browse files Browse the repository at this point in the history
util clean
  • Loading branch information
avcopan authored Jul 8, 2024
2 parents aa4a2e0 + 371af5f commit a06ea7e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 60 deletions.
5 changes: 2 additions & 3 deletions automol/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
""" common utilities used by automol
"""
"""Common utilities used by automol."""

from automol.util import dict_, heuristic, matrix, ring, tensor, vector, zmat_conv
from automol.util._util import (
breakby,
equivalence_partition,
flatten,
translate,
formula_from_symbols,
is_even_permutation,
is_odd_permutation,
Expand All @@ -18,6 +16,7 @@
scale_iterable,
separate_negatives,
sort_by_list,
translate,
value_similar_to,
)
from automol.util.zmat_conv import ZmatConv
Expand Down
92 changes: 36 additions & 56 deletions automol/util/_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
""" miscellaneous utilities
"""
"""miscellaneous utilities."""

import itertools
from collections.abc import Collection, Iterable
from numbers import Number
from typing import Any, List
from typing import Any, Callable

from phydat import ptab


def partner(pair: List, item: Any) -> Any:
"""Get the partner of an item in a pair
def partner(pair: Any, item: Any) -> Any:
"""Get the partner of an item in a pair.
The two items must be distinct
Expand All @@ -23,13 +22,13 @@ def partner(pair: List, item: Any) -> Any:
return next(iter(pair - {item}))


def flatten(lst):
"""Flatten an arbitrarily nested list of lists (iterator)
def flatten(lst: list):
"""Flatten an arbitrarily nested list of lists (iterator).
Source: https://stackoverflow.com/a/2158532
"""
for elem in lst:
if isinstance(elem, Iterable) and not isinstance(elem, (str, bytes)):
if isinstance(elem, Iterable) and not isinstance(elem, str | bytes):
yield from flatten(elem)
else:
yield elem
Expand All @@ -38,17 +37,16 @@ def flatten(lst):
def translate(
seq: Collection, trans_dct: dict, drop: bool = False, item_typ: type = Number
) -> Collection:
"""Translate items in a nested sequence or collection with a dictionary
"""Translate items in a nested sequence or collection with a dictionary.
:param seq: An arbitrarily nested sequence or collection
:param trans_dct: A translation dictionary
:param drop: Drop values missing from translation dictionary?, defaults to False
:param item_typ: The type of item to translate, defaults to Number
:return: _description_
:return: Translated version of collection
"""

def transform_(seq_in):
"""Recursively convert a nested list of z-matrix keys to geometry keys"""
def transform_(seq_in: Collection) -> Collection:
"""Recursively convert a nested list of z-matrix keys to geometry keys."""
assert isinstance(seq_in, Collection), f"Cannot process non-sequence {seq_in}"
type_ = type(seq_in)

Expand All @@ -64,24 +62,22 @@ def transform_(seq_in):
return transform_(seq)


def is_odd_permutation(seq1: List, seq2: List):
def is_odd_permutation(seq1: list, seq2: list) -> bool:
"""Determine whether a permutation of a sequence is odd.
:param seq1: the first sequence
:param seq2: the second sequence, which must be a permuation of the first
:param seq1: The first sequence
:param seq2: The second sequence, which must be a permuation of the first
:returns: True if the permutation is even, False if it is odd
:rtype: bool
"""
return not is_even_permutation(seq1, seq2)


def is_even_permutation(seq1: List, seq2: List, check: bool = True):
"""Determine whether a permutation of a sequence is even or odd.
def is_even_permutation(seq1: list, seq2: list, check: bool = True):
"""Determine whether a permutation of a sequence is even.
:param seq1: the first sequence
:param seq2: the second sequence, which must be a permuation of the first
:param seq1: The first sequence
:param seq2: The second sequence, which must be a permuation of the first
:returns: True if the permutation is even, False if it is odd
:rtype: bool
"""
size = len(seq1)
if check:
Expand All @@ -103,16 +99,15 @@ def is_even_permutation(seq1: List, seq2: List, check: bool = True):
return parity


def equivalence_partition(iterable, relation, perfect=False):
"""Partitions a set of objects into equivalence classes
def equivalence_partition(iterable: Collection, relation: Callable [[Any,Any], bool], perfect:bool =False) -> list:
"""Partitions a set of objects into equivalence classes.
canned function taken from https://stackoverflow.com/a/38924631
Args:
iterable: collection of objects to be partitioned
relation: equivalence relation. I.e. relation(o1,o2) evaluates to True
if and only if o1 and o2 are equivalent
perfect: is this a perfect equivalence relation, where a = c and b = c
:param iterable: Collection of objects to be partitioned
:param relation: equivalence relation. I.e. relation(o1,o2) evaluates to True
if and only if o1 and o2 are equivalent
:param perfect: is this a perfect equivalence relation, where a = c and b = c
guarantees a = b? if not, an extra search is performed to make sure
that a, b, and c still end up in the same class
Expand All @@ -132,7 +127,7 @@ def equivalence_partition(iterable, relation, perfect=False):
found = True
break
if not found: # it is in a new class
classes.append(set([obj]))
classes.append({obj})

# 2. Now, account for the possibility of 'imperfect' equivalence relations,
# where the relation gives a = c and b = c, but not a = b, and yet we still
Expand All @@ -156,28 +151,24 @@ def equivalence_partition(iterable, relation, perfect=False):


# Useful functions on Python objects
def move_item_to_front(lst, item):
def move_item_to_front(lst: list | tuple, item) -> tuple:
"""Move an item to the front of a list.
:param lst: the list
:type lst: list or tuple
:param item: the item, which must be in `lst`
:returns: the list, with the item moved to front
:rtype: tuple
"""
lst = list(lst)
lst.insert(0, lst.pop(lst.index(item)))
return tuple(lst)


def move_item_to_end(lst, item):
def move_item_to_end(lst: list | tuple, item) -> tuple:
"""Move an item to the end of a list.
:param lst: the list
:type lst: list or tuple
:param item: the item, which must be in `lst`
:returns: the list, with the item moved to end
:rtype: tuple
"""
lst = list(lst)
lst.append(lst.pop(lst.index(item)))
Expand Down Expand Up @@ -211,24 +202,22 @@ def breakby(lst, elem):


def separate_negatives(lst):
"""Seperate a list of numbers into negative and nonnegative (>= 0)"""

"""Seperate a list of numbers into negative and nonnegative (>= 0)."""
neg_lst = tuple(val for val in lst if val < 0)
pos_lst = tuple(val for val in lst if val >= 0)

return neg_lst, pos_lst


def value_similar_to(val, lst, thresh):
"""Check if a value is close to some lst of values within some threshold"""
"""Check if a value is close to some lst of values within some threshold."""
return any(abs(val - vali) < thresh for vali in lst)


def scale_iterable(iterable, scale_factor):
"""Scale some type of iterable of floats by a scale factor"""

"""Scale some type of iterable of floats by a scale factor."""
if isinstance(iterable, list):
iterable = list(val * scale_factor for val in iterable)
iterable = [val * scale_factor for val in iterable]
elif isinstance(iterable, tuple):
iterable = tuple(val * scale_factor for val in iterable)

Expand All @@ -238,14 +227,14 @@ def scale_iterable(iterable, scale_factor):
def remove_duplicates_with_order(lst):
"""Remove all duplicates of a list while not reordering the list."""
if isinstance(lst, list):
lst = list(n for i, n in enumerate(lst) if n not in lst[:i])
lst = [n for i, n in enumerate(lst) if n not in lst[:i]]
if isinstance(lst, tuple):
lst = tuple(n for i, n in enumerate(lst) if n not in lst[:i])

return lst


def sort_by_list(lst, ref_lst, include_missing=True):
def sort_by_list(lst: tuple, ref_lst: tuple, include_missing=True) -> tuple:
"""Order the elements of the list by using the priorities given
by some reference lst.
Expand All @@ -258,18 +247,14 @@ def sort_by_list(lst, ref_lst, include_missing=True):
dropped if the user specifies not to include it.
:param lst: list to sort
:type lst: tuple
:param ref_lst: list which sets the order of the previous list
:type ref_lst: tuple
:rtype: tuple
"""

# Split input list by elements in and not in reference list
x_in_ref = tuple(x for x in lst if x in ref_lst)
x_missing = tuple(x for x in lst if x not in ref_lst)

# Sorted list of elements in th reference
sort_lst = tuple(sorted(list(x_in_ref), key=lambda x: ref_lst.index(x)))
sort_lst = tuple(sorted(x_in_ref, key=lambda x: ref_lst.index(x)))

# If request append the missing elements
if include_missing:
Expand All @@ -278,29 +263,24 @@ def sort_by_list(lst, ref_lst, include_missing=True):
return sort_lst


def formula_from_symbols(symbs):
def formula_from_symbols(symbs: tuple[str]) -> str:
"""Build a molecular formula from a list of atomic symbols.
(note: dummy atoms will be filtered out and cases will be standardized)
:param symbs: atomic symbols
:type symbs: tuple(str)
:rtype: str
"""

symbs = list(filter(ptab.to_number, map(ptab.to_symbol, symbs)))

return _unique_item_counts(symbs)


def _unique_item_counts(iterable):
def _unique_item_counts(iterable: Iterable) -> dict[object:int]:
"""Build a dictionary giving the count of each unique item in a sequence.
:param iterable: sequence to obtain counts for
:type iterable: iterable object
:rtype: dict[obj: int]
"""

items = tuple(iterable)

return {item: items.count(item) for item in sorted(set(items))}
5 changes: 5 additions & 0 deletions docs/source/automol/util.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# automol.util
some words
```{eval-rst}
.. automodule:: automol.util
```
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Welcome to my documentation website...
```{toctree}
:hidden:
automol/inchi_key.md
automol/util.md
```
3 changes: 2 additions & 1 deletion lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
FILES=(
"automol/inchi_key.py"
"automol/error.py"
"automol/vmat.py"
"automol/util/_util.py"
"automol/util/__init__.py"
)

(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ extend-ignore = [
"D413", # Missing blank line after last section
"D416", # Section name should end with a colon
"N806", # Variable in function should be lowercase
"C901", # Too complex
]

[tool.mypy]
Expand Down

0 comments on commit a06ea7e

Please sign in to comment.