From a3fb013513de459381a6721a3d7fd43f6f4d26ed Mon Sep 17 00:00:00 2001 From: PAlena <86962990+PAlena@users.noreply.github.com> Date: Mon, 29 Apr 2024 22:12:19 +0200 Subject: [PATCH] Add files via upload --- Seminar_5/DS_case_study_solved.ipynb | 1890 ++++++++++++++++++++++++++ 1 file changed, 1890 insertions(+) create mode 100644 Seminar_5/DS_case_study_solved.ipynb diff --git a/Seminar_5/DS_case_study_solved.ipynb b/Seminar_5/DS_case_study_solved.ipynb new file mode 100644 index 0000000..c85c09f --- /dev/null +++ b/Seminar_5/DS_case_study_solved.ipynb @@ -0,0 +1,1890 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2d667870", + "metadata": {}, + "source": [ + "# Heart Attack Predicting Case Study" + ] + }, + { + "cell_type": "markdown", + "id": "be742eee", + "metadata": {}, + "source": [ + "## Problem definition: Predict whether a patient will have a heart attack or not" + ] + }, + { + "cell_type": "markdown", + "id": "6c9c4ad8", + "metadata": {}, + "source": [ + "### Features:\n", + "Age : Age of the patient\n", + "\n", + "Sex : Sex of the patient\n", + "\n", + "exang: exercise induced angina (1 = yes; 0 = no)\n", + "\n", + "ca: number of major vessels (0-3)\n", + "\n", + "cp : Chest Pain type chest pain type\n", + "\n", + "Value 1: typical angina\n", + "Value 2: atypical angina\n", + "Value 3: non-anginal pain\n", + "Value 4: asymptomatic\n", + "trtbps : resting blood pressure (in mm Hg)\n", + "\n", + "chol : cholestoral in mg/dl fetched via BMI sensor\n", + "\n", + "fbs : (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)\n", + "\n", + "rest_ecg : resting electrocardiographic results\n", + "\n", + "Value 0: normal\n", + "Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)\n", + "Value 2: showing probable or definite left ventricular hypertrophy by Estes' criteria\n", + "thalach : maximum heart rate achieved\n", + "\n", + "target : 0= less chance of heart attack 1= more chance of heart attack" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "25294e0a", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import sklearn\n", + "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import confusion_matrix\n" + ] + }, + { + "cell_type": "markdown", + "id": "0cd1b4db", + "metadata": {}, + "source": [ + "## Data preparation" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cad732a7", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv(\"data/heart.csv\")\n", + "o2_saturation = pd.read_csv(\"data/o2Saturation.csv\").head(303)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "190ade33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
agesexcptrtbpscholfbsrestecgthalachhexngoldpeakslpcaathalloutput
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
\n", + "
" + ], + "text/plain": [ + " age sex cp trtbps chol fbs restecg thalachh exng oldpeak slp \\\n", + "0 63 1 3 145 233 1 0 150 0 2.3 0 \n", + "1 37 1 2 130 250 0 1 187 0 3.5 0 \n", + "2 41 0 1 130 204 0 0 172 0 1.4 2 \n", + "3 56 1 1 120 236 0 1 178 0 0.8 2 \n", + "4 57 0 0 120 354 0 1 163 1 0.6 2 \n", + "\n", + " caa thall output \n", + "0 0 1 1 \n", + "1 0 2 1 \n", + "2 0 2 1 \n", + "3 0 2 1 \n", + "4 0 2 1 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9c2711c0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 303 entries, 0 to 302\n", + "Data columns (total 14 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 age 303 non-null int64 \n", + " 1 sex 303 non-null int64 \n", + " 2 cp 303 non-null int64 \n", + " 3 trtbps 303 non-null int64 \n", + " 4 chol 303 non-null int64 \n", + " 5 fbs 303 non-null int64 \n", + " 6 restecg 303 non-null int64 \n", + " 7 thalachh 303 non-null int64 \n", + " 8 exng 303 non-null int64 \n", + " 9 oldpeak 303 non-null float64\n", + " 10 slp 303 non-null int64 \n", + " 11 caa 303 non-null int64 \n", + " 12 thall 303 non-null int64 \n", + " 13 output 303 non-null int64 \n", + "dtypes: float64(1), int64(13)\n", + "memory usage: 33.3 KB\n" + ] + } + ], + "source": [ + "data.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1489c727", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
agesexcptrtbpscholfbsrestecgthalachhexngoldpeakslpcaathalloutput
count303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000
mean54.3663370.6831680.966997131.623762246.2640260.1485150.528053149.6468650.3267331.0396041.3993400.7293732.3135310.544554
std9.0821010.4660111.03205217.53814351.8307510.3561980.52586022.9051610.4697941.1610750.6162261.0226060.6122770.498835
min29.0000000.0000000.00000094.000000126.0000000.0000000.00000071.0000000.0000000.0000000.0000000.0000000.0000000.000000
25%47.5000000.0000000.000000120.000000211.0000000.0000000.000000133.5000000.0000000.0000001.0000000.0000002.0000000.000000
50%55.0000001.0000001.000000130.000000240.0000000.0000001.000000153.0000000.0000000.8000001.0000000.0000002.0000001.000000
75%61.0000001.0000002.000000140.000000274.5000000.0000001.000000166.0000001.0000001.6000002.0000001.0000003.0000001.000000
max77.0000001.0000003.000000200.000000564.0000001.0000002.000000202.0000001.0000006.2000002.0000004.0000003.0000001.000000
\n", + "
" + ], + "text/plain": [ + " age sex cp trtbps chol fbs \\\n", + "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n", + "mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 \n", + "std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 \n", + "min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 \n", + "25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 \n", + "50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 \n", + "75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 \n", + "max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 \n", + "\n", + " restecg thalachh exng oldpeak slp caa \\\n", + "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n", + "mean 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 \n", + "std 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 \n", + "min 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 \n", + "25% 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 \n", + "50% 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 \n", + "75% 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 \n", + "max 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 \n", + "\n", + " thall output \n", + "count 303.000000 303.000000 \n", + "mean 2.313531 0.544554 \n", + "std 0.612277 0.498835 \n", + "min 0.000000 0.000000 \n", + "25% 2.000000 0.000000 \n", + "50% 2.000000 1.000000 \n", + "75% 3.000000 1.000000 \n", + "max 3.000000 1.000000 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "17e69c68", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# what is the distribution of age column?\n", + "data[\"age\"].hist(bins=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "caba3e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Numbers if man: 0.68\n", + "Numbers if woman: 0.32\n" + ] + } + ], + "source": [ + "# What is the share of wemales in the sample? What is the share of males?\n", + "print(f'Numbers if man: {round(data.sex.sum()/len(data),2)}')\n", + "print(f'Numbers if woman: {round(1 - data.sex.sum()/len(data),2)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f0dd15a8-1502-40c5-8522-2b3737f89bac", + "metadata": {}, + "outputs": [], + "source": [ + "# Add o2_saturation data in the data frame 'data'\n", + "# o2_saturation.rename(columns={'98.6':'o2_saturation'}, inplace = True)\n", + "# data = pd.concat([data, o2_saturation], axis=1)\n", + "data[\"o2_Saturation\"] = o2_saturation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "195ae311", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "corr = data.corr()\n", + "\n", + "plt.figure(figsize=(15, 10))\n", + "sns.heatmap(corr, annot=True)" + ] + }, + { + "cell_type": "markdown", + "id": "076fb645", + "metadata": {}, + "source": [ + "## Random Forest Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0fe76502", + "metadata": {}, + "outputs": [], + "source": [ + "# Specify X and y and split the dataset\n", + "X = data.drop(\"output\", axis=1)\n", + "y = data[\"output\"]\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state = 14)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "183fd155", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((212, 14), (91, 14), (212,), (91,))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train.shape, X_test.shape, y_train.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "077d86fe", + "metadata": {}, + "outputs": [], + "source": [ + "# specify model\n", + "forest = RandomForestClassifier(random_state = 14)\n", + "\n", + "# train model\n", + "forest.fit(X_train, y_train)\n", + "\n", + "# make prediction\n", + "forest_preds = forest.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "40e251ea-7fcd-4aee-b638-f29309b68b89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1,\n", + " 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,\n", + " 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,\n", + " 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0,\n", + " 0, 1, 1])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "forest_preds" + ] + }, + { + "cell_type": "markdown", + "id": "080c566f-6f78-43d2-8782-008a738e0697", + "metadata": {}, + "source": [ + "## Evaluate the result" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "525bd174-6241-4d14-946a-f41fe31e6c7b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "41\n", + "8\n" + ] + } + ], + "source": [ + "# Manually calculate number of True positive and False negative predictions\n", + "Y = pd.DataFrame()\n", + "Y['real'] = y_test\n", + "Y['pred'] = forest_preds\n", + "print(len(Y[(Y.real==1)&(Y.pred==1)]))\n", + "print(len(Y[(Y.real==1)&(Y.pred==0)]))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e5252037-b250-4ae4-99c5-1138bfd537e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Confusion Matrix')" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sns.heatmap(confusion_matrix(y_test, forest_preds), annot=True)\n", + "plt.xlabel(\"Predicted Labels\")\n", + "plt.ylabel(\"Actual Labels\")\n", + "plt.title(\"Confusion Matrix\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3d3f1e98-c3cd-428d-aa14-d099db19e3a6", + "metadata": {}, + "outputs": [], + "source": [ + "# Based on the confusion matrics fill:\n", + "true_positive = 41\n", + "true_negative = 33\n", + "false_positive = 9\n", + "false_negative = 8" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a5f55d8d-f413-4d0a-9b7b-23efc31c257d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.82\n", + "0.82\n" + ] + } + ], + "source": [ + "# Precision: measure of how many of the positive predictions made are correct (true positives)\n", + "print(round(true_positive/(true_positive + false_positive),2))\n", + "print(round(sklearn.metrics.precision_score(y_test, forest_preds), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "979fb237-0771-4585-ae70-74a8b4db6077", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.84\n", + "0.84\n" + ] + } + ], + "source": [ + "# Recall: is a measure of how many of the positive cases the classifier correctly predicted\n", + "print(round(true_positive/(true_positive + false_negative),2))\n", + "print(round(sklearn.metrics.recall_score(y_test, forest_preds), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "008cd394-6520-4888-863f-6463fe3fa9ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.83\n", + "0.83\n" + ] + } + ], + "source": [ + "# F1-Score: is a measure combining both precision and recall 2(prec*recall)/(prec+recall)\n", + "print(round(2*(0.82*0.84)/(0.82+0.84),2))\n", + "print(round(sklearn.metrics.f1_score(y_test, forest_preds), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c6fac62f-c5a8-4ba5-b7ee-82ec647b9ef0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.81\n", + "0.81\n" + ] + } + ], + "source": [ + "# Accuracy: describing the number of correct predictions over all predictions\n", + "print(round((true_positive+true_negative)/(true_positive+true_negative+false_positive+false_negative),2))\n", + "print(round(sklearn.metrics.accuracy_score(y_test, forest_preds), 2))" + ] + }, + { + "cell_type": "markdown", + "id": "e43acac7-9bca-461c-9e32-8542a3fd04ef", + "metadata": {}, + "source": [ + "## Save the result" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "922aecda-8067-4b9c-8f2c-ad83e345b42a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
precision
recall
f1_score
accuracy
\n", + "
" + ], + "text/plain": [ + "Empty DataFrame\n", + "Columns: []\n", + "Index: [precision, recall, f1_score, accuracy]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create empty df with indexes: ['precision','recall', 'f1_score','accuracy']\n", + "result = pd.DataFrame(index=['precision','recall', 'f1_score','accuracy'])\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "1469766a-232b-4931-a7d6-43ee8b22141c", + "metadata": {}, + "outputs": [], + "source": [ + "# write a function which will return 4 metrics:\n", + "# [round(sklearn.metrics.precision_score(y_test, forest_preds), 2),round(sklearn.metrics.recall_score(y_test, forest_preds), 2),round(sklearn.metrics.f1_score(y_test, forest_preds), 2),round(sklearn.metrics.accuracy_score(y_test, forest_preds), 2)]\n", + "\n", + "\n", + "def eval_result(x_predicted):\n", + " return [round(sklearn.metrics.precision_score(y_test, x_predicted), 2),\n", + " round(sklearn.metrics.recall_score(y_test, x_predicted), 2),\n", + " round(sklearn.metrics.f1_score(y_test, x_predicted), 2),\n", + " round(sklearn.metrics.accuracy_score(y_test, x_predicted), 2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ce75bce6-17e4-46fc-bd25-c31bdc178d10", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
model_1
precision0.82
recall0.84
f1_score0.83
accuracy0.81
\n", + "
" + ], + "text/plain": [ + " model_1\n", + "precision 0.82\n", + "recall 0.84\n", + "f1_score 0.83\n", + "accuracy 0.81" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# add the result of model_1 into the df 'result'\n", + "result['model_1'] = eval_result(forest_preds)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "id": "e6c6db57-954c-4449-bc97-5c3ef07e2020", + "metadata": {}, + "source": [ + "### Change the threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "cabae7b5-dc23-41bc-8cba-f9ba3e1df5ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 0, 0, 1])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "forest.predict(X_test)[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "bd4b584e-9b3d-47b2-ae9e-396c6f6af482", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.82, 0.18],\n", + " [0.37, 0.63],\n", + " [0.81, 0.19],\n", + " [0.79, 0.21],\n", + " [0.33, 0.67]])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "forest.predict_proba(X_test)[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0df4b875-4107-4ce8-837a-1c9b5bfcb669", + "metadata": {}, + "outputs": [], + "source": [ + "# get the result for a threshold = 0.7" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "21754b98-1b7b-43f4-b2ad-18507ba9c86d", + "metadata": {}, + "outputs": [], + "source": [ + "# get the result for any threshold from 0.1 till 0.9 and add it into 'result' df\n", + "res_proba = forest.predict_proba(X_test)\n", + "\n", + "for t in [x/10 for x in range(1, 10)]:\n", + " res_pred = [0 if i[0]>=t else 1 for i in res_proba]\n", + " result[f'model_thr_{t}'] = eval_result(res_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3b87475b-d862-44a7-a667-5962099483ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
precisionrecallf1_scoreaccuracy
model_thr_0.70.751.000.860.82
model_10.820.840.830.81
model_thr_0.50.820.840.830.81
model_thr_0.60.780.920.840.81
model_thr_0.40.850.690.760.77
model_thr_0.80.691.000.820.76
model_thr_0.30.900.550.680.73
model_thr_0.20.960.470.630.70
model_thr_0.90.611.000.760.66
model_thr_0.11.000.180.310.56
\n", + "
" + ], + "text/plain": [ + " precision recall f1_score accuracy\n", + "model_thr_0.7 0.75 1.00 0.86 0.82\n", + "model_1 0.82 0.84 0.83 0.81\n", + "model_thr_0.5 0.82 0.84 0.83 0.81\n", + "model_thr_0.6 0.78 0.92 0.84 0.81\n", + "model_thr_0.4 0.85 0.69 0.76 0.77\n", + "model_thr_0.8 0.69 1.00 0.82 0.76\n", + "model_thr_0.3 0.90 0.55 0.68 0.73\n", + "model_thr_0.2 0.96 0.47 0.63 0.70\n", + "model_thr_0.9 0.61 1.00 0.76 0.66\n", + "model_thr_0.1 1.00 0.18 0.31 0.56" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result.T.sort_values('accuracy', ascending=False)" + ] + }, + { + "cell_type": "markdown", + "id": "45d0fcc7", + "metadata": {}, + "source": [ + "### Using RandomizedSearchCV for hyperparameter turning" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "7c1501e3", + "metadata": {}, + "outputs": [], + "source": [ + "estimator = RandomForestClassifier()\n", + "grid = {\"n_estimators\": [80, 90, 100, 110, 120],\n", + " \"max_depth\": [5, 10, 15],\n", + " \"max_features\" : [\"auto\", \"sqrt\", \"log2\"],\n", + " \"min_samples_split\": [2, 1, 3, 4],\n", + " \"min_samples_leaf\": [1, 2, 3, 4]}\n", + "\n", + "rand_search_model = RandomizedSearchCV(estimator=estimator,\n", + " param_distributions=grid)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "a9188e30", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n", + "5 fits failed out of a total of 50.\n", + "The score on these train-test partitions for these parameters will be set to nan.\n", + "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n", + "\n", + "Below are more details about the failures:\n", + "--------------------------------------------------------------------------------\n", + "5 fits failed with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n", + " estimator.fit(X_train, y_train, **fit_params)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 476, in fit\n", + " trees = Parallel(\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 1043, in __call__\n", + " if self.dispatch_one_batch(iterator):\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 861, in dispatch_one_batch\n", + " self._dispatch(tasks)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 779, in _dispatch\n", + " job = self._backend.apply_async(batch, callback=cb)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n", + " result = ImmediateResult(func)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 572, in __init__\n", + " self.results = batch()\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in __call__\n", + " return [func(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in \n", + " return [func(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/fixes.py\", line 117, in __call__\n", + " return self.function(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 189, in _parallel_build_trees\n", + " tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 969, in fit\n", + " super().fit(\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 265, in fit\n", + " check_scalar(\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py\", line 1480, in check_scalar\n", + " raise ValueError(\n", + "ValueError: min_samples_split == 1, must be >= 2.\n", + "\n", + " warnings.warn(some_fits_failed_message, FitFailedWarning)\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_search.py:953: UserWarning: One or more of the test scores are non-finite: [0.81550388 0.8013289 0.81572536 0.82026578 nan 0.81539313\n", + " 0.82026578 0.81085271 0.82026578 0.81096346]\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# Retrain the model with tuned parameters and compare with model 1\n", + "rand_search_model.fit(X_train, y_train)\n", + "rand_search_pred = rand_search_model.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "acbb111b-8b61-46b2-92b3-511588758d0f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'n_estimators': 90,\n", + " 'min_samples_split': 3,\n", + " 'min_samples_leaf': 3,\n", + " 'max_features': 'log2',\n", + " 'max_depth': 10}" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rand_search_model.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "10b109c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
precisionrecallf1_scoreaccuracy
model_thr_0.70.751.000.860.82
model_10.820.840.830.81
model_thr_0.50.820.840.830.81
model_thr_0.60.780.920.840.81
tuned_model0.820.820.820.80
model_thr_0.40.850.690.760.77
model_thr_0.80.691.000.820.76
model_thr_0.30.900.550.680.73
model_thr_0.20.960.470.630.70
model_thr_0.90.611.000.760.66
model_thr_0.11.000.180.310.56
\n", + "
" + ], + "text/plain": [ + " precision recall f1_score accuracy\n", + "model_thr_0.7 0.75 1.00 0.86 0.82\n", + "model_1 0.82 0.84 0.83 0.81\n", + "model_thr_0.5 0.82 0.84 0.83 0.81\n", + "model_thr_0.6 0.78 0.92 0.84 0.81\n", + "tuned_model 0.82 0.82 0.82 0.80\n", + "model_thr_0.4 0.85 0.69 0.76 0.77\n", + "model_thr_0.8 0.69 1.00 0.82 0.76\n", + "model_thr_0.3 0.90 0.55 0.68 0.73\n", + "model_thr_0.2 0.96 0.47 0.63 0.70\n", + "model_thr_0.9 0.61 1.00 0.76 0.66\n", + "model_thr_0.1 1.00 0.18 0.31 0.56" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result['tuned_model'] = eval_result(rand_search_pred)\n", + "result.T.sort_values('accuracy', ascending=False)" + ] + }, + { + "cell_type": "markdown", + "id": "b084c123", + "metadata": {}, + "source": [ + "## Feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "44066cb9-2bcc-4450-935b-8c5fb86d48b6", + "metadata": {}, + "outputs": [], + "source": [ + "importances = forest.feature_importances_" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "b86599aa-8ef2-49c8-8aea-1674a1a9bc5a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "forest_importances = pd.Series(importances, index = X.columns)\n", + "forest_importances = forest_importances.sort_values(ascending=False)\n", + "fig, ax = plt.subplots()\n", + "forest_importances.plot.bar(ax=ax)\n", + "ax.set_title(\"Feature importances\")\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "1526680d-a94a-4335-8644-902d6a5751b4", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop the least important feature and compare the result with model 1" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "046d5bab-4f5b-47bc-b7c0-080359042bf6", + "metadata": {}, + "outputs": [], + "source": [ + "X_test.drop(['fbs'], axis=1, inplace=True)\n", + "X_train.drop(['fbs'], axis=1, inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "2f730335-849d-4fbd-b148-01fe746e2721", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n", + "5 fits failed out of a total of 50.\n", + "The score on these train-test partitions for these parameters will be set to nan.\n", + "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n", + "\n", + "Below are more details about the failures:\n", + "--------------------------------------------------------------------------------\n", + "5 fits failed with the following error:\n", + "Traceback (most recent call last):\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n", + " estimator.fit(X_train, y_train, **fit_params)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 476, in fit\n", + " trees = Parallel(\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 1043, in __call__\n", + " if self.dispatch_one_batch(iterator):\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 861, in dispatch_one_batch\n", + " self._dispatch(tasks)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 779, in _dispatch\n", + " job = self._backend.apply_async(batch, callback=cb)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n", + " result = ImmediateResult(func)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 572, in __init__\n", + " self.results = batch()\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in __call__\n", + " return [func(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in \n", + " return [func(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/fixes.py\", line 117, in __call__\n", + " return self.function(*args, **kwargs)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 189, in _parallel_build_trees\n", + " tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 969, in fit\n", + " super().fit(\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 265, in fit\n", + " check_scalar(\n", + " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py\", line 1480, in check_scalar\n", + " raise ValueError(\n", + "ValueError: min_samples_split == 1, must be >= 2.\n", + "\n", + " warnings.warn(some_fits_failed_message, FitFailedWarning)\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_search.py:953: UserWarning: One or more of the test scores are non-finite: [0.8393134 0.81085271 0.82026578 0.79667774 nan 0.82015504\n", + " 0.82513843 0.81085271 0.80586932 0.82048726]\n", + " warnings.warn(\n", + "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n", + " warn(\n" + ] + } + ], + "source": [ + "feat_model = rand_search_model.fit(X_train, y_train)\n", + "feat_pred = feat_model.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "2ff087f3-9c1d-4860-a3f1-069cc2c61dce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
precisionrecallf1_scoreaccuracy
model_thr_0.70.751.000.860.82
model_10.820.840.830.81
model_thr_0.50.820.840.830.81
model_thr_0.60.780.920.840.81
tuned_model0.820.820.820.80
features_tuned_model0.800.840.820.80
model_thr_0.40.850.690.760.77
model_thr_0.80.691.000.820.76
model_thr_0.30.900.550.680.73
model_thr_0.20.960.470.630.70
model_thr_0.90.611.000.760.66
model_thr_0.11.000.180.310.56
\n", + "
" + ], + "text/plain": [ + " precision recall f1_score accuracy\n", + "model_thr_0.7 0.75 1.00 0.86 0.82\n", + "model_1 0.82 0.84 0.83 0.81\n", + "model_thr_0.5 0.82 0.84 0.83 0.81\n", + "model_thr_0.6 0.78 0.92 0.84 0.81\n", + "tuned_model 0.82 0.82 0.82 0.80\n", + "features_tuned_model 0.80 0.84 0.82 0.80\n", + "model_thr_0.4 0.85 0.69 0.76 0.77\n", + "model_thr_0.8 0.69 1.00 0.82 0.76\n", + "model_thr_0.3 0.90 0.55 0.68 0.73\n", + "model_thr_0.2 0.96 0.47 0.63 0.70\n", + "model_thr_0.9 0.61 1.00 0.76 0.66\n", + "model_thr_0.1 1.00 0.18 0.31 0.56" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result['features_tuned_model'] = eval_result(feat_pred)\n", + "result.T.sort_values('accuracy', ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f57a68eb-efd5-41e8-8348-bffa40f1c2c2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}