-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
114 lines (89 loc) · 3.94 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
import pandas as pd
import pickle
import matplotlib.pyplot as plt
# Load the trained model
model = tf.keras.models.load_model('model.h5')
# Load the encoders and scaler
with open('label_encoder_gender.pkl', 'rb') as file:
label_encoder_gender = pickle.load(file)
with open('onehot_encoder_geo.pkl', 'rb') as file:
onehot_encoder_geo = pickle.load(file)
with open('scaler.pkl', 'rb') as file:
scaler = pickle.load(file)
# Set Streamlit page configuration
st.set_page_config(page_title="Churn Predictor", page_icon='Fardeen NB - Logo.png', layout="centered", initial_sidebar_state="auto", menu_items=None)
# App title
st.title('Customer Churn Prediction Project')
# Create two columns for input layout
col1, spacer, col2 = st.columns([1, 0.2, 1])
# Input elements arranged side by side in the two columns
with col1:
geography = st.selectbox('Geography', onehot_encoder_geo.categories_[0])
gender = st.selectbox('Gender', label_encoder_gender.classes_)
age = st.slider('Age', 18, 92)
balance = st.number_input('Balance')
credit_score = st.number_input('Credit Score')
with col2:
estimated_salary = st.number_input('Estimated Salary')
tenure = st.slider('Tenure', 0, 10)
num_of_products = st.slider('Number of Products', 1, 4)
has_cr_card = st.selectbox('Has Credit Card', [0, 1])
is_active_member = st.selectbox('Is Active Member', [0, 1])
# Prepare the input data for prediction
input_data = pd.DataFrame({
'CreditScore': [credit_score],
'Gender': [label_encoder_gender.transform([gender])[0]],
'Age': [age],
'Tenure': [tenure],
'Balance': [balance],
'NumOfProducts': [num_of_products],
'HasCrCard': [has_cr_card],
'IsActiveMember': [is_active_member],
'EstimatedSalary': [estimated_salary]
})
# One-hot encode 'Geography'
geo_encoded = onehot_encoder_geo.transform([[geography]]).toarray()
geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography']))
# Combine one-hot encoded columns with input data
input_data = pd.concat([input_data.reset_index(drop=True), geo_encoded_df], axis=1)
# Scale the input data
input_data_scaled = scaler.transform(input_data)
# Predict churn probability
prediction = model.predict(input_data_scaled)
prediction_proba = prediction[0][0]
# Display churn probability
st.write( f'Churn Probability: {prediction_proba * 100:.2f}%')
# Churn likelihood output
if prediction_proba > 0.3:
st.markdown("## The customer is likely to churn.")
else:
st.markdown("## The customer is not likely to churn.")
# Plotting churn probability
# Create a plot using matplotlib
fig, ax = plt.subplots()
# Set y-axis limit to 100 for distinction
ax.set_ylim(0, 100)
# Set background color to black
fig.patch.set_facecolor('black')
ax.set_facecolor('black')
# Plot the churn percentage as a bar chart with rounded corners
churn_percentage = prediction_proba * 100
bars = ax.bar(['Churn Probability'], [churn_percentage], color='darkblue', edgecolor='darkblue', linewidth=0, width=0.5)
# Adding labels and title with white font
ax.set_ylabel('Percentage (%)', color='white')
ax.set_title('Customer Churn Probability', color='white')
# Customize tick colors
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
# Show percentage on the bar with white font and padding on top
for i, v in enumerate([churn_percentage]):
ax.text(i, v + 3, f"{v:.2f}%", ha='center', color='white')
# Add watermark "© Copyrighted to https://fardeen.net ™" in pure white
watermark_text = "© 2024 Fardeen NB. All Rights Reserved. https://fardeen.net ™"
fig.text(0.5, 0.001, watermark_text, fontsize=10, color='white', ha='center', alpha=0.8)
# Display the plot in Streamlit
st.pyplot(fig)