diff --git a/logistic_regression_class/logistic_visualize.py b/logistic_regression_class/logistic_visualize.py index 65289535..ba84ffe7 100644 --- a/logistic_regression_class/logistic_visualize.py +++ b/logistic_regression_class/logistic_visualize.py @@ -44,7 +44,10 @@ def sigmoid(z): z = Xb.dot(w) Y = sigmoid(z) -plt.scatter(X[:,0], X[:,1], c=T, s=100, alpha=0.5) +# make colors more visible +plt_colors = [f'C{i+1}' for i in T] + +plt.scatter(X[:,0], X[:,1], c=plt_colors, s=100, alpha=0.5) x_axis = np.linspace(-6, 6, 100) y_axis = -x_axis