diff --git a/Decision_Tree_Model_on_Titanic_Survival_Dataset.ipynb b/Decision_Tree_Model_on_Titanic_Survival_Dataset.ipynb new file mode 100644 index 0000000..5c8520b --- /dev/null +++ b/Decision_Tree_Model_on_Titanic_Survival_Dataset.ipynb @@ -0,0 +1,514 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Decision Tree Model on Titanic Survival Dataset.ipynb", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyNv1UQt/gTxwm4uJKTgOqt4", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Scig2d2E9Tf2" + }, + "source": [ + "## **Decision Tree Model on Titanic Survival Dataset**\n", + "This is an implementation of the Decision Tree Model, a machine learning model on the Titanic survival dataset. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ohqIXtvId_p" + }, + "source": [ + "#### **Uploads**\n", + "Setting up libraries and uploading dataset files." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qWLIh6DJBb12" + }, + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from google.colab import files" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "resources": { + "http://localhost:8080/nbextensions/google.colab/files.js": { + "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgZG8gewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwoKICAgICAgbGV0IHBlcmNlbnREb25lID0gZmlsZURhdGEuYnl0ZUxlbmd0aCA9PT0gMCA/CiAgICAgICAgICAxMDAgOgogICAgICAgICAgTWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCk7CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPSBgJHtwZXJjZW50RG9uZX0lIGRvbmVgOwoKICAgIH0gd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCk7CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", + "ok": true, + "headers": [ + [ + "content-type", + "application/javascript" + ] + ], + "status": 200, + "status_text": "" + } + }, + "base_uri": "https://localhost:8080/", + "height": 128 + }, + "id": "exkGi5yhB4nk", + "outputId": "0dd8b193-ae08-44b3-f1b9-3186f4e1e99f" + }, + "source": [ + "upload_train = files.upload()\n", + "upload_test = files.upload()" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " Upload widget is only available when the cell has been executed in the\n", + " current browser session. Please rerun this cell to enable.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Saving train.csv to train.csv\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " Upload widget is only available when the cell has been executed in the\n", + " current browser session. Please rerun this cell to enable.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Saving test.csv to test.csv\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XnbdI_Grl_NZ" + }, + "source": [ + "#### **Data Preprocessing**\n", + "Making a function to make changes into the dataframe, such as deleting PassId and converting Sex to male=0, female=1 objective case; and filling in median value of Age wherever it is not available." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "NHwyMoy0yt2T" + }, + "source": [ + "def uploadto(address):\n", + " data = pd.read_csv(address)\n", + " col_names = ['PassId', 'Survived', 'PClass', 'Sex', 'Age', 'SibSp', 'ParCh', 'Fare']\n", + " data.columns = col_names\n", + " data.Sex.replace(('male', 'female'), (0, 1), inplace=True)\n", + " data.pop('PassId')\n", + " data = data.fillna(data['Age'].median())\n", + " print(data)\n", + " return data" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CXqMalUoB-Ze", + "outputId": "8fb43157-a89f-4b20-dfce-4135baa2b430" + }, + "source": [ + "train_df= uploadto('/content/train.csv')\n", + "test_df= uploadto('/content/test.csv')" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + " Survived PClass Sex Age SibSp ParCh Fare\n", + "0 0 3 0 22.0 1 0 7.2500\n", + "1 1 1 1 38.0 1 0 71.2833\n", + "2 1 3 1 26.0 0 0 7.9250\n", + "3 1 1 1 35.0 1 0 53.1000\n", + "4 0 3 0 35.0 0 0 8.0500\n", + ".. ... ... ... ... ... ... ...\n", + "615 1 2 1 24.0 1 2 65.0000\n", + "616 0 3 0 34.0 1 1 14.4000\n", + "617 0 3 1 26.0 1 0 16.1000\n", + "618 1 2 1 4.0 2 1 39.0000\n", + "619 0 2 0 26.0 0 0 10.5000\n", + "\n", + "[620 rows x 7 columns]\n", + " Survived PClass Sex Age SibSp ParCh Fare\n", + "0 0 3 0 27.0 1 0 14.4542\n", + "1 1 1 0 42.0 1 0 52.5542\n", + "2 1 3 0 20.0 1 1 15.7417\n", + "3 0 3 0 21.0 0 0 7.8542\n", + "4 0 3 0 21.0 0 0 16.1000\n", + ".. ... ... ... ... ... ... ...\n", + "266 0 2 0 27.0 0 0 13.0000\n", + "267 1 1 1 19.0 0 0 30.0000\n", + "268 0 3 1 28.0 1 2 23.4500\n", + "269 1 1 0 26.0 0 0 30.0000\n", + "270 0 3 0 32.0 0 0 7.7500\n", + "\n", + "[271 rows x 7 columns]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "immomjS2I13A" + }, + "source": [ + "#### **Decision Tree Functions**\n", + "In the next three blocks, we define:\n", + "1. Entropy Function\n", + "2. Division Algorithm\n", + "3. Information Gain Function\n", + "\n", + "All three of them are important in the study of Decision Trees." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YoaCH_nXNEZV" + }, + "source": [ + "def entropy(target_col):\n", + " elements, counts = np.unique(target_col,return_counts = True)\n", + " sum = 0.0\n", + " n = np.sum(counts)\n", + " for i in counts:\n", + " p = i/n\n", + " sum = sum - (p * np.log2(p))\n", + " return sum" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "JMN5iiOlfEHX" + }, + "source": [ + "def division(input_data, title, mean):\n", + " right = pd.DataFrame([], columns = input_data.columns)\n", + " left = pd.DataFrame([], columns = input_data.columns)\n", + " k = input_data.shape[0]\n", + " for x in range(k):\n", + " value = input_data[title].loc[x]\n", + " if value >= mean:\n", + " right = right.append(input_data.iloc[x])\n", + " else:\n", + " left = left.append(input_data.iloc[x])\n", + " return right, left" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "OjFDInCznjsg" + }, + "source": [ + "def iGain(input_data, title, mean):\n", + " right, left = division(input_data, title, mean)\n", + " k = input_data.shape[0]\n", + " left_ratio = float(left.shape[0])/k\n", + " right_ratio = float(right.shape[0])/k\n", + " if left.shape[0] == 0 or right.shape[0] == 0:\n", + " return -99999\n", + " igain = entropy(input_data.Survived) - ( left_ratio * entropy(left.Survived) + right_ratio * entropy(right.Survived))\n", + " return igain" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uVtW3BT7JcZi" + }, + "source": [ + "#### **Modeling**\n", + "In the next block we define the decision tree model. \n", + "1. The first function inside the class initialises the model.\n", + "2. The second is the main training module.\n", + "3. The third function is the prediction module." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tSr7ZXMkg3nL" + }, + "source": [ + "class DT:\n", + " def __init__(self, depth=0, max_depth=5):\n", + " self.left = None\n", + " self.right = None\n", + " self.title_name = None\n", + " self.mean_val = None\n", + " self.depth = depth\n", + " self.max_depth = max_depth\n", + " self.target = None\n", + " # \n", + " # \n", + " def train_model(self, input_train):\n", + " features = ['PClass', 'Sex', 'Age', 'SibSp', 'ParCh', 'Fare'] \n", + " iGains = []\n", + " for col in features: \n", + " iGains.append(iGain(input_train, col, input_train[col].mean()))\n", + " # \n", + " self.title_name = features[np.argmax(iGains)] \n", + " self.mean_val = input_train[self.title_name].mean() \n", + " # \n", + " r_data, l_data = division(input_train, self.title_name, self.mean_val) \n", + " r_data = r_data.reset_index(drop=True) \n", + " l_data = l_data.reset_index(drop=True)\n", + " # \n", + " if l_data.shape[0] == 0 or r_data.shape[0] == 0: \n", + " if input_train.Survived.mean() >= 0.5: \n", + " self.target = 1 \n", + " else: \n", + " self.target = 0\n", + " return\n", + " # \n", + " if self.depth >= self.max_depth: \n", + " if input_train.Survived.mean() >= 0.5:\n", + " self.target = 1\n", + " else:\n", + " self.target = 0\n", + " return\n", + " # \n", + " self.left = DT(self.depth+1,self.max_depth) \n", + " self.left.train_model(l_data)\n", + " self.right = DT(self.depth+1,self.max_depth) \n", + " self.right.train_model(r_data)\n", + " # \n", + " if input_train.Survived.mean() >= 0.5:\n", + " self.target = 1\n", + " else:\n", + " self.target = 0\n", + " return\n", + " # \n", + " # \n", + " def predictions(self,test_df): \n", + " if test_df[self.title_name] > self.mean_val:\n", + " if self.right is None:\n", + " return self.target\n", + " return self.right.predictions(test_df)\n", + " # \n", + " if test_df[self.title_name] < self.mean_val:\n", + " if self.left is None:\n", + " return self.target\n", + " return self.left.predictions(test_df)" + ], + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LCU43sfAr3z9" + }, + "source": [ + "model = DT()\n", + "model.train_model(train_df)" + ], + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lB1UudDq9CK7" + }, + "source": [ + "#### **Predictions and Accuracy**\n", + "In the next two blocks, we have measured the various statistical parameters of our model, such as accuracy, loss, F1 score, sensitivity and precision." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "iGnPeRiR-Duz" + }, + "source": [ + "def stats(data):\n", + " prediction = []\n", + " for i in range(data.shape[0]):\n", + " prediction.append(model.predictions(data.loc[i]))\n", + " prediction = np.array(prediction)\n", + " survive_data = np.array(data['Survived'])\n", + " # \n", + " loss = 0\n", + " f_neg = 0\n", + " f_pos = 0 \n", + " t_neg = 0\n", + " t_pos = 0\n", + " # \n", + " for i, j in zip(prediction, survive_data):\n", + " if i == 1 and j == 1:\n", + " t_pos+=1\n", + " elif i == 1 and j == 0:\n", + " f_pos+=1\n", + " loss+=1\n", + " elif i==0 and j == 1:\n", + " f_neg+=1\n", + " loss+=1\n", + " else:\n", + " t_neg+=1\n", + " # \n", + " rec = t_pos / (t_pos + f_neg)\n", + " prc = t_pos / (t_pos + f_pos)\n", + " acc = (t_pos + t_neg) / (t_pos + t_neg + f_pos + f_neg)\n", + " f1s = 2 * prc * rec / (prc + rec)\n", + " # \n", + " print(' Accuracy is {:.2f}%'.format(100*acc))\n", + " print(' Loss is',loss)\n", + " print(' F1 Score is {:.4f}'.format(f1s))\n", + " print('Sensitivity is {:.4f}'.format(rec))\n", + " print(' Precision is {:.4f}'.format(prc))" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Wt9olaxOAIhY", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c1c20aa2-74eb-4edf-b49a-d4e7819de2e2" + }, + "source": [ + "print(\"Statistics for Training dataset are:\")\n", + "print(\" \")\n", + "stats(train_df)\n", + "print(\" \")\n", + "print(\" \")\n", + "print(\"Statistics for Test dataset are:\")\n", + "print(\" \")\n", + "stats(test_df)" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Statistics for Training dataset are:\n", + " \n", + " Accuracy is 82.90%\n", + " Loss is 106\n", + " F1 Score is 0.7686\n", + "Sensitivity is 0.7273\n", + " Precision is 0.8148\n", + " \n", + " \n", + "Statistics for Test dataset are:\n", + " \n", + " Accuracy is 83.39%\n", + " Loss is 45\n", + " F1 Score is 0.7458\n", + "Sensitivity is 0.6804\n", + " Precision is 0.8250\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ot_uNe0_9ijP" + }, + "source": [ + "#### **References**\n", + "\n", + "1. Gagan Panwar's YouTube playlist over the same topic was a great help: www.youtube.com/playlist?list=PL9mhv0CavXYg3KFKct0JnslSwBCpAd_g0\n", + "2. Some Towards Data Science (TDS) articles were helpful, especially: www.towardsdatascience.com/decision-trees-for-classification-id3-algorithm-explained-89df76e72df1\n", + "3. This Exsilio blog was greatly helpful in visualing the final statistics of the model: www.blog.exsilio.com/all/accuracy-precision-recall-f1-score-interpretation-of-performance-measures/" + ] + } + ] +} \ No newline at end of file