Skip to content

Commit 7162271

Browse files
committed
refact:modules&bandit checks
1 parent 4a85621 commit 7162271

File tree

13 files changed

+250
-224
lines changed

13 files changed

+250
-224
lines changed

.bandit

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[bandit]
2+
exclude = tests
3+
skips = B103,B607,B603,B101,B404,B311

.devcontainer/Dockerfile

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ ENV https_proxy=http://a100-internal.yaoyy.moe:10089
99
ENV all_proxy=http://a100-internal.yaoyy.moe:10089
1010
ENV GITHUB_ROSETTA_TEST=YES
1111

12-
RUN apt update -y && apt install git curl wget -y
12+
RUN apt-get update -y && apt-get install git curl wget -y
1313

1414
RUN python -m pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple \
15-
&& python -m pip install --upgrade pip \
16-
&& python -m pip install 'flit>=3.8.0'
15+
&& python -m pip install --no-cache-dir --upgrade pip \
16+
&& python -m pip install --no-cache-dir 'flit>=3.8.0'
1717

1818
ENV FLIT_ROOT_INSTALL=1
1919

2020
COPY pyproject.toml .
2121
RUN touch README.md \
2222
&& mkdir -p src/RosettaPy \
23-
&& python -m flit install --only-deps --deps develop \
23+
&& python -m flit install --no-cache-dir --only-deps --deps develop \
2424
&& rm -r pyproject.toml README.md src

.github/workflows/CI.yml

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Python CI
22
on:
33
push:
4-
branches: [ main ]
4+
branches: [main]
55
pull_request:
6-
branches: [ main ]
6+
branches: [main]
77
release:
88
types: [created]
99
workflow_dispatch:
@@ -19,16 +19,15 @@ jobs:
1919
- "3.10"
2020
- "3.11"
2121
- "3.12"
22+
- "3.13"
2223

2324
uses: YaoYinYing/action-python/.github/workflows/[email protected]
2425
with:
25-
workdir: '.'
26+
workdir: "."
2627
python-version: ${{ matrix.python }}
2728
secrets:
2829
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
2930

30-
31-
3231
publish:
3332
strategy:
3433
fail-fast: false

.github/workflows/RosettaCI.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
run: |
4545
apt update -y
4646
apt install gnupg2 git -y
47-
pip install '.[test,wrapper]' -U
47+
pip install '.[test]' -U
4848
4949
- name: Run test cases
5050
run: |

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Examples of valid binary filenames:
8383

8484
## Installation
8585

86-
Ensure you have Python 3.6 or higher installed.
86+
Ensure you have Python 3.8 or higher installed.
8787

8888
### Install via PyPI
8989

pyproject.toml

