-
Notifications
You must be signed in to change notification settings - Fork 0
/
Eye Diseases Classification using ResNet-18
1 lines (1 loc) · 20.5 KB
/
Eye Diseases Classification using ResNet-18
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":4130910,"sourceType":"datasetVersion","datasetId":2440665}],"dockerImageVersionId":30498,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Task\n\nEye disease classification is a research area that focuses on developing algorithms and models to accurately classify different types of eye diseases based on medical imaging data. It plays a critical role in assisting ophthalmologists and healthcare professionals in effectively diagnosing and treating eye diseases.\n\nThe primary objective of eye disease classification is to leverage machine learning and computer vision techniques to analyze medical images and detect the four diseases: cataract, diabetic retinopathy, glaucoma, normal\n\n\n# About the diseases\n\n1. **Cataract**: Cataract is a common age-related eye condition characterized by the clouding of the lens, leading to blurry vision and visual impairment. It can be treated surgically by replacing the cloudy lens with an artificial one, restoring clear vision and improving quality of life.\n\n2. **Diabetic Retinopathy**: Diabetic retinopathy is a complication of diabetes that affects the blood vessels in the retina. It can cause vision loss, including blurred or distorted vision, and in severe cases, lead to blindness. Early detection, regular eye exams, and proper management of diabetes are crucial for preventing and managing this condition.\n\n3. **Glaucoma**: Glaucoma is a group of eye diseases that damage the optic nerve, often due to increased fluid pressure in the eye. It gradually leads to vision loss, starting with peripheral vision and potentially progressing to complete blindness. Timely diagnosis, treatment, and ongoing monitoring are vital for preserving vision and preventing irreversible damage.\n\n# Use Case\n\nEye disease classification has several important use cases and applications:\n\n1. **Screening and Early Detection**: Eye disease classification algorithms can serve as screening tools to identify individuals at risk of developing eye diseases. By analyzing medical images, these models can detect early signs of diseases like diabetic retinopathy, age-related macular degeneration, glaucoma, and others. Early detection enables prompt intervention and treatment, potentially preventing vision loss.\n\n2. **Diagnosis Support**: Eye disease classification models can assist healthcare professionals, especially those with limited ophthalmic expertise, in making accurate diagnoses. By providing additional insights and suggestions based on image analysis, these models act as decision support systems, enhancing the accuracy and efficiency of diagnoses.\n\n3. **Treatment Planning and Monitoring**: Once an eye disease is diagnosed, classification algorithms can aid in treatment planning and monitoring. By analyzing sequential imaging data, these models can track disease progression, assess the effectiveness of treatments, and guide adjustments in treatment plans as required.\n\n","metadata":{}},{"cell_type":"markdown","source":"# Installing some extra libraries\n\n1. **torch-summary**: It is a library that provides a simple and convenient way to summarize the structure and number of parameters in a PyTorch model.\n\n2. **torchmetrics**: It is a PyTorch library that provides a collection of metric functions commonly used in machine learning and deep learning tasks.","metadata":{}},{"cell_type":"code","source":"!pip install torch-summary\n!pip install torchmetrics","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:30:48.494868Z","iopub.execute_input":"2023-08-07T12:30:48.495243Z","iopub.status.idle":"2023-08-07T12:31:12.059855Z","shell.execute_reply.started":"2023-08-07T12:30:48.495212Z","shell.execute_reply":"2023-08-07T12:31:12.058732Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Importing Libraries","metadata":{}},{"cell_type":"code","source":"import pandas as pd\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport seaborn as sea\nimport os\nfrom tqdm.notebook import tqdm\nimport cv2 as op\nimport torch\nfrom torchsummary import summary\nimport torchmetrics\n\nplt.style.use('seaborn')\nnp.__version__\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\ndevice","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2023-08-07T12:31:12.062381Z","iopub.execute_input":"2023-08-07T12:31:12.062751Z","iopub.status.idle":"2023-08-07T12:31:26.314927Z","shell.execute_reply.started":"2023-08-07T12:31:12.062711Z","shell.execute_reply":"2023-08-07T12:31:26.313142Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Loading the data","metadata":{}},{"cell_type":"code","source":"PATH = '/kaggle/input/eye-diseases-classification/dataset'\nlabel2id = {}\nfor i, label in enumerate(os.listdir(PATH)):\n label2id[label] = i\n \nid2label = {key : value for (value, key) in label2id.items()}\n\nfilenames, outcome = [], []\n\nfor label in tqdm(os.listdir(PATH)):\n for img in os.listdir(os.path.join(PATH, label)):\n filenames.append(os.path.join(PATH, label, img))\n outcome.append(label2id[label])\n\n \ndf = pd.DataFrame({\n \"filename\" : filenames,\n \"outcome\" : outcome\n})\n\ndf = df.sample(frac = 1)\ndf.head()","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:26.316259Z","iopub.execute_input":"2023-08-07T12:31:26.316623Z","iopub.status.idle":"2023-08-07T12:31:27.352892Z","shell.execute_reply.started":"2023-08-07T12:31:26.316589Z","shell.execute_reply":"2023-08-07T12:31:27.352044Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Plotting the class distribution\n\nWe can observe that the distribution is fairly uniform and each class has approximately 1000 images.","metadata":{}},{"cell_type":"code","source":"sea.countplot(x = 'outcome', data = df, palette = 'Blues_d')","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:27.35571Z","iopub.execute_input":"2023-08-07T12:31:27.356157Z","iopub.status.idle":"2023-08-07T12:31:27.615005Z","shell.execute_reply.started":"2023-08-07T12:31:27.356121Z","shell.execute_reply":"2023-08-07T12:31:27.614045Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Plotting the sample images\n\nIt was found that all the image pixels are not between [0-255]. Hence, simply normalizing or dividing the image with 255 lead to problems. Hence, each image was normalized using the Min-Max Scaling method to bring the values in the range [0, 1].","metadata":{}},{"cell_type":"code","source":"def load_image(path):\n img = plt.imread(path)\n img = (img - img.min())/img.max()\n return img\n\ncounter = 0\n\nplt.figure(figsize = (10, 12))\n\nfor i in range(4):\n for path in df[df['outcome'] == i].sample(n = 3)['filename']:\n plt.subplot(4, 3, counter + 1)\n img = load_image(path)\n plt.imshow(img)\n plt.axis('off')\n plt.title('Class:' + \" \" + id2label[i])\n counter += 1\n \nplt.show()","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:27.616304Z","iopub.execute_input":"2023-08-07T12:31:27.616654Z","iopub.status.idle":"2023-08-07T12:31:31.244358Z","shell.execute_reply.started":"2023-08-07T12:31:27.616622Z","shell.execute_reply":"2023-08-07T12:31:31.24331Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Building the dataset\n\n1. The dataset was building using `torch.utils.data.Dataset` for efficinet loading of data.\n2. For data augmentation, only Random Horizontal and Vertical flip was used. Adding augmentaitons in colors, brightness etc made training difficult, since then ","metadata":{}},{"cell_type":"code","source":"import torch \nimport torch.nn as nn\nfrom torch.utils.data import Dataset, DataLoader\nimport torchvision\nfrom torchvision import transforms, models\nimport torch.nn.functional as f\n\ntrain_transform = transforms.Compose([\n transforms.ToTensor(),\n transforms.Resize(size = (224, 224)),\n transforms.RandomHorizontalFlip(p = 0.5),\n transforms.RandomVerticalFlip(p = 0.5)\n])\n\nval_transform = transforms.Compose([\n transforms.ToTensor(),\n transforms.Resize(size = (224, 224))\n])\n\nclass EyeDataset(Dataset):\n def __init__(self, df, n_classes, transform = None):\n self.df = df\n self.n_samples = len(self.df)\n self.n_classes = n_classes\n self.transform = transform\n \n def __len__(self):\n return self.n_samples\n \n def __getitem__(self, index):\n img = plt.imread(self.df.iloc[index, 0])\n label = self.df.iloc[index, 1]\n \n img = (img - img.min())/img.max()\n \n if self.transform:\n img = self.transform(img)\n \n return img.to(torch.float32), label","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:31.245356Z","iopub.execute_input":"2023-08-07T12:31:31.245658Z","iopub.status.idle":"2023-08-07T12:31:31.262978Z","shell.execute_reply.started":"2023-08-07T12:31:31.245631Z","shell.execute_reply":"2023-08-07T12:31:31.261878Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"1. 85% of the dataset was used for training while the remaining for validation/testing.\n2. Batch Size of 128 was chosen.","metadata":{}},{"cell_type":"code","source":"from sklearn.model_selection import train_test_split\n\ndf_train, df_val = train_test_split(df, test_size = 0.15, random_state = 28)\n\ndf_train.shape, df_val.shape\n","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:31.2643Z","iopub.execute_input":"2023-08-07T12:31:31.264865Z","iopub.status.idle":"2023-08-07T12:31:31.458795Z","shell.execute_reply.started":"2023-08-07T12:31:31.264829Z","shell.execute_reply":"2023-08-07T12:31:31.45779Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"NUM_CLASSES = 4\nBATCH_SIZE = 128\n\ntrain_dataset = EyeDataset(df_train, NUM_CLASSES, train_transform)\nval_dataset = EyeDataset(df_val, NUM_CLASSES, val_transform)\n\ntrain_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)\nval_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:31.460438Z","iopub.execute_input":"2023-08-07T12:31:31.461045Z","iopub.status.idle":"2023-08-07T12:31:31.467966Z","shell.execute_reply.started":"2023-08-07T12:31:31.461007Z","shell.execute_reply":"2023-08-07T12:31:31.466734Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"a, b = next(iter(train_loader))\n\nprint(a.shape, b.shape)\ndel(a)\ndel(b)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T12:31:31.469654Z","iopub.execute_input":"2023-08-07T12:31:31.470353Z","iopub.status.idle":"2023-08-07T12:31:34.821774Z","shell.execute_reply.started":"2023-08-07T12:31:31.470318Z","shell.execute_reply":"2023-08-07T12:31:34.820672Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Model Architecture\n\n1. We used the Resnet18 pretrained model for this task.\n\n2. ResNet-18 is composed of multiple residual blocks, which are designed to address the problem of vanishing gradients in deep neural networks. These blocks introduce skip connections, allowing information to bypass several layers and flow directly to deeper layers. This helps in mitigating the degradation problem and enables the network to learn more effectively, even with very deep architectures.\n\n3. We replaced the final layer with two new dense layer. 70% of the resnet18 was freezed while the remaining was kept trianable. The resnet18 block was trained with a lr of 5x10 <sup>-5</sup> while the dense layers with lr = 8x10<sup>-4</sup>. \n\n![resnet18](https://i.imgur.com/XwcnU5x.png)","metadata":{}},{"cell_type":"code","source":"from math import ceil\n\nclass Net(nn.Module):\n def __init__(self):\n super().__init__()\n self.base = torchvision.models.resnet18(pretrained = True)\n \n for param in list(self.base.parameters())[:-15]:\n param.requires_grad = False\n \n self.block = nn.Sequential(\n nn.Linear(512, 128),\n nn.ReLU(),\n nn.Dropout(0.2),\n nn.Linear(128, 4),\n )\n self.base.classifier = nn.Sequential()\n self.base.fc = nn.Sequential()\n \n \n def get_optimizer(self):\n return torch.optim.AdamW([\n {'params' : self.base.parameters(), 'lr': 3e-5},\n {'params' : self.block.parameters(), 'lr': 8e-4}\n ])\n \n \n def forward(self, x):\n x = self.base(x)\n x = self.block(x)\n return x\n \n# 👁️ PyTorch: Eye Disease Classification| 92.7%\n\nclass Trainer(nn.Module):\n def __init__(self, train_loader, val_loader, device):\n super().__init__()\n self.train_loader = train_loader\n self.val_loader= val_loader\n self.device = device\n \n self.model = Net().to(self.device)\n self.optimizer = self.model.get_optimizer()\n self.loss_fxn = nn.CrossEntropyLoss()\n self.accuracy = torchmetrics.Accuracy(task = \"multiclass\", num_classes = NUM_CLASSES).to(self.device)\n \n self.history = {'train_loss' : [], 'val_loss': [], 'train_acc': [], 'val_acc': []}\n \n def training_step(self, x, y):\n pred = self.model(x)\n loss = self.loss_fxn(pred, y)\n acc = self.accuracy(pred, y)\n \n self.optimizer.zero_grad()\n loss.backward()\n self.optimizer.step()\n \n return loss, acc\n \n def val_step(self, x, y):\n with torch.no_grad():\n pred = self.model(x)\n loss = self.loss_fxn(pred, y)\n acc = self.accuracy(pred, y)\n \n return loss, acc\n \n def step_fxn(self, loader, step):\n loss, acc = 0, 0\n \n for X, y in tqdm(loader):\n X, y = X.to(self.device), y.to(self.device)\n l, a = step(X, y)\n loss, acc = loss + l.item(), acc + a.item()\n \n return loss/len(loader), acc/len(loader)\n \n def train(self, epochs):\n \n for epoch in tqdm(range(epochs)):\n \n train_loss, train_acc = self.step_fxn(self.train_loader, self.training_step)\n val_loss, val_acc = self.step_fxn(self.val_loader, self.val_step)\n \n for item, value in zip(self.history.keys(), list([train_loss, val_loss, train_acc, val_acc])):\n self.history[item].append(value)\n \n \n print(\"[Epoch: {}] Train: [loss: {:.3f} acc: {:.3f}] Val: [loss: {:.3f} acc:{:.3f}]\".format(epoch + 1, train_loss, train_acc, val_loss, val_acc))\n ","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:08:58.561004Z","iopub.execute_input":"2023-08-07T13:08:58.561637Z","iopub.status.idle":"2023-08-07T13:08:58.595254Z","shell.execute_reply.started":"2023-08-07T13:08:58.561595Z","shell.execute_reply":"2023-08-07T13:08:58.594344Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"trainer = Trainer(train_loader, val_loader, device)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:08:59.627203Z","iopub.execute_input":"2023-08-07T13:08:59.627625Z","iopub.status.idle":"2023-08-07T13:08:59.851428Z","shell.execute_reply.started":"2023-08-07T13:08:59.627593Z","shell.execute_reply":"2023-08-07T13:08:59.850488Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Summary of the model","metadata":{}},{"cell_type":"code","source":"summary(trainer.model.base, (3, 224, 224))","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:09:00.900827Z","iopub.execute_input":"2023-08-07T13:09:00.901218Z","iopub.status.idle":"2023-08-07T13:09:00.930381Z","shell.execute_reply.started":"2023-08-07T13:09:00.901182Z","shell.execute_reply":"2023-08-07T13:09:00.929436Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Training the model","metadata":{}},{"cell_type":"code","source":"trainer.train(epochs = 2)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:09:06.938572Z","iopub.execute_input":"2023-08-07T13:09:06.938937Z","iopub.status.idle":"2023-08-07T13:11:37.921705Z","shell.execute_reply.started":"2023-08-07T13:09:06.938906Z","shell.execute_reply":"2023-08-07T13:11:37.920666Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Plotting Model Results","metadata":{}},{"cell_type":"code","source":"plt.figure(figsize = (15, 4))\n\nplt.subplot(1,2,1)\nplt.title('Loss')\nplt.plot(trainer.history['train_loss'], label = 'Training')\nplt.plot(trainer.history['val_loss'], label = 'Validation')\nplt.legend()\n\nplt.subplot(1,2,2)\nplt.title('Accuracy')\nplt.plot(trainer.history['train_acc'], label = 'Training')\nplt.plot(trainer.history['val_acc'], label = 'Training')\nplt.legend()\n\n","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:12:24.029676Z","iopub.execute_input":"2023-08-07T13:12:24.030065Z","iopub.status.idle":"2023-08-07T13:12:24.649558Z","shell.execute_reply.started":"2023-08-07T13:12:24.030033Z","shell.execute_reply":"2023-08-07T13:12:24.648626Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Model Predictions","metadata":{}},{"cell_type":"code","source":"preds, true = [], []\n\nwith torch.no_grad():\n for x, y in tqdm(val_loader):\n pred = torch.argmax(trainer.model(x.to(device)), axis = 1).detach().cpu().numpy()\n preds.extend(pred)\n true.extend(y)\n \nlen(preds), len(true)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:12:41.826816Z","iopub.execute_input":"2023-08-07T13:12:41.827215Z","iopub.status.idle":"2023-08-07T13:12:54.061765Z","shell.execute_reply.started":"2023-08-07T13:12:41.827182Z","shell.execute_reply":"2023-08-07T13:12:54.060816Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from sklearn.metrics import confusion_matrix\n\ncm = confusion_matrix(true, preds)\nsea.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)\n\nplt.xlabel('Predicted labels')\nplt.ylabel('True labels')\nplt.title('Confusion Matrix')","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:12:58.916144Z","iopub.execute_input":"2023-08-07T13:12:58.916532Z","iopub.status.idle":"2023-08-07T13:12:59.303983Z","shell.execute_reply.started":"2023-08-07T13:12:58.916501Z","shell.execute_reply":"2023-08-07T13:12:59.302996Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from sklearn.metrics import classification_report\n\nprint(classification_report(true, preds, target_names = label2id.keys()))","metadata":{"execution":{"iopub.status.busy":"2023-08-07T13:13:03.746566Z","iopub.execute_input":"2023-08-07T13:13:03.746932Z","iopub.status.idle":"2023-08-07T13:13:03.775793Z","shell.execute_reply.started":"2023-08-07T13:13:03.7469Z","shell.execute_reply":"2023-08-07T13:13:03.774649Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Conclusion\n\n1. The \"glaucoma\" class has a precision of 0.90, recall of 0.83, and F1-score of 0.86. This suggests that the model performs well in correctly identifying glaucoma cases, but there may be some false negatives. The \"normal\" class has a precision of 0.85, recall of 0.90, and F1-score of 0.88. The model performs well in both precision and recall for normal cases. The \"diabetic_retinopathy\" class has high precision, recall, and F1-score of 0.99. This indicates the model's excellent performance in correctly identifying cases of diabetic retinopathy. The \"cataract\" class also has high precision, recall, and F1-score of 0.95 and above, indicating accurate identification of cataract cases.\n\n2. The overall accuracy of the model is 0.92, indicating the percentage of correctly predicted instances across all classes.\n\n3. In summary, the model shows strong performance in correctly identifying cases of diabetic retinopathy and cataract, while slightly lower precision and recall are observed for glaucoma.","metadata":{}}]}