diff --git a/PC_model/pc_model.py b/PC_model/pc_model.py index 2343445182..e1de99538b 100644 --- a/PC_model/pc_model.py +++ b/PC_model/pc_model.py @@ -3,6 +3,7 @@ import numpy as np import sys import os +import joblib sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) from image_processing.gamma_correction import gamma_correction @@ -46,6 +47,9 @@ def train(self, train_x, train_y): def test(self, test_x): return self.xgb.predict(test_x), self.ovr.predict(test_x), self.ovo.predict(test_x), self.knn.predict(test_x), self.lr.predict(test_x), self.voting.predict(test_x), self.rfc.predict(test_x) + def save(self): + joblib.dump(value=self, filename=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_model.pkl")) + #%% train_df = pd.read_csv("/Users/ohs/Desktop/capstone/personal_color_dataset/train/new_data.csv") diff --git a/main.py b/main.py index 9f2729d1d5..3d38cc2d23 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,5 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt - -train_path = "./train" -test_path = "./test" +import joblib diff --git a/test_model.pkl b/test_model.pkl new file mode 100644 index 0000000000..d3e85df13d Binary files /dev/null and b/test_model.pkl differ