+10-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "RosettaPy"
77
authors = [
88
{name = "Yinying Yao", email = "[email protected]"},
99
]
10-
description = "Searching for Rosetta Binaries."
10+
description = "A Python utility for wrapping Rosetta command line tools."
1111
readme = "README.md"
1212
classifiers = [
1313
"Development Status :: 6 - Mature",
@@ -20,23 +20,26 @@ classifiers = [
2020
"Programming Language :: Python :: 3.9",
2121
"Programming Language :: Python :: 3.10",
2222
"Programming Language :: Python :: 3.11",
23-
"Programming Language :: Python :: 3.12"
23+
"Programming Language :: Python :: 3.12",
24+
"Programming Language :: Python :: 3.13"
2425
]
2526
requires-python = ">=3.8"
2627
dynamic = ["version"]
2728

28-
[project.optional-dependencies]
29-
spark = [
30-
"pyspark>=3.0.0"
31-
]
32-
wrapper = [
29+
dependencies = [
3330
"joblib",
3431
"absl-py",
3532
"pandas",
3633
"biopython",
3734
"rdkit",
3835
"numpy>=1.20.3,<3"
3936
]
37+
38+
[project.optional-dependencies]
39+
spark = [
40+
"pyspark>=3.0.0"
41+
]
42+
4043
test = [
4144
"bandit[toml]==1.7.10",
4245
"black==24.8.0",

src/RosettaPy/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from .rosetta_finder import RosettaBinary, RosettaFinder, main
3-
from .rosetta import Rosetta, RosettaScriptsVariableGroup, MPI_node, RosettaEnergyUnitAnalyser
3+
from .rosetta import Rosetta, RosettaScriptsVariableGroup, MPI_node
4+
from .analyser import RosettaEnergyUnitAnalyser
45
from .utils import timing, isolate
56

67
__all__ = [
@@ -15,4 +16,4 @@
1516
"RosettaEnergyUnitAnalyser",
1617
]
1718

18-
__version__ = "0.1.1"
19+
__version__ = "0.1.2"

src/RosettaPy/analyser/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .reu import RosettaEnergyUnitAnalyser
2+
3+
__all__ = ["RosettaEnergyUnitAnalyser"]

src/RosettaPy/analyser/reu.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from dataclasses import dataclass
2+
import os
3+
from typing import Dict, Literal, Optional, Tuple, Union
4+
import warnings
5+
6+
import pandas as pd
7+
8+
9+
@dataclass
10+
class RosettaEnergyUnitAnalyser:
11+
"""
12+
A tool class for analyzing Rosetta energy calculation results.
13+
14+
Parameters:
15+
- score_file (str): The path to the score file or directory containing score files.
16+
- score_term (str, optional): The column name in the score file to use as the score. Defaults to "total_score".
17+
- job_id (Optional[str], optional): An identifier for the job. Defaults to None.
18+
"""
19+
20+
score_file: str
21+
score_term: str = "total_score"
22+
23+
job_id: Optional[str] = None
24+
25+
@staticmethod
26+
def scorefile2df(score_file: str) -> pd.DataFrame:
27+
"""
28+
Converts a score file into a pandas DataFrame.
29+
30+
Parameters:
31+
- score_file (str): Path to the score file.
32+
33+
Returns:
34+
- pd.DataFrame: DataFrame containing the data from the score file.
35+
"""
36+
df = pd.read_fwf(score_file, skiprows=1)
37+
38+
if "SCORE:" in df.columns:
39+
df.drop("SCORE:", axis=1, inplace=True)
40+
41+
return df
42+
43+
def __post_init__(self):
44+
"""
45+
Initializes the DataFrame based on the provided score file or directory.
46+
"""
47+
if os.path.isfile(self.score_file):
48+
self.df = self.scorefile2df(self.score_file)
49+
elif os.path.isdir(self.score_file):
50+
dfs = [
51+
self.scorefile2df(os.path.join(self.score_file, f))
52+
for f in os.listdir(self.score_file)
53+
if f.endswith(".sc")
54+
]
55+
warnings.warn(UserWarning(f"Concatenate {len(dfs)} score files"))
56+
self.df = pd.concat(dfs, axis=0, ignore_index=True)
57+
else:
58+
raise FileNotFoundError(f"Score file {self.score_file} not found.")
59+
60+
if not self.score_term in self.df.columns:
61+
raise ValueError(f'Score term "{self.score_term}" not found in score file.')
62+
63+
@staticmethod
64+
def df2dict(dfs: pd.DataFrame, k: str = "total_score") -> Tuple[Dict[Literal["score", "decoy"], Union[str, float]]]:
65+
"""
66+
Converts a DataFrame into a tuple of dictionaries with scores and decoys.
67+
68+
Parameters:
69+
- dfs (pd.DataFrame): DataFrame containing the scores.
70+
- k (str, optional): Column name to use as the score. Defaults to "total_score".
71+
72+
Returns:
73+
- Tuple[Dict[Literal["score", "decoy"], Union[str, float]]]: Tuple of dictionaries containing scores and decoys.
74+
"""
75+
t = tuple(
76+
{
77+
"score": float(dfs[dfs.index == i][k].iloc[0]),
78+
"decoy": str(dfs[dfs.index == i]["description"].iloc[0]),
79+
}
80+
for i in dfs.index
81+
)
82+
83+
return t # type: ignore
84+
85+
@property
86+
def best_decoy(self) -> Dict[Literal["score", "decoy"], Union[str, float]]:
87+
"""
88+
Returns the best decoy based on the score term.
89+
90+
Returns:
91+
- Dict[Literal["score", "decoy"], Union[str, float]]: Dictionary containing the score and decoy of the best entry.
92+
"""
93+
if self.df.empty:
94+
return {}
95+
return self.top(1)[0]
96+
97+
def top(
98+
self, rank: int = 1, score_term: Optional[str] = None
99+
) -> Tuple[Dict[Literal["score", "decoy"], Union[str, float]]]:
100+
"""
101+
Returns the top `rank` decoys based on the specified score term.
102+
103+
Parameters:
104+
- rank (int, optional): The number of top entries to return. Defaults to 1.
105+
- score_term (Optional[str], optional): The column name to use as the score. Defaults to the class attribute `score_term`.
106+
107+
Returns:
108+
- Tuple[Dict[Literal["score", "decoy"], Union[str, float]]]: Tuple of dictionaries containing scores and decoys of the top entries.
109+
"""
110+
if rank <= 0:
111+
raise ValueError(f"Rank must be greater than 0")
112+
113+
# Override score_term if provided
114+
score_term = score_term if score_term is not None and score_term in self.df.columns else self.score_term
115+
116+
df = self.df.sort_values(
117+
by=score_term if score_term is not None and score_term in self.df.columns else self.score_term
118+
).head(rank)
119+
120+
return self.df2dict(dfs=df, k=score_term)

src/RosettaPy/app/utils/smiles2param.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def convert(self, ligands: Dict[str, str]):
160160
try:
161161
cs = Chem.CanonSmiles(ds)
162162
c_smiles.append(cs)
163-
except:
163+
except Exception:
164164
print('Invalid SMILES: %s\n%s' % (i, ds))
165165
print(c_smiles)
166166

src/RosettaPy/node/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .mpi import MPI_node
2+
3+
__all__ = ['MPI_node']

src/RosettaPy/node/mpi.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
2+
import contextlib
3+
import copy
4+
from dataclasses import dataclass
5+
import os
6+
import random
7+
import shutil
8+
import subprocess
9+
from typing import Dict, List, Optional
10+
import warnings
11+
12+
13+
class MPI_IncompatibleInputWarning(RuntimeWarning): ...
14+
15+
16+
@dataclass
17+
class MPI_node:
18+
nproc: int = 0
19+
node_matrix: Optional[Dict[str, int]] = None # Node ID: nproc
20+
node_file = f"nodefile_{random.randint(1,9_999_999_999)}.txt"
21+
22+
user = os.getuid()
23+
24+
def __post_init__(self):
25+
26+
for mpi_exec in ["mpirun", "mpicc", ...]:
27+
self.mpi_excutable = shutil.which(mpi_exec)
28+
if self.mpi_excutable is not None:
29+
break
30+
31+
if not isinstance(self.node_matrix, dict):
32+
return
33+
34+
with open(self.node_file, "w") as f:
35+
for node, nproc in self.node_matrix.items():
36+
f.write(f"{node} slots={nproc}\n")
37+
self.nproc = sum(self.node_matrix.values()) # fix nproc to real node matrix
38+
39+
@property
40+
def local(self) -> List[str]:
41+
return [self.mpi_excutable, "--use-hwthread-cpus", "-np", str(self.nproc)]
42+
43+
@property
44+
def host_file(self) -> List[str]:
45+
return [self.mpi_excutable, "--hostfile", self.node_file]
46+
47+
@contextlib.contextmanager
48+
def apply(self, cmd: List[str]):
49+
cmd_copy = copy.copy(cmd)
50+
m = self.local if not self.node_matrix else self.host_file
51+
if self.user == 0:
52+
m.append("--allow-run-as-root")
53+
warnings.warn(UserWarning("Running Rosetta with MPI as Root User"))
54+
55+
yield m + cmd_copy
56+
57+
if os.path.exists(self.node_file):
58+
os.remove(self.node_file)
59+
60+
@classmethod
61+
def from_slurm(cls) -> "MPI_node":
62+
try:
63+
nodes = (
64+
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
65+
.decode()
66+
.strip()
67+
.split("\n")
68+
)
69+
except KeyError as e:
70+
raise RuntimeError(f"Environment variable {e} not set") from None
71+
except subprocess.CalledProcessError as e:
72+
raise RuntimeError(f"Failed to get node list: {e.output}") from None
73+
74+
slurm_cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK", "1")
75+
slurm_ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE", "1")
76+
77+
if int(slurm_cpus_per_task) < 1:
78+
print(f"Fixing $SLURM_CPUS_PER_TASK from {slurm_cpus_per_task} to 1.")
79+
slurm_cpus_per_task = "1"
80+
81+
if int(slurm_ntasks_per_node) < 1:
82+
print(f"Fixing $SLURM_NTASKS_PER_NODE from {slurm_ntasks_per_node} to 1.")
83+
slurm_ntasks_per_node = "1"
84+
85+
node_dict = {i: int(slurm_ntasks_per_node) * int(slurm_cpus_per_task) for i in nodes}
86+
87+
total_nproc = sum(node_dict.values())
88+
return cls(total_nproc, node_dict)

0 commit comments

Comments
 (0)