-
Notifications
You must be signed in to change notification settings - Fork 4
/
Image Segmentation.py
62 lines (56 loc) · 2.38 KB
/
Image Segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 30 02:12:29 2019
@author: Mithilesh
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
plt.style.use("default")
class Segmentation:
def __init__(self,image,dom_colors):
#since cv2 reads image in BGR mode , it is necessary to convert into RGB mode.
self.image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
self.orginal_size=image.shape
#flatenning the image in order to make it compatible for KMeans class
self.pixel_array=self.image.reshape((-1,3))
self.dom_colors=dom_colors
#Created an instance of KMeans
self.km=KMeans(n_clusters=self.dom_colors)
#created a model here
self.km.fit(self.pixel_array)
#this program helps to extract out the dominant colors from the image
def dominant_colors(self):
#taking out the centers
self.centers=np.array(self.km.cluster_centers_,dtype='uint8')
self.colors=[]
plt.figure(0,(4,4))
plt.axis("off")
var=1
for current_center in self.centers:
plt.subplot(1,self.dom_colors+1,var)
self.colors.append(current_center)
#created an array to store data of each dominant color
color_array=np.zeros((100,100,3),dtype='uint8')
color_array[:,:,:]=current_center
plt.imshow(color_array)
var+=1
plt.show()
#this fucntion draws the image with the given dominant colors
def draw_image(self):
self.centers=np.array(self.km.cluster_centers_,dtype='uint8')
#here predict function gives label to each point , i.e the given point is nearer to which center
pred=self.km.predict(self.pixel_array)
#creaed an empty array to store data of image
new_image=np.zeros((self.image.shape[0]*self.image.shape[1],3),dtype='uint8')
for i in range(new_image.shape[0]):
new_image[i]=self.centers[pred[i]]
#new_image is reshaped into original size in order to get whole image together
new_image=new_image.reshape(self.orginal_size)
plt.axis("off")
plt.title("No. of colors : "+str(self.dom_colors))
plt.imshow(new_image)
plt.show()
img=cv2.imread("Mickey.jpg")
IS=Segmentation(img,5)