From a3fb013513de459381a6721a3d7fd43f6f4d26ed Mon Sep 17 00:00:00 2001
From: PAlena <86962990+PAlena@users.noreply.github.com>
Date: Mon, 29 Apr 2024 22:12:19 +0200
Subject: [PATCH] Add files via upload
---
Seminar_5/DS_case_study_solved.ipynb | 1890 ++++++++++++++++++++++++++
1 file changed, 1890 insertions(+)
create mode 100644 Seminar_5/DS_case_study_solved.ipynb
diff --git a/Seminar_5/DS_case_study_solved.ipynb b/Seminar_5/DS_case_study_solved.ipynb
new file mode 100644
index 0000000..c85c09f
--- /dev/null
+++ b/Seminar_5/DS_case_study_solved.ipynb
@@ -0,0 +1,1890 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "2d667870",
+ "metadata": {},
+ "source": [
+ "# Heart Attack Predicting Case Study"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "be742eee",
+ "metadata": {},
+ "source": [
+ "## Problem definition: Predict whether a patient will have a heart attack or not"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c9c4ad8",
+ "metadata": {},
+ "source": [
+ "### Features:\n",
+ "Age : Age of the patient\n",
+ "\n",
+ "Sex : Sex of the patient\n",
+ "\n",
+ "exang: exercise induced angina (1 = yes; 0 = no)\n",
+ "\n",
+ "ca: number of major vessels (0-3)\n",
+ "\n",
+ "cp : Chest Pain type chest pain type\n",
+ "\n",
+ "Value 1: typical angina\n",
+ "Value 2: atypical angina\n",
+ "Value 3: non-anginal pain\n",
+ "Value 4: asymptomatic\n",
+ "trtbps : resting blood pressure (in mm Hg)\n",
+ "\n",
+ "chol : cholestoral in mg/dl fetched via BMI sensor\n",
+ "\n",
+ "fbs : (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)\n",
+ "\n",
+ "rest_ecg : resting electrocardiographic results\n",
+ "\n",
+ "Value 0: normal\n",
+ "Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)\n",
+ "Value 2: showing probable or definite left ventricular hypertrophy by Estes' criteria\n",
+ "thalach : maximum heart rate achieved\n",
+ "\n",
+ "target : 0= less chance of heart attack 1= more chance of heart attack"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "25294e0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "import sklearn\n",
+ "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.metrics import confusion_matrix\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0cd1b4db",
+ "metadata": {},
+ "source": [
+ "## Data preparation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "cad732a7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = pd.read_csv(\"data/heart.csv\")\n",
+ "o2_saturation = pd.read_csv(\"data/o2Saturation.csv\").head(303)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "190ade33",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " sex \n",
+ " cp \n",
+ " trtbps \n",
+ " chol \n",
+ " fbs \n",
+ " restecg \n",
+ " thalachh \n",
+ " exng \n",
+ " oldpeak \n",
+ " slp \n",
+ " caa \n",
+ " thall \n",
+ " output \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 63 \n",
+ " 1 \n",
+ " 3 \n",
+ " 145 \n",
+ " 233 \n",
+ " 1 \n",
+ " 0 \n",
+ " 150 \n",
+ " 0 \n",
+ " 2.3 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 37 \n",
+ " 1 \n",
+ " 2 \n",
+ " 130 \n",
+ " 250 \n",
+ " 0 \n",
+ " 1 \n",
+ " 187 \n",
+ " 0 \n",
+ " 3.5 \n",
+ " 0 \n",
+ " 0 \n",
+ " 2 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 41 \n",
+ " 0 \n",
+ " 1 \n",
+ " 130 \n",
+ " 204 \n",
+ " 0 \n",
+ " 0 \n",
+ " 172 \n",
+ " 0 \n",
+ " 1.4 \n",
+ " 2 \n",
+ " 0 \n",
+ " 2 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 56 \n",
+ " 1 \n",
+ " 1 \n",
+ " 120 \n",
+ " 236 \n",
+ " 0 \n",
+ " 1 \n",
+ " 178 \n",
+ " 0 \n",
+ " 0.8 \n",
+ " 2 \n",
+ " 0 \n",
+ " 2 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 57 \n",
+ " 0 \n",
+ " 0 \n",
+ " 120 \n",
+ " 354 \n",
+ " 0 \n",
+ " 1 \n",
+ " 163 \n",
+ " 1 \n",
+ " 0.6 \n",
+ " 2 \n",
+ " 0 \n",
+ " 2 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age sex cp trtbps chol fbs restecg thalachh exng oldpeak slp \\\n",
+ "0 63 1 3 145 233 1 0 150 0 2.3 0 \n",
+ "1 37 1 2 130 250 0 1 187 0 3.5 0 \n",
+ "2 41 0 1 130 204 0 0 172 0 1.4 2 \n",
+ "3 56 1 1 120 236 0 1 178 0 0.8 2 \n",
+ "4 57 0 0 120 354 0 1 163 1 0.6 2 \n",
+ "\n",
+ " caa thall output \n",
+ "0 0 1 1 \n",
+ "1 0 2 1 \n",
+ "2 0 2 1 \n",
+ "3 0 2 1 \n",
+ "4 0 2 1 "
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9c2711c0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "RangeIndex: 303 entries, 0 to 302\n",
+ "Data columns (total 14 columns):\n",
+ " # Column Non-Null Count Dtype \n",
+ "--- ------ -------------- ----- \n",
+ " 0 age 303 non-null int64 \n",
+ " 1 sex 303 non-null int64 \n",
+ " 2 cp 303 non-null int64 \n",
+ " 3 trtbps 303 non-null int64 \n",
+ " 4 chol 303 non-null int64 \n",
+ " 5 fbs 303 non-null int64 \n",
+ " 6 restecg 303 non-null int64 \n",
+ " 7 thalachh 303 non-null int64 \n",
+ " 8 exng 303 non-null int64 \n",
+ " 9 oldpeak 303 non-null float64\n",
+ " 10 slp 303 non-null int64 \n",
+ " 11 caa 303 non-null int64 \n",
+ " 12 thall 303 non-null int64 \n",
+ " 13 output 303 non-null int64 \n",
+ "dtypes: float64(1), int64(13)\n",
+ "memory usage: 33.3 KB\n"
+ ]
+ }
+ ],
+ "source": [
+ "data.info()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "1489c727",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " sex \n",
+ " cp \n",
+ " trtbps \n",
+ " chol \n",
+ " fbs \n",
+ " restecg \n",
+ " thalachh \n",
+ " exng \n",
+ " oldpeak \n",
+ " slp \n",
+ " caa \n",
+ " thall \n",
+ " output \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " count \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " 303.000000 \n",
+ " \n",
+ " \n",
+ " mean \n",
+ " 54.366337 \n",
+ " 0.683168 \n",
+ " 0.966997 \n",
+ " 131.623762 \n",
+ " 246.264026 \n",
+ " 0.148515 \n",
+ " 0.528053 \n",
+ " 149.646865 \n",
+ " 0.326733 \n",
+ " 1.039604 \n",
+ " 1.399340 \n",
+ " 0.729373 \n",
+ " 2.313531 \n",
+ " 0.544554 \n",
+ " \n",
+ " \n",
+ " std \n",
+ " 9.082101 \n",
+ " 0.466011 \n",
+ " 1.032052 \n",
+ " 17.538143 \n",
+ " 51.830751 \n",
+ " 0.356198 \n",
+ " 0.525860 \n",
+ " 22.905161 \n",
+ " 0.469794 \n",
+ " 1.161075 \n",
+ " 0.616226 \n",
+ " 1.022606 \n",
+ " 0.612277 \n",
+ " 0.498835 \n",
+ " \n",
+ " \n",
+ " min \n",
+ " 29.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 94.000000 \n",
+ " 126.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 71.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " \n",
+ " \n",
+ " 25% \n",
+ " 47.500000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 120.000000 \n",
+ " 211.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 133.500000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 1.000000 \n",
+ " 0.000000 \n",
+ " 2.000000 \n",
+ " 0.000000 \n",
+ " \n",
+ " \n",
+ " 50% \n",
+ " 55.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 130.000000 \n",
+ " 240.000000 \n",
+ " 0.000000 \n",
+ " 1.000000 \n",
+ " 153.000000 \n",
+ " 0.000000 \n",
+ " 0.800000 \n",
+ " 1.000000 \n",
+ " 0.000000 \n",
+ " 2.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 75% \n",
+ " 61.000000 \n",
+ " 1.000000 \n",
+ " 2.000000 \n",
+ " 140.000000 \n",
+ " 274.500000 \n",
+ " 0.000000 \n",
+ " 1.000000 \n",
+ " 166.000000 \n",
+ " 1.000000 \n",
+ " 1.600000 \n",
+ " 2.000000 \n",
+ " 1.000000 \n",
+ " 3.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " max \n",
+ " 77.000000 \n",
+ " 1.000000 \n",
+ " 3.000000 \n",
+ " 200.000000 \n",
+ " 564.000000 \n",
+ " 1.000000 \n",
+ " 2.000000 \n",
+ " 202.000000 \n",
+ " 1.000000 \n",
+ " 6.200000 \n",
+ " 2.000000 \n",
+ " 4.000000 \n",
+ " 3.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age sex cp trtbps chol fbs \\\n",
+ "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n",
+ "mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 \n",
+ "std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 \n",
+ "min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 \n",
+ "25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 \n",
+ "50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 \n",
+ "75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 \n",
+ "max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 \n",
+ "\n",
+ " restecg thalachh exng oldpeak slp caa \\\n",
+ "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n",
+ "mean 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 \n",
+ "std 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 \n",
+ "min 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 \n",
+ "25% 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 \n",
+ "50% 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 \n",
+ "75% 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 \n",
+ "max 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 \n",
+ "\n",
+ " thall output \n",
+ "count 303.000000 303.000000 \n",
+ "mean 2.313531 0.544554 \n",
+ "std 0.612277 0.498835 \n",
+ "min 0.000000 0.000000 \n",
+ "25% 2.000000 0.000000 \n",
+ "50% 2.000000 1.000000 \n",
+ "75% 3.000000 1.000000 \n",
+ "max 3.000000 1.000000 "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.describe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "17e69c68",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# what is the distribution of age column?\n",
+ "data[\"age\"].hist(bins=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "caba3e33",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Numbers if man: 0.68\n",
+ "Numbers if woman: 0.32\n"
+ ]
+ }
+ ],
+ "source": [
+ "# What is the share of wemales in the sample? What is the share of males?\n",
+ "print(f'Numbers if man: {round(data.sex.sum()/len(data),2)}')\n",
+ "print(f'Numbers if woman: {round(1 - data.sex.sum()/len(data),2)}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "f0dd15a8-1502-40c5-8522-2b3737f89bac",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Add o2_saturation data in the data frame 'data'\n",
+ "# o2_saturation.rename(columns={'98.6':'o2_saturation'}, inplace = True)\n",
+ "# data = pd.concat([data, o2_saturation], axis=1)\n",
+ "data[\"o2_Saturation\"] = o2_saturation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "195ae311",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "corr = data.corr()\n",
+ "\n",
+ "plt.figure(figsize=(15, 10))\n",
+ "sns.heatmap(corr, annot=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "076fb645",
+ "metadata": {},
+ "source": [
+ "## Random Forest Classifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "0fe76502",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Specify X and y and split the dataset\n",
+ "X = data.drop(\"output\", axis=1)\n",
+ "y = data[\"output\"]\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state = 14)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "183fd155",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((212, 14), (91, 14), (212,), (91,))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train.shape, X_test.shape, y_train.shape, y_test.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "077d86fe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# specify model\n",
+ "forest = RandomForestClassifier(random_state = 14)\n",
+ "\n",
+ "# train model\n",
+ "forest.fit(X_train, y_train)\n",
+ "\n",
+ "# make prediction\n",
+ "forest_preds = forest.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "40e251ea-7fcd-4aee-b638-f29309b68b89",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1,\n",
+ " 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,\n",
+ " 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,\n",
+ " 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0,\n",
+ " 0, 1, 1])"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "forest_preds"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "080c566f-6f78-43d2-8782-008a738e0697",
+ "metadata": {},
+ "source": [
+ "## Evaluate the result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "525bd174-6241-4d14-946a-f41fe31e6c7b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "41\n",
+ "8\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Manually calculate number of True positive and False negative predictions\n",
+ "Y = pd.DataFrame()\n",
+ "Y['real'] = y_test\n",
+ "Y['pred'] = forest_preds\n",
+ "print(len(Y[(Y.real==1)&(Y.pred==1)]))\n",
+ "print(len(Y[(Y.real==1)&(Y.pred==0)]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "e5252037-b250-4ae4-99c5-1138bfd537e1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Text(0.5, 1.0, 'Confusion Matrix')"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sns.heatmap(confusion_matrix(y_test, forest_preds), annot=True)\n",
+ "plt.xlabel(\"Predicted Labels\")\n",
+ "plt.ylabel(\"Actual Labels\")\n",
+ "plt.title(\"Confusion Matrix\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "3d3f1e98-c3cd-428d-aa14-d099db19e3a6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Based on the confusion matrics fill:\n",
+ "true_positive = 41\n",
+ "true_negative = 33\n",
+ "false_positive = 9\n",
+ "false_negative = 8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "a5f55d8d-f413-4d0a-9b7b-23efc31c257d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.82\n",
+ "0.82\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Precision: measure of how many of the positive predictions made are correct (true positives)\n",
+ "print(round(true_positive/(true_positive + false_positive),2))\n",
+ "print(round(sklearn.metrics.precision_score(y_test, forest_preds), 2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "979fb237-0771-4585-ae70-74a8b4db6077",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.84\n",
+ "0.84\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Recall: is a measure of how many of the positive cases the classifier correctly predicted\n",
+ "print(round(true_positive/(true_positive + false_negative),2))\n",
+ "print(round(sklearn.metrics.recall_score(y_test, forest_preds), 2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "008cd394-6520-4888-863f-6463fe3fa9ec",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.83\n",
+ "0.83\n"
+ ]
+ }
+ ],
+ "source": [
+ "# F1-Score: is a measure combining both precision and recall 2(prec*recall)/(prec+recall)\n",
+ "print(round(2*(0.82*0.84)/(0.82+0.84),2))\n",
+ "print(round(sklearn.metrics.f1_score(y_test, forest_preds), 2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "c6fac62f-c5a8-4ba5-b7ee-82ec647b9ef0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0.81\n",
+ "0.81\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Accuracy: describing the number of correct predictions over all predictions\n",
+ "print(round((true_positive+true_negative)/(true_positive+true_negative+false_positive+false_negative),2))\n",
+ "print(round(sklearn.metrics.accuracy_score(y_test, forest_preds), 2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e43acac7-9bca-461c-9e32-8542a3fd04ef",
+ "metadata": {},
+ "source": [
+ "## Save the result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "922aecda-8067-4b9c-8f2c-ad83e345b42a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " precision \n",
+ " \n",
+ " \n",
+ " recall \n",
+ " \n",
+ " \n",
+ " f1_score \n",
+ " \n",
+ " \n",
+ " accuracy \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "Empty DataFrame\n",
+ "Columns: []\n",
+ "Index: [precision, recall, f1_score, accuracy]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# create empty df with indexes: ['precision','recall', 'f1_score','accuracy']\n",
+ "result = pd.DataFrame(index=['precision','recall', 'f1_score','accuracy'])\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "1469766a-232b-4931-a7d6-43ee8b22141c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# write a function which will return 4 metrics:\n",
+ "# [round(sklearn.metrics.precision_score(y_test, forest_preds), 2),round(sklearn.metrics.recall_score(y_test, forest_preds), 2),round(sklearn.metrics.f1_score(y_test, forest_preds), 2),round(sklearn.metrics.accuracy_score(y_test, forest_preds), 2)]\n",
+ "\n",
+ "\n",
+ "def eval_result(x_predicted):\n",
+ " return [round(sklearn.metrics.precision_score(y_test, x_predicted), 2),\n",
+ " round(sklearn.metrics.recall_score(y_test, x_predicted), 2),\n",
+ " round(sklearn.metrics.f1_score(y_test, x_predicted), 2),\n",
+ " round(sklearn.metrics.accuracy_score(y_test, x_predicted), 2)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "ce75bce6-17e4-46fc-bd25-c31bdc178d10",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " model_1 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " precision \n",
+ " 0.82 \n",
+ " \n",
+ " \n",
+ " recall \n",
+ " 0.84 \n",
+ " \n",
+ " \n",
+ " f1_score \n",
+ " 0.83 \n",
+ " \n",
+ " \n",
+ " accuracy \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " model_1\n",
+ "precision 0.82\n",
+ "recall 0.84\n",
+ "f1_score 0.83\n",
+ "accuracy 0.81"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# add the result of model_1 into the df 'result'\n",
+ "result['model_1'] = eval_result(forest_preds)\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e6c6db57-954c-4449-bc97-5c3ef07e2020",
+ "metadata": {},
+ "source": [
+ "### Change the threshold"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "cabae7b5-dc23-41bc-8cba-f9ba3e1df5ae",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 1, 0, 0, 1])"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "forest.predict(X_test)[:5]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "bd4b584e-9b3d-47b2-ae9e-396c6f6af482",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[0.82, 0.18],\n",
+ " [0.37, 0.63],\n",
+ " [0.81, 0.19],\n",
+ " [0.79, 0.21],\n",
+ " [0.33, 0.67]])"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "forest.predict_proba(X_test)[:5]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0df4b875-4107-4ce8-837a-1c9b5bfcb669",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# get the result for a threshold = 0.7"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "21754b98-1b7b-43f4-b2ad-18507ba9c86d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# get the result for any threshold from 0.1 till 0.9 and add it into 'result' df\n",
+ "res_proba = forest.predict_proba(X_test)\n",
+ "\n",
+ "for t in [x/10 for x in range(1, 10)]:\n",
+ " res_pred = [0 if i[0]>=t else 1 for i in res_proba]\n",
+ " result[f'model_thr_{t}'] = eval_result(res_pred)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "3b87475b-d862-44a7-a667-5962099483ee",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " precision \n",
+ " recall \n",
+ " f1_score \n",
+ " accuracy \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " model_thr_0.7 \n",
+ " 0.75 \n",
+ " 1.00 \n",
+ " 0.86 \n",
+ " 0.82 \n",
+ " \n",
+ " \n",
+ " model_1 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.5 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.6 \n",
+ " 0.78 \n",
+ " 0.92 \n",
+ " 0.84 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.4 \n",
+ " 0.85 \n",
+ " 0.69 \n",
+ " 0.76 \n",
+ " 0.77 \n",
+ " \n",
+ " \n",
+ " model_thr_0.8 \n",
+ " 0.69 \n",
+ " 1.00 \n",
+ " 0.82 \n",
+ " 0.76 \n",
+ " \n",
+ " \n",
+ " model_thr_0.3 \n",
+ " 0.90 \n",
+ " 0.55 \n",
+ " 0.68 \n",
+ " 0.73 \n",
+ " \n",
+ " \n",
+ " model_thr_0.2 \n",
+ " 0.96 \n",
+ " 0.47 \n",
+ " 0.63 \n",
+ " 0.70 \n",
+ " \n",
+ " \n",
+ " model_thr_0.9 \n",
+ " 0.61 \n",
+ " 1.00 \n",
+ " 0.76 \n",
+ " 0.66 \n",
+ " \n",
+ " \n",
+ " model_thr_0.1 \n",
+ " 1.00 \n",
+ " 0.18 \n",
+ " 0.31 \n",
+ " 0.56 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall f1_score accuracy\n",
+ "model_thr_0.7 0.75 1.00 0.86 0.82\n",
+ "model_1 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.5 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.6 0.78 0.92 0.84 0.81\n",
+ "model_thr_0.4 0.85 0.69 0.76 0.77\n",
+ "model_thr_0.8 0.69 1.00 0.82 0.76\n",
+ "model_thr_0.3 0.90 0.55 0.68 0.73\n",
+ "model_thr_0.2 0.96 0.47 0.63 0.70\n",
+ "model_thr_0.9 0.61 1.00 0.76 0.66\n",
+ "model_thr_0.1 1.00 0.18 0.31 0.56"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "result.T.sort_values('accuracy', ascending=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "45d0fcc7",
+ "metadata": {},
+ "source": [
+ "### Using RandomizedSearchCV for hyperparameter turning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "7c1501e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "estimator = RandomForestClassifier()\n",
+ "grid = {\"n_estimators\": [80, 90, 100, 110, 120],\n",
+ " \"max_depth\": [5, 10, 15],\n",
+ " \"max_features\" : [\"auto\", \"sqrt\", \"log2\"],\n",
+ " \"min_samples_split\": [2, 1, 3, 4],\n",
+ " \"min_samples_leaf\": [1, 2, 3, 4]}\n",
+ "\n",
+ "rand_search_model = RandomizedSearchCV(estimator=estimator,\n",
+ " param_distributions=grid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "a9188e30",
+ "metadata": {
+ "collapsed": true,
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n",
+ "5 fits failed out of a total of 50.\n",
+ "The score on these train-test partitions for these parameters will be set to nan.\n",
+ "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
+ "\n",
+ "Below are more details about the failures:\n",
+ "--------------------------------------------------------------------------------\n",
+ "5 fits failed with the following error:\n",
+ "Traceback (most recent call last):\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
+ " estimator.fit(X_train, y_train, **fit_params)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 476, in fit\n",
+ " trees = Parallel(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 1043, in __call__\n",
+ " if self.dispatch_one_batch(iterator):\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 861, in dispatch_one_batch\n",
+ " self._dispatch(tasks)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 779, in _dispatch\n",
+ " job = self._backend.apply_async(batch, callback=cb)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n",
+ " result = ImmediateResult(func)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 572, in __init__\n",
+ " self.results = batch()\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in __call__\n",
+ " return [func(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in \n",
+ " return [func(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/fixes.py\", line 117, in __call__\n",
+ " return self.function(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 189, in _parallel_build_trees\n",
+ " tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 969, in fit\n",
+ " super().fit(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 265, in fit\n",
+ " check_scalar(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py\", line 1480, in check_scalar\n",
+ " raise ValueError(\n",
+ "ValueError: min_samples_split == 1, must be >= 2.\n",
+ "\n",
+ " warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_search.py:953: UserWarning: One or more of the test scores are non-finite: [0.81550388 0.8013289 0.81572536 0.82026578 nan 0.81539313\n",
+ " 0.82026578 0.81085271 0.82026578 0.81096346]\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Retrain the model with tuned parameters and compare with model 1\n",
+ "rand_search_model.fit(X_train, y_train)\n",
+ "rand_search_pred = rand_search_model.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "acbb111b-8b61-46b2-92b3-511588758d0f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'n_estimators': 90,\n",
+ " 'min_samples_split': 3,\n",
+ " 'min_samples_leaf': 3,\n",
+ " 'max_features': 'log2',\n",
+ " 'max_depth': 10}"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rand_search_model.best_params_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "10b109c8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " precision \n",
+ " recall \n",
+ " f1_score \n",
+ " accuracy \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " model_thr_0.7 \n",
+ " 0.75 \n",
+ " 1.00 \n",
+ " 0.86 \n",
+ " 0.82 \n",
+ " \n",
+ " \n",
+ " model_1 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.5 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.6 \n",
+ " 0.78 \n",
+ " 0.92 \n",
+ " 0.84 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " tuned_model \n",
+ " 0.82 \n",
+ " 0.82 \n",
+ " 0.82 \n",
+ " 0.80 \n",
+ " \n",
+ " \n",
+ " model_thr_0.4 \n",
+ " 0.85 \n",
+ " 0.69 \n",
+ " 0.76 \n",
+ " 0.77 \n",
+ " \n",
+ " \n",
+ " model_thr_0.8 \n",
+ " 0.69 \n",
+ " 1.00 \n",
+ " 0.82 \n",
+ " 0.76 \n",
+ " \n",
+ " \n",
+ " model_thr_0.3 \n",
+ " 0.90 \n",
+ " 0.55 \n",
+ " 0.68 \n",
+ " 0.73 \n",
+ " \n",
+ " \n",
+ " model_thr_0.2 \n",
+ " 0.96 \n",
+ " 0.47 \n",
+ " 0.63 \n",
+ " 0.70 \n",
+ " \n",
+ " \n",
+ " model_thr_0.9 \n",
+ " 0.61 \n",
+ " 1.00 \n",
+ " 0.76 \n",
+ " 0.66 \n",
+ " \n",
+ " \n",
+ " model_thr_0.1 \n",
+ " 1.00 \n",
+ " 0.18 \n",
+ " 0.31 \n",
+ " 0.56 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall f1_score accuracy\n",
+ "model_thr_0.7 0.75 1.00 0.86 0.82\n",
+ "model_1 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.5 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.6 0.78 0.92 0.84 0.81\n",
+ "tuned_model 0.82 0.82 0.82 0.80\n",
+ "model_thr_0.4 0.85 0.69 0.76 0.77\n",
+ "model_thr_0.8 0.69 1.00 0.82 0.76\n",
+ "model_thr_0.3 0.90 0.55 0.68 0.73\n",
+ "model_thr_0.2 0.96 0.47 0.63 0.70\n",
+ "model_thr_0.9 0.61 1.00 0.76 0.66\n",
+ "model_thr_0.1 1.00 0.18 0.31 0.56"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "result['tuned_model'] = eval_result(rand_search_pred)\n",
+ "result.T.sort_values('accuracy', ascending=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b084c123",
+ "metadata": {},
+ "source": [
+ "## Feature importance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "44066cb9-2bcc-4450-935b-8c5fb86d48b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "importances = forest.feature_importances_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "b86599aa-8ef2-49c8-8aea-1674a1a9bc5a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "forest_importances = pd.Series(importances, index = X.columns)\n",
+ "forest_importances = forest_importances.sort_values(ascending=False)\n",
+ "fig, ax = plt.subplots()\n",
+ "forest_importances.plot.bar(ax=ax)\n",
+ "ax.set_title(\"Feature importances\")\n",
+ "fig.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "1526680d-a94a-4335-8644-902d6a5751b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Drop the least important feature and compare the result with model 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "046d5bab-4f5b-47bc-b7c0-080359042bf6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_test.drop(['fbs'], axis=1, inplace=True)\n",
+ "X_train.drop(['fbs'], axis=1, inplace=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "2f730335-849d-4fbd-b148-01fe746e2721",
+ "metadata": {
+ "collapsed": true,
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n",
+ "5 fits failed out of a total of 50.\n",
+ "The score on these train-test partitions for these parameters will be set to nan.\n",
+ "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
+ "\n",
+ "Below are more details about the failures:\n",
+ "--------------------------------------------------------------------------------\n",
+ "5 fits failed with the following error:\n",
+ "Traceback (most recent call last):\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
+ " estimator.fit(X_train, y_train, **fit_params)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 476, in fit\n",
+ " trees = Parallel(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 1043, in __call__\n",
+ " if self.dispatch_one_batch(iterator):\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 861, in dispatch_one_batch\n",
+ " self._dispatch(tasks)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 779, in _dispatch\n",
+ " job = self._backend.apply_async(batch, callback=cb)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n",
+ " result = ImmediateResult(func)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/_parallel_backends.py\", line 572, in __init__\n",
+ " self.results = batch()\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in __call__\n",
+ " return [func(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/joblib/parallel.py\", line 262, in \n",
+ " return [func(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/fixes.py\", line 117, in __call__\n",
+ " return self.function(*args, **kwargs)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py\", line 189, in _parallel_build_trees\n",
+ " tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 969, in fit\n",
+ " super().fit(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/tree/_classes.py\", line 265, in fit\n",
+ " check_scalar(\n",
+ " File \"/usr/local/lib/python3.8/dist-packages/sklearn/utils/validation.py\", line 1480, in check_scalar\n",
+ " raise ValueError(\n",
+ "ValueError: min_samples_split == 1, must be >= 2.\n",
+ "\n",
+ " warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/model_selection/_search.py:953: UserWarning: One or more of the test scores are non-finite: [0.8393134 0.81085271 0.82026578 0.79667774 nan 0.82015504\n",
+ " 0.82513843 0.81085271 0.80586932 0.82048726]\n",
+ " warnings.warn(\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/ensemble/_forest.py:427: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'` or remove this parameter as it is also the default value for RandomForestClassifiers and ExtraTreesClassifiers.\n",
+ " warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "feat_model = rand_search_model.fit(X_train, y_train)\n",
+ "feat_pred = feat_model.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "2ff087f3-9c1d-4860-a3f1-069cc2c61dce",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " precision \n",
+ " recall \n",
+ " f1_score \n",
+ " accuracy \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " model_thr_0.7 \n",
+ " 0.75 \n",
+ " 1.00 \n",
+ " 0.86 \n",
+ " 0.82 \n",
+ " \n",
+ " \n",
+ " model_1 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.5 \n",
+ " 0.82 \n",
+ " 0.84 \n",
+ " 0.83 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " model_thr_0.6 \n",
+ " 0.78 \n",
+ " 0.92 \n",
+ " 0.84 \n",
+ " 0.81 \n",
+ " \n",
+ " \n",
+ " tuned_model \n",
+ " 0.82 \n",
+ " 0.82 \n",
+ " 0.82 \n",
+ " 0.80 \n",
+ " \n",
+ " \n",
+ " features_tuned_model \n",
+ " 0.80 \n",
+ " 0.84 \n",
+ " 0.82 \n",
+ " 0.80 \n",
+ " \n",
+ " \n",
+ " model_thr_0.4 \n",
+ " 0.85 \n",
+ " 0.69 \n",
+ " 0.76 \n",
+ " 0.77 \n",
+ " \n",
+ " \n",
+ " model_thr_0.8 \n",
+ " 0.69 \n",
+ " 1.00 \n",
+ " 0.82 \n",
+ " 0.76 \n",
+ " \n",
+ " \n",
+ " model_thr_0.3 \n",
+ " 0.90 \n",
+ " 0.55 \n",
+ " 0.68 \n",
+ " 0.73 \n",
+ " \n",
+ " \n",
+ " model_thr_0.2 \n",
+ " 0.96 \n",
+ " 0.47 \n",
+ " 0.63 \n",
+ " 0.70 \n",
+ " \n",
+ " \n",
+ " model_thr_0.9 \n",
+ " 0.61 \n",
+ " 1.00 \n",
+ " 0.76 \n",
+ " 0.66 \n",
+ " \n",
+ " \n",
+ " model_thr_0.1 \n",
+ " 1.00 \n",
+ " 0.18 \n",
+ " 0.31 \n",
+ " 0.56 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " precision recall f1_score accuracy\n",
+ "model_thr_0.7 0.75 1.00 0.86 0.82\n",
+ "model_1 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.5 0.82 0.84 0.83 0.81\n",
+ "model_thr_0.6 0.78 0.92 0.84 0.81\n",
+ "tuned_model 0.82 0.82 0.82 0.80\n",
+ "features_tuned_model 0.80 0.84 0.82 0.80\n",
+ "model_thr_0.4 0.85 0.69 0.76 0.77\n",
+ "model_thr_0.8 0.69 1.00 0.82 0.76\n",
+ "model_thr_0.3 0.90 0.55 0.68 0.73\n",
+ "model_thr_0.2 0.96 0.47 0.63 0.70\n",
+ "model_thr_0.9 0.61 1.00 0.76 0.66\n",
+ "model_thr_0.1 1.00 0.18 0.31 0.56"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "result['features_tuned_model'] = eval_result(feat_pred)\n",
+ "result.T.sort_values('accuracy', ascending=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f57a68eb-efd5-41e8-8348-bffa40f1c2c2",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}