Skip to content

Commit

Permalink
add ruff changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjonesBSU committed May 8, 2024
1 parent 0245f75 commit 6ca7d50
Show file tree
Hide file tree
Showing 30 changed files with 193 additions and 540 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: pytest

on:
push:
# Action will run when any changes to these paths are pushed or pr'ed to master
branches: [ main ]
paths:
- flowermd/**
Expand All @@ -26,6 +25,7 @@ on:

jobs:
pytest:
if: github.event.pull_request.draft == false
strategy:
fail-fast: false
matrix:
Expand Down
29 changes: 6 additions & 23 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
ci:
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: [ ]
submodules: false

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2 # Ruff version
hooks:
- id: ruff
args: [--line-length=80, --fix, --extend-ignore=E203]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: 'flowermd/tests/assets/.* | flowermd/assets/.*'
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
args: [ --line-length=80 ]
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
Expand All @@ -30,19 +29,3 @@ repos:
args:
[ --profile=black, --line-length=80 ]
exclude: 'flowermd/tests/assets/.* '

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
args:
- --max-line-length=80
- --extend-ignore=E203
exclude: '__init__.py'

- repo: https://github.com/pycqa/pydocstyle
rev: '6.3.0'
hooks:
- id: pydocstyle
exclude: ^(flowermd/tests/|flowermd/internal/|flowermd/utils|setup.py|flowermd/__version__.py|docs/)
args: [ --convention=numpy ]
4 changes: 1 addition & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import sys

project = "flowerMD"
copyright = (
"2023, Chris Jones, Marjan Albooyeh, Rainier Barrett, Eric Jankowski"
)
copyright = "2023, Chris Jones, Marjan Albooyeh, Rainier Barrett, Eric Jankowski"
author = "Chris Jones, Marjan Albooyeh, Rainier Barrett, Eric Jankowski"

sys.path.insert(0, os.path.abspath("../.."))
Expand Down
8 changes: 2 additions & 6 deletions flowermd/base/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def __init__(self, forcefield_files=None, name=None):
super(BaseXMLForcefield, self).__init__(
forcefield_files=forcefield_files, name=name
)
self.gmso_ff = (
ffutils.FoyerFFs().load(forcefield_files or name).to_gmso_ff()
)
self.gmso_ff = ffutils.FoyerFFs().load(forcefield_files or name).to_gmso_ff()


class BaseHOOMDForcefield:
Expand All @@ -22,8 +20,6 @@ class BaseHOOMDForcefield:
def __init__(self, hoomd_forces):
self.hoomd_forces = hoomd_forces
if hoomd_forces is None:
raise NotImplementedError(
"`hoomd_forces` must be defined in the subclass."
)
raise NotImplementedError("`hoomd_forces` must be defined in the subclass.")
if not isinstance(hoomd_forces, list):
raise TypeError("`hoomd_forces` must be a list.")
29 changes: 9 additions & 20 deletions flowermd/base/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ def _load(self):
return mb.load(self.smiles, smiles=True)
else:
raise MoleculeLoadError(
msg=f"Unable to load the molecule from smiles "
f"{self.smiles}."
msg=f"Unable to load the molecule from smiles " f"{self.smiles}."
)

def _align_backbones_z_axis(self, heavy_atoms_only=False):
Expand All @@ -214,11 +213,7 @@ def _align_backbones_z_axis(self, heavy_atoms_only=False):
if heavy_atoms_only:
try:
positions = np.array(
[
p.xyz[0]
for p in mol.particles()
if p.element.symbol != "H"
]
[p.xyz[0] for p in mol.particles() if p.element.symbol != "H"]
)
except AttributeError:
positions = mol.xyz
Expand Down Expand Up @@ -277,9 +272,7 @@ def _identify_particle_information(self, gmso_molecule):
):
self.hydrogen_types.append(p_name)
self.particle_typeid.append(self.particle_types.index(p_name))
self.particle_charge.append(
site.charge.to_value() if site.charge else 0
)
self.particle_charge.append(site.charge.to_value() if site.charge else 0)

def _identify_pairs(self, particle_types):
"""Identify all unique particle pairs from the particle types.
Expand All @@ -290,9 +283,7 @@ def _identify_pairs(self, particle_types):
List of all particle types.
"""
self.pairs = set(
itertools.combinations_with_replacement(particle_types, 2)
)
self.pairs = set(itertools.combinations_with_replacement(particle_types, 2))

def _identify_bond_types(self, gmso_molecule):
"""Identify all unique bond types from the GMSO topology.
Expand All @@ -314,7 +305,7 @@ def _identify_bond_types(self, gmso_molecule):
or bond.connection_members[1].name
)
bond_connections = [p1_name, p2_name]
if not tuple(bond_connections[::-1]) in self.bond_types:
if tuple(bond_connections[::-1]) not in self.bond_types:
self.bond_types.add(tuple(bond_connections))

def _identify_angle_types(self, gmso_molecule):
Expand All @@ -341,7 +332,7 @@ def _identify_angle_types(self, gmso_molecule):
or angle.connection_members[2].name
)
angle_connections = [p1_name, p2_name, p3_name]
if not tuple(angle_connections[::-1]) in self.angle_types:
if tuple(angle_connections[::-1]) not in self.angle_types:
self.angle_types.add(tuple(angle_connections))

def _identify_dihedral_types(self, gmso_molecule):
Expand Down Expand Up @@ -372,7 +363,7 @@ def _identify_dihedral_types(self, gmso_molecule):
or dihedral.connection_members[3].name
)
dihedral_connections = [p1_name, p2_name, p3_name, p4_name]
if not tuple(dihedral_connections[::-1]) in self.dihedral_types:
if tuple(dihedral_connections[::-1]) not in self.dihedral_types:
self.dihedral_types.add(tuple(dihedral_connections))

def _identify_improper_types(self, gmso_molecule):
Expand Down Expand Up @@ -403,7 +394,7 @@ def _identify_improper_types(self, gmso_molecule):
or improper.connection_members[3].name
)
improper_connections = [p1_name, p2_name, p3_name, p4_name]
if not tuple(improper_connections[::-1]) in self.improper_types:
if tuple(improper_connections[::-1]) not in self.improper_types:
self.improper_types.add(tuple(improper_connections))

def _identify_topology_information(self, gmso_molecule):
Expand Down Expand Up @@ -435,9 +426,7 @@ def _validate_force_field(self):
# Update topology information from typed gmso after applying ff.
self._identify_topology_information(self.gmso_molecule)
elif isinstance(self.force_field, BaseHOOMDForcefield):
_validate_hoomd_ff(
self.force_field.hoomd_forces, self.topology_information
)
_validate_hoomd_ff(self.force_field.hoomd_forces, self.topology_information)
elif isinstance(self.force_field, List):
_validate_hoomd_ff(self.force_field, self.topology_information)
else:
Expand Down
47 changes: 12 additions & 35 deletions flowermd/base/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def __init__(
):
if not isinstance(forcefield, Iterable) or isinstance(forcefield, str):
raise ValueError(
"forcefield must be a sequence of "
"hoomd.md.force.Force objects."
"forcefield must be a sequence of " "hoomd.md.force.Force objects."
)
else:
for obj in forcefield:
Expand All @@ -87,9 +86,7 @@ def __init__(
self.gsd_write_freq = int(gsd_write_freq)
self.maximum_write_buffer_size = gsd_max_buffer_size
self.log_write_freq = int(log_write_freq)
self._std_out_freq = int(
(self.gsd_write_freq + self.log_write_freq) / 2
)
self._std_out_freq = int((self.gsd_write_freq + self.log_write_freq) / 2)
self.gsd_file_name = gsd_file_name
self.log_file_name = log_file_name
self.log_quantities = [
Expand Down Expand Up @@ -436,9 +433,7 @@ def thermostat(self, thermostat):
thermostat : flowermd.utils.HOOMDThermostats, required
The type of thermostat to use.
"""
if not issubclass(
self._thermostat, hoomd.md.methods.thermostats.Thermostat
):
if not issubclass(self._thermostat, hoomd.md.methods.thermostats.Thermostat):
raise ValueError(
f"Invalid thermostat. Please choose from: {HOOMDThermostats}"
)
Expand Down Expand Up @@ -527,9 +522,7 @@ def _initialize_thermostat(self, thermostat_kwargs):
required_thermostat_kwargs = {}
for k in inspect.signature(self.thermostat).parameters:
if k not in thermostat_kwargs.keys():
raise ValueError(
f"Missing required parameter {k} for thermostat."
)
raise ValueError(f"Missing required parameter {k} for thermostat.")
required_thermostat_kwargs[k] = thermostat_kwargs[k]
return self.thermostat(**required_thermostat_kwargs)

Expand All @@ -551,9 +544,7 @@ def set_integrator_method(self, integrator_method, method_kwargs):
if not self.integrator: # Integrator and method not yet created
self.integrator = hoomd.md.Integrator(
dt=self.dt,
integrate_rotational_dof=(
True if self._rigid_constraint else False
),
integrate_rotational_dof=(True if self._rigid_constraint else False),
)
if self._rigid_constraint:
self.integrator.rigid = self._rigid_constraint
Expand Down Expand Up @@ -715,9 +706,7 @@ def run_update_volume(
self.set_integrator_method(
integrator_method=hoomd.md.methods.ConstantVolume,
method_kwargs={
"thermostat": self._initialize_thermostat(
{"kT": kT, "tau": tau_kt}
),
"thermostat": self._initialize_thermostat({"kT": kT, "tau": tau_kt}),
"filter": self.integrate_group,
},
)
Expand Down Expand Up @@ -849,9 +838,7 @@ def run_NPT(
"rescale_all": rescale_all,
"gamma": gamma,
"filter": self.integrate_group,
"thermostat": self._initialize_thermostat(
{"kT": kT, "tau": tau_kt}
),
"thermostat": self._initialize_thermostat({"kT": kT, "tau": tau_kt}),
},
)
if thermalize_particles:
Expand Down Expand Up @@ -894,9 +881,7 @@ def run_NVT(
self.set_integrator_method(
integrator_method=hoomd.md.methods.ConstantVolume,
method_kwargs={
"thermostat": self._initialize_thermostat(
{"kT": kT, "tau": tau_kt}
),
"thermostat": self._initialize_thermostat({"kT": kT, "tau": tau_kt}),
"filter": self.integrate_group,
},
)
Expand Down Expand Up @@ -1105,17 +1090,13 @@ def _thermalize_system(self, kT):
filter=self.integrate_group, kT=kT.range[0]
)
else:
self.state.thermalize_particle_momenta(
filter=self.integrate_group, kT=kT
)
self.state.thermalize_particle_momenta(filter=self.integrate_group, kT=kT)

def _lj_force(self):
"""Return the Lennard-Jones pair force."""
if not self.integrator:
lj_force = [
f
for f in self._forcefield
if isinstance(f, hoomd.md.pair.pair.LJ)
f for f in self._forcefield if isinstance(f, hoomd.md.pair.pair.LJ)
][0]
else:
lj_force = [
Expand All @@ -1141,19 +1122,15 @@ def _create_state(self, initial_state):
print("Initializing simulation state from a GSD file.")
self.create_state_from_gsd(initial_state)
elif isinstance(initial_state, hoomd.snapshot.Snapshot):
print(
"Initializing simulation state from a hoomd.snapshot.Snapshot"
)
print("Initializing simulation state from a hoomd.snapshot.Snapshot")
self.create_state_from_snapshot(initial_state)
elif isinstance(initial_state, gsd.hoomd.Frame):
print("Initializing simulation state from a gsd.hoomd.Frame.")
self.create_state_from_snapshot(initial_state)

def _add_hoomd_writers(self):
"""Create gsd and log writers."""
gsd_logger = hoomd.logging.Logger(
categories=["scalar", "string", "sequence"]
)
gsd_logger = hoomd.logging.Logger(categories=["scalar", "string", "sequence"])
logger = hoomd.logging.Logger(categories=["scalar", "string"])
gsd_logger.add(self, quantities=["timestep", "tps"])
logger.add(self, quantities=["timestep", "tps"])
Expand Down
Loading

0 comments on commit 6ca7d50

Please sign in to comment.