-
Notifications
You must be signed in to change notification settings - Fork 0
/
39 Diseases Classif with Pytorch & EfficientNet
1 lines (1 loc) · 17.7 KB
/
39 Diseases Classif with Pytorch & EfficientNet
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import numpy as np\nimport pandas as pd\nimport os\nimport glob\nfrom tqdm.notebook import tqdm\nfrom PIL import Image","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2021-11-05T06:06:16.07647Z","iopub.execute_input":"2021-11-05T06:06:16.076918Z","iopub.status.idle":"2021-11-05T06:06:16.086258Z","shell.execute_reply.started":"2021-11-05T06:06:16.076877Z","shell.execute_reply":"2021-11-05T06:06:16.085394Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"os.listdir('/kaggle/input/fundusimage1000/1000images/1000images')","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:04:27.978666Z","iopub.execute_input":"2021-11-05T06:04:27.979046Z","iopub.status.idle":"2021-11-05T06:04:27.99825Z","shell.execute_reply.started":"2021-11-05T06:04:27.979011Z","shell.execute_reply":"2021-11-05T06:04:27.997214Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import torch\nimport torchvision\nfrom torchvision import transforms\nfrom torch.utils.data import DataLoader,Dataset\nimport torch.nn as nn\nimport torch.optim as optim\n\ndevice='cuda' if torch.cuda.is_available() else 'cpu'\ndevice","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:18.684038Z","iopub.execute_input":"2021-11-05T06:12:18.684853Z","iopub.status.idle":"2021-11-05T06:12:18.692349Z","shell.execute_reply.started":"2021-11-05T06:12:18.684804Z","shell.execute_reply":"2021-11-05T06:12:18.691491Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"class custom_dataset(Dataset):\n def __init__(self,root_dir,transform=None):\n\n self.data=[]\n self.transform=transform\n\n for img_path in tqdm(glob.glob(root_dir+\"/*/**\")):\n class_name=img_path.split(\"/\")[-2]\n self.data.append([img_path,class_name])\n \n self.class_map={}\n for index,item in enumerate(os.listdir(root_dir)):\n self.class_map[item]=index\n print(f\"Total Classes:{len(self.class_map)}\")\n \n def __len__(self):\n return len(self.data)\n\n def __getitem__(self,idx):\n img_path,class_name=self.data[idx]\n img=Image.open(img_path)\n class_id=self.class_map[class_name]\n class_id=torch.tensor(class_id)\n\n if self.transform:\n img=self.transform(img)\n\n return img,class_id","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:19.447105Z","iopub.execute_input":"2021-11-05T06:12:19.447627Z","iopub.status.idle":"2021-11-05T06:12:19.456966Z","shell.execute_reply.started":"2021-11-05T06:12:19.447586Z","shell.execute_reply":"2021-11-05T06:12:19.456248Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"root_dir=r'/kaggle/input/fundusimage1000/1000images/1000images'","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:20.561075Z","iopub.execute_input":"2021-11-05T06:12:20.561834Z","iopub.status.idle":"2021-11-05T06:12:20.5659Z","shell.execute_reply.started":"2021-11-05T06:12:20.561783Z","shell.execute_reply":"2021-11-05T06:12:20.56501Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def create_transforms(normalize=False,mean=[0,0,0],std=[1,1,1]):\n if normalize:\n my_transforms=transforms.Compose([\n transforms.Resize((224,224)),\n# transforms.ColorJitter(brightness=0.3,saturation=0.5,contrast=0.7,),\n# transforms.RandomRotation(degrees=33),\n transforms.RandomHorizontalFlip(),\n transforms.ToTensor(),\n transforms.Normalize(mean=mean,std=std)\n ])\n \n else:\n my_transforms=transforms.Compose([\n transforms.Resize((512,512)),\n# transforms.ColorJitter(brightness=0.3,saturation=0.5,contrast=0.7,p=0.57),\n# transforms.RandomRotation(degrees=33),\n transforms.RandomHorizontalFlip(),\n transforms.ToTensor()])\n \n \n return my_transforms\n ","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:48.414071Z","iopub.execute_input":"2021-11-05T06:12:48.41434Z","iopub.status.idle":"2021-11-05T06:12:48.420981Z","shell.execute_reply.started":"2021-11-05T06:12:48.414311Z","shell.execute_reply":"2021-11-05T06:12:48.420133Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"BS=8\nnum_classes=39","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:52.039427Z","iopub.execute_input":"2021-11-05T06:12:52.039983Z","iopub.status.idle":"2021-11-05T06:12:52.045058Z","shell.execute_reply.started":"2021-11-05T06:12:52.039942Z","shell.execute_reply":"2021-11-05T06:12:52.043323Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"my_transforms=create_transforms(normalize=False)\ndataset=custom_dataset(root_dir,my_transforms)\nprint(len(dataset))\n\ntrain_set, val_set=torch.utils.data.random_split(dataset,[800,200],generator=torch.Generator().manual_seed(7))\ntrain_loader=DataLoader(train_set,batch_size=BS,shuffle=True)\nval_loader=DataLoader(val_set,batch_size=BS,shuffle=True)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:54.156484Z","iopub.execute_input":"2021-11-05T06:12:54.157157Z","iopub.status.idle":"2021-11-05T06:12:54.239993Z","shell.execute_reply.started":"2021-11-05T06:12:54.157122Z","shell.execute_reply":"2021-11-05T06:12:54.239053Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def get_mean_std(loader):\n #var=E[x^2]-(E[x])^2\n channels_sum, channels_squared_sum,num_batches=0,0,0\n for data,_ in tqdm(loader):\n channels_sum+=torch.mean(data,dim=[0,2,3]) # we dont want to a singuar mean for al 3 channels (in case of RGB)\n channels_squared_sum+=torch.mean(data**2,dim=[0,2,3])\n num_batches+=1\n mean=channels_sum/num_batches\n std=(channels_squared_sum/num_batches-mean**2)**0.5\n \n return mean, std","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:12:56.703348Z","iopub.execute_input":"2021-11-05T06:12:56.703854Z","iopub.status.idle":"2021-11-05T06:12:56.709906Z","shell.execute_reply.started":"2021-11-05T06:12:56.703814Z","shell.execute_reply":"2021-11-05T06:12:56.708781Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"mean,std=get_mean_std(train_loader)\nprint(mean, std)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:13:03.699286Z","iopub.execute_input":"2021-11-05T06:13:03.699836Z","iopub.status.idle":"2021-11-05T06:14:25.072606Z","shell.execute_reply.started":"2021-11-05T06:13:03.699796Z","shell.execute_reply":"2021-11-05T06:14:25.07167Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"#Since these are medical images (differenct from Imagenet data) I'll use the calculated mean, std\nmy_transforms=create_transforms(normalize=True,mean=mean,std = std)\ndataset=custom_dataset(root_dir,my_transforms)\nprint(len(dataset))\n\ntrain_set, val_set=torch.utils.data.random_split(dataset,[800,200],generator=torch.Generator().manual_seed(7))\ntrain_loader=DataLoader(train_set,batch_size=BS,shuffle=True)\nval_loader=DataLoader(val_set,batch_size=BS,shuffle=True)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:14:25.074747Z","iopub.execute_input":"2021-11-05T06:14:25.075019Z","iopub.status.idle":"2021-11-05T06:14:25.13569Z","shell.execute_reply.started":"2021-11-05T06:14:25.074983Z","shell.execute_reply":"2021-11-05T06:14:25.13494Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import matplotlib.pyplot as plt\n\ndataiter = iter(train_loader)\nimages, labels = dataiter.next()\n\nprint(images.shape)\nprint(labels.shape)\n\n\nplt.imshow(images[0].permute(1, 2, 0))","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:14:25.136976Z","iopub.execute_input":"2021-11-05T06:14:25.137391Z","iopub.status.idle":"2021-11-05T06:14:26.016073Z","shell.execute_reply.started":"2021-11-05T06:14:25.137354Z","shell.execute_reply":"2021-11-05T06:14:26.015401Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Modelling","metadata":{}},{"cell_type":"markdown","source":"#### VGG16\nfrozen inner layers, top trained","metadata":{}},{"cell_type":"code","source":"vgg_model=torchvision.models.vgg16(pretrained=True)\nprint(vgg_model)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:14:44.986915Z","iopub.execute_input":"2021-11-05T06:14:44.987593Z","iopub.status.idle":"2021-11-05T06:14:48.942615Z","shell.execute_reply.started":"2021-11-05T06:14:44.987533Z","shell.execute_reply":"2021-11-05T06:14:48.941759Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"vgg_model=torchvision.models.vgg16(pretrained=True)\n\nfor param in vgg_model.parameters():\n param.requires_grad=False\n \nclass Identity(nn.Module):\n def __init__(self):\n super().__init__()\n\n \n def forward(self,x):\n return x\n \n# vgg_model.avgpool=Identity()\nvgg_model.classifier=nn.Sequential(\n nn.Linear(25088,2048),\n nn.ReLU(),\n nn.Dropout(p=0.37),\n nn.Linear(2048,1024),\n nn.ReLU(),\n nn.Dropout(p=0.5),\n nn.Linear(1024,num_classes)\n)\n\nvgg_model.to(device)\n\n# model.features[30]=nn.AdaptiveAvgPool2d((16,16))\n\n# print(model.features)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:18:15.882038Z","iopub.execute_input":"2021-11-05T06:18:15.882336Z","iopub.status.idle":"2021-11-05T06:18:17.763373Z","shell.execute_reply.started":"2021-11-05T06:18:15.882305Z","shell.execute_reply":"2021-11-05T06:18:17.762548Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"EPOCHS=5\nLR=1e-3","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:18:20.085249Z","iopub.execute_input":"2021-11-05T06:18:20.085508Z","iopub.status.idle":"2021-11-05T06:18:20.088997Z","shell.execute_reply.started":"2021-11-05T06:18:20.085479Z","shell.execute_reply":"2021-11-05T06:18:20.088319Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def train_model(model):\n \n criterion = nn.CrossEntropyLoss()\n optimizer = optim.AdamW(model.parameters(), lr=LR)\n scheduler=optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9,verbose=True)\n \n for epoch in range(EPOCHS):\n losses=[]\n print(f\"Epoch {epoch+1}/{EPOCHS}:\")\n loop=tqdm(enumerate(train_loader),total=len(train_loader))\n for batch_idx,(data,targets) in loop:\n data=data.to(device)\n targets=targets.to(device)\n\n #forward\n scores=model(data)\n loss=criterion(scores,targets)\n\n losses.append(loss.item())\n\n #backward\n optimizer.zero_grad()\n loss.backward()\n\n #gradient descent/adam step\n optimizer.step()\n mean_loss=sum(losses)/len(losses)\n scheduler.step()\n\n print(f\"Loss at Epoch {epoch+1}:\\t{mean_loss:.5f}\\n\")","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:18:20.837144Z","iopub.execute_input":"2021-11-05T06:18:20.838175Z","iopub.status.idle":"2021-11-05T06:18:20.846407Z","shell.execute_reply.started":"2021-11-05T06:18:20.838127Z","shell.execute_reply":"2021-11-05T06:18:20.845613Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def check_accuracy(loader, model):\n\n num_correct = 0\n num_samples = 0\n model.eval()\n\n with torch.no_grad():\n for x, y in tqdm(loader):\n x = x.to(device=device)\n y = y.to(device=device)\n\n scores = model(x)\n _, predictions = scores.max(1)\n num_correct += (predictions == y).sum()\n num_samples += predictions.size(0)\n\n print(\n f\"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}\"\n )\n\n model.train()","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:18:23.377167Z","iopub.execute_input":"2021-11-05T06:18:23.377964Z","iopub.status.idle":"2021-11-05T06:18:23.384823Z","shell.execute_reply.started":"2021-11-05T06:18:23.377911Z","shell.execute_reply":"2021-11-05T06:18:23.384059Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_model(vgg_model)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:18:27.501953Z","iopub.execute_input":"2021-11-05T06:18:27.502704Z","iopub.status.idle":"2021-11-05T06:31:03.737677Z","shell.execute_reply.started":"2021-11-05T06:18:27.502665Z","shell.execute_reply":"2021-11-05T06:31:03.736942Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(\"Training accuracy:\",end='\\t')\ncheck_accuracy(train_loader, vgg_model)\nprint(\"Validation accuracy:\",end='\\t')\ncheck_accuracy(val_loader, vgg_model)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:31:03.739517Z","iopub.execute_input":"2021-11-05T06:31:03.740026Z","iopub.status.idle":"2021-11-05T06:32:35.97291Z","shell.execute_reply.started":"2021-11-05T06:31:03.739977Z","shell.execute_reply":"2021-11-05T06:32:35.972159Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"#### EfficientNet B3\n\nfull training","metadata":{}},{"cell_type":"code","source":"!pip install efficientnet_pytorch","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:32:35.974331Z","iopub.execute_input":"2021-11-05T06:32:35.974619Z","iopub.status.idle":"2021-11-05T06:32:43.230226Z","shell.execute_reply.started":"2021-11-05T06:32:35.974582Z","shell.execute_reply":"2021-11-05T06:32:43.229313Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from efficientnet_pytorch import EfficientNet","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:32:43.232788Z","iopub.execute_input":"2021-11-05T06:32:43.233109Z","iopub.status.idle":"2021-11-05T06:32:43.366257Z","shell.execute_reply.started":"2021-11-05T06:32:43.233067Z","shell.execute_reply":"2021-11-05T06:32:43.365606Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def create_eff_net(version='b1',trainable=False):\n eff_model = EfficientNet.from_name(f'efficientnet-{version}')\n \n for param in eff_model.parameters():\n param.requires_grad = trainable\n \n num_ftrs = eff_model._fc.in_features\n \n eff_model._fc = nn.Sequential(\n nn.Linear(num_ftrs,1024),\n nn.ReLU(),\n nn.Dropout(p=0.37),\n nn.Linear(1024,num_classes)\n )\n\n eff_model.to(device)\n \n return eff_model\n\n ","metadata":{"execution":{"iopub.status.busy":"2021-11-05T07:20:46.914344Z","iopub.execute_input":"2021-11-05T07:20:46.914625Z","iopub.status.idle":"2021-11-05T07:20:46.920971Z","shell.execute_reply.started":"2021-11-05T07:20:46.914591Z","shell.execute_reply":"2021-11-05T07:20:46.919942Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"eff_model=create_eff_net(version='b3',trainable=True)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_model(eff_model)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:32:43.444988Z","iopub.execute_input":"2021-11-05T06:32:43.446719Z","iopub.status.idle":"2021-11-05T06:46:54.154238Z","shell.execute_reply.started":"2021-11-05T06:32:43.446682Z","shell.execute_reply":"2021-11-05T06:46:54.15352Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(\"Training accuracy:\",end='\\t')\ncheck_accuracy(train_loader, eff_model)\nprint(\"Validation accuracy:\",end='\\t')\ncheck_accuracy(val_loader, eff_model)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T06:47:37.610294Z","iopub.execute_input":"2021-11-05T06:47:37.610581Z","iopub.status.idle":"2021-11-05T06:49:12.015605Z","shell.execute_reply.started":"2021-11-05T06:47:37.610536Z","shell.execute_reply":"2021-11-05T06:49:12.014778Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"Freeze inner layers, train top now","metadata":{}},{"cell_type":"code","source":"eff_model2=create_eff_net(version='b3',trainable=False)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T07:21:19.829385Z","iopub.execute_input":"2021-11-05T07:21:19.830158Z","iopub.status.idle":"2021-11-05T07:21:19.984332Z","shell.execute_reply.started":"2021-11-05T07:21:19.830109Z","shell.execute_reply":"2021-11-05T07:21:19.983596Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_model(eff_model2)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T07:21:20.982228Z","iopub.execute_input":"2021-11-05T07:21:20.982522Z","iopub.status.idle":"2021-11-05T07:34:01.408204Z","shell.execute_reply.started":"2021-11-05T07:21:20.982488Z","shell.execute_reply":"2021-11-05T07:34:01.407501Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(\"Training accuracy:\",end='\\t')\ncheck_accuracy(train_loader, eff_model2)\nprint(\"Validation accuracy:\",end='\\t')\ncheck_accuracy(val_loader, eff_model2)","metadata":{"execution":{"iopub.status.busy":"2021-11-05T07:34:01.410895Z","iopub.execute_input":"2021-11-05T07:34:01.411703Z","iopub.status.idle":"2021-11-05T07:35:35.513267Z","shell.execute_reply.started":"2021-11-05T07:34:01.411663Z","shell.execute_reply":"2021-11-05T07:35:35.512577Z"},"trusted":true},"execution_count":null,"outputs":[]}]}