Skip to content

Commit a4af998

Browse files
committed
Add PyTorch model to fit/predict
1 parent fa73080 commit a4af998

File tree

5 files changed

+84
-8
lines changed

5 files changed

+84
-8
lines changed

.github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
export PATH=bin:$PATH
8181
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_rf
8282
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_xgb -e XGB
83+
metafx fit -f wd_unique_pca/feature_table.tsv -i wd_unique_pca/samples_categories.tsv -w wd_fit_torch -e Torch
8384
- name: metafx cv
8485
run: |
8586
export PATH=bin:$PATH

bin/metafx-scripts/fit.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sklearn.metrics import classification_report
99
from xgboost import XGBClassifier
1010
from sklearn import preprocessing
11+
from metafx_torch import TorchLinearModel
12+
import torch
1113

1214
if __name__ == "__main__":
1315
features = pd.read_csv(sys.argv[1], header=0, index_col=0, sep="\t")
@@ -23,19 +25,35 @@
2325
M = features.shape[0] # features count
2426
N = features.shape[1] # samples count
2527

26-
model = RandomForestClassifier(n_estimators=100) if sys.argv[4] == "RF" else XGBClassifier(n_estimators=100)
2728
X = features.T
2829
y = np.array([metadata.loc[i, 1] for i in X.index])
2930

31+
model = None
32+
if sys.argv[4] == "RF":
33+
model = RandomForestClassifier(n_estimators=100)
34+
elif sys.argv[4] == "XGB":
35+
model = XGBClassifier(n_estimators=100)
36+
else:
37+
model = TorchLinearModel(n_features=M, n_classes=len(set(y)))
38+
3039
if sys.argv[4] == "XGB":
3140
le = preprocessing.LabelEncoder()
3241
le.fit(y)
3342
y = le.transform(y)
43+
elif sys.argv[4] == "Torch":
44+
le = preprocessing.LabelEncoder()
45+
le.fit(y)
46+
y = le.transform(y)
3447

3548
model.fit(X, y)
36-
dump(model, outName + ".joblib")
3749

38-
if sys.argv[4] == "XGB":
50+
if sys.argv[4] == "RF":
51+
dump(model, outName + ".joblib")
52+
elif sys.argv[4] == "XGB":
53+
dump(model, outName + ".joblib")
54+
dump(le, outName + "_le.joblib")
55+
elif sys.argv[4] == "Torch":
56+
torch.save(model, outName + ".joblib")
3957
dump(le, outName + "_le.joblib")
4058

4159
print("Model accuracy after training:")

bin/metafx-scripts/metafx_torch.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python
2+
# PyTorch Liner Classification Model
3+
import torch
4+
from torch import nn, optim
5+
import numpy as np
6+
7+
8+
class TorchLinearModel():
9+
"""PyTorch sequential linear model for classification into C classes"""
10+
11+
def __init__(self, n_features, n_classes, n_epochs=1000):
12+
self.n_features = n_features
13+
self.n_classes = n_classes
14+
self.n_epochs = n_epochs
15+
self.model = nn.Sequential(
16+
nn.Linear(self.n_features, 32),
17+
nn.Sigmoid(),
18+
nn.Linear(32, self.n_classes),
19+
nn.Sigmoid()
20+
)
21+
self.criterion = nn.CrossEntropyLoss()
22+
self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
23+
24+
def fit(self, X, y):
25+
y_true = np.zeros((X.shape[0], self.n_classes))
26+
for i, val in enumerate(y):
27+
y_true[i, val] = 1.
28+
29+
X = torch.from_numpy(X.values).float()
30+
y_true = torch.from_numpy(y_true)
31+
32+
for epoch in range(self.n_epochs):
33+
self.optimizer.zero_grad()
34+
35+
y_pred = self.model(X)
36+
loss = self.criterion(y_pred, y_true)
37+
loss.backward()
38+
self.optimizer.step()
39+
40+
if (epoch+1) % 100 == 0:
41+
print("Epoch", epoch+1, "/", self.n_epochs, ":", round(loss.item(), 5), "loss", flush=True)
42+
43+
def predict(self, X):
44+
y_pred = self.model(torch.from_numpy(X.values).float()).cpu().data.numpy()
45+
return np.argmax(y_pred, axis=1)
46+
47+
def get_model(self):
48+
return self.model

bin/metafx-scripts/predict.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,24 @@
44
import pandas as pd
55
from joblib import load
66
from sklearn.metrics import classification_report
7+
import torch
78

89

910
if __name__ == "__main__":
1011
features = pd.read_csv(sys.argv[1], header=0, index_col=0, sep="\t")
1112
outName = sys.argv[2]
12-
model = load(sys.argv[3])
13-
metadata = None
1413
model_type = sys.argv[4]
15-
if model_type == "XGB":
14+
15+
if model_type == "RF":
16+
model = load(sys.argv[3])
17+
elif model_type == "XGB":
18+
model = load(sys.argv[3])
19+
le = load(sys.argv[3][:-7] + "_le.joblib")
20+
elif model_type == "Torch":
21+
model = torch.load(sys.argv[3])
1622
le = load(sys.argv[3][:-7] + "_le.joblib")
23+
24+
metadata = None
1725
if len(sys.argv) == 6:
1826
metadata = pd.read_csv(sys.argv[5], sep="\t", header=None, index_col=0, dtype=str)
1927
metadata.index = metadata.index.astype(str)
@@ -24,7 +32,7 @@
2432
X = features.T
2533
y_pred = model.predict(X)
2634

27-
if model_type == "XGB":
35+
if model_type == "XGB" or model_type == "Torch":
2836
y_pred = le.inverse_transform(y_pred)
2937

3038
outFile = open(outName + ".tsv", "w")

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ scikit-learn==1.3.0
44
matplotlib==3.8.2
55
joblib==1.2.0
66
ete3==3.1.3
7-
xgboost==2.0.3
7+
xgboost==2.0.3
8+
torch==2.2.0

0 commit comments

Comments
 (0)