|
8 | 8 | from sklearn.metrics import classification_report
|
9 | 9 | from xgboost import XGBClassifier
|
10 | 10 | from sklearn import preprocessing
|
| 11 | +from metafx_torch import TorchLinearModel |
| 12 | +import torch |
11 | 13 |
|
12 | 14 | if __name__ == "__main__":
|
13 | 15 | features = pd.read_csv(sys.argv[1], header=0, index_col=0, sep="\t")
|
|
23 | 25 | M = features.shape[0] # features count
|
24 | 26 | N = features.shape[1] # samples count
|
25 | 27 |
|
26 |
| - model = RandomForestClassifier(n_estimators=100) if sys.argv[4] == "RF" else XGBClassifier(n_estimators=100) |
27 | 28 | X = features.T
|
28 | 29 | y = np.array([metadata.loc[i, 1] for i in X.index])
|
29 | 30 |
|
| 31 | + model = None |
| 32 | + if sys.argv[4] == "RF": |
| 33 | + model = RandomForestClassifier(n_estimators=100) |
| 34 | + elif sys.argv[4] == "XGB": |
| 35 | + model = XGBClassifier(n_estimators=100) |
| 36 | + else: |
| 37 | + model = TorchLinearModel(n_features=M, n_classes=len(set(y))) |
| 38 | + |
30 | 39 | if sys.argv[4] == "XGB":
|
31 | 40 | le = preprocessing.LabelEncoder()
|
32 | 41 | le.fit(y)
|
33 | 42 | y = le.transform(y)
|
| 43 | + elif sys.argv[4] == "Torch": |
| 44 | + le = preprocessing.LabelEncoder() |
| 45 | + le.fit(y) |
| 46 | + y = le.transform(y) |
34 | 47 |
|
35 | 48 | model.fit(X, y)
|
36 |
| - dump(model, outName + ".joblib") |
37 | 49 |
|
38 |
| - if sys.argv[4] == "XGB": |
| 50 | + if sys.argv[4] == "RF": |
| 51 | + dump(model, outName + ".joblib") |
| 52 | + elif sys.argv[4] == "XGB": |
| 53 | + dump(model, outName + ".joblib") |
| 54 | + dump(le, outName + "_le.joblib") |
| 55 | + elif sys.argv[4] == "Torch": |
| 56 | + torch.save(model, outName + ".joblib") |
39 | 57 | dump(le, outName + "_le.joblib")
|
40 | 58 |
|
41 | 59 | print("Model accuracy after training:")
|
|
0 commit comments