forked from zoogzog/chexnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ChexnetDownload.py
46 lines (33 loc) · 1.43 KB
/
ChexnetDownload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import numpy as np
import time
import sys
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as tfunc
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as func
from sklearn.metrics.ranking import roc_auc_score
from DensenetModels import DenseNet121
from DensenetModels import DenseNet169
from DensenetModels import DenseNet201
from ResnetModels import ResNet18
from ResnetModels import ResNet50
from DatasetGenerator import DatasetGenerator
#--------------------------------------------------------------------------------
def download (nnArchitecture, nnIsTrained, nnClassCount):
#-------------------- SETTINGS: NETWORK ARCHITECTURE
if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, nnIsTrained)
elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, nnIsTrained)
elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, nnIsTrained)
elif nnArchitecture == 'RES-NET-18': model = ResNet18(nnClassCount, nnIsTrained)
elif nnArchitecture == 'RES-NET-50': model = ResNet50(nnClassCount, nnIsTrained)
model = torch.nn.DataParallel(model)
if __name__ == "__main__":
download("DENSE-NET-121", True, 14)
# download("RES-NET-18", True, 14)