-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
96 lines (78 loc) · 4.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import tensorflow as tf
import warnings
import argparse
import matplotlib.pyplot as plt
from utils.data_loader import load_data
from utils.data_preprocessing import visualize_data_distribution, plot_sample_images
from models.vgg16_feature_extractor import extract_vgg16_features_from_generator
from models.autoencoder import apply_autoencoder
from sklearn.metrics import classification_report, confusion_matrix
from models.knn_svm_classifier import KNNSVMClassifier
import joblib # To save the SVM model
# Suppress TensorFlow warnings related to GPU and minor issues
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress INFO and WARNING messages
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow internal logging
warnings.filterwarnings("ignore", category=UserWarning, message=".*?CUDA.*?") # Suppress CUDA-related warnings
def check_gpu():
gpus = tf.config.list_physical_devices('GPU')
if gpus:
print("GPU is available and will be used.")
else:
print("No GPU detected. Using CPU.")
# Setup MirroredStrategy for multiple GPUs
def setup_strategy():
if len(tf.config.list_physical_devices('GPU')) > 1:
strategy = tf.distribute.MirroredStrategy()
print("Using MirroredStrategy for multi-GPU training.")
return strategy
else:
print("Single GPU or CPU is being used.")
return None
def main(args):
# Check GPU availability
check_gpu()
# Setup strategy
strategy = setup_strategy()
with strategy.scope() if strategy else tf.device('/CPU:0'):
# Load data
print("################ Loading data using ImageDataGenerator... ########################")
train_generator, test_generator = load_data(args.data_dir, img_size=(args.img_size, args.img_size), batch_size=args.batch_size)
# Feature extraction
print("################ Extracting features using VGG16... ################")
X_train_features, y_train = extract_vgg16_features_from_generator(train_generator, batch_size=args.batch_size)
X_test_features, y_test = extract_vgg16_features_from_generator(test_generator, batch_size=args.batch_size)
# Dimensionality reduction with Autoencoder
print("################ Reducing dimensionality using Autoencoder... ################")
X_train_reduced, X_test_reduced,autoencoder = apply_autoencoder(X_train_features, X_test_features,
batch_size=args.batch_size, epochs=args.epochs)
# Classification with kNN-SVM
print("################ Training kNN-SVM classifier... ################")
classifier = KNNSVMClassifier(k=args.k_neighbors, kernel='linear')
classifier.fit(X_train_reduced, y_train)
# Save the SVM model
os.makedirs('saved_model', exist_ok=True)
joblib.dump(classifier, 'saved_model/knn_svm_classifier.joblib')
autoencoder.save('saved_model/autoencoder.h5')
# Prediction
print("################ Predicting on test data... ################")
y_pred = classifier.predict(X_test_reduced)
# Evaluation
print("################ Classification Report: ################")
print(classification_report(y_test, y_pred))
print("################ Confusion Matrix: ################")
print(confusion_matrix(y_test, y_pred))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="################ kNN-SVM with VGG16 Features for COVID-19 Pneumonia Detection ################")
# Data parameters
parser.add_argument('--data_dir', type=str, default='data/', help='Directory for data (with train/test folders)')
parser.add_argument('--img_size', type=int, default=224, help='Image size for VGG16 input')
# Model parameters
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for feature extraction and training')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for autoencoder training')
parser.add_argument('--k_neighbors', type=int, default=5, help='Number of neighbors for kNN')
parser.add_argument('--svm_c', type=float, default=1.0, help='C parameter for SVM')
# Miscellaneous
parser.add_argument('--eda', action='store_true', help='Perform exploratory data analysis')
args = parser.parse_args()
main(args)