-
Notifications
You must be signed in to change notification settings - Fork 0
/
Mean_Shift.py
108 lines (98 loc) · 4.26 KB
/
Mean_Shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
The algorithm process of MS(Mean Shift)
"""
import numpy
from Algo1 import Algo1
from Algo2 import Algo2
import numpy as np
import math
import sys
class MeanShift(object):
def __init__(self, data_path, algo_name, bandwidth, threshold):
self.data_path = data_path
self.algo_name = algo_name
if self.algo_name == 'Algo1':
self.alg = Algo1()
elif self.algo_name == 'Algo2':
self.alg = Algo2()
self.bandwidth = bandwidth
self.threshold = threshold
def ms_process(self):
# read data from csv and store in `data`
# for `Algo1` : get the processed data
# for `Algo2` : get the original data
data = self.alg.get_data(self.data_path)
original_points = np.array(data)
shifting_points = np.array(data)
# we need to get the circular-linear raw data when plotting, so we need to return the raw data
# for `algo2`, theta_raw_points is just the original_points
# for `algo1`, original_points is the raw points with the format of (theta, r)
if self.algo_name == 'Algo1':
theta_raw_points = np.array(self.alg._get_original_data(self.data_path))
else :
theta_raw_points = np.array(data)
# initialize a max_distance greater than threshold
max_distance = self.threshold + 1
# flag to reveal whether the point need to iterate
end_flag = [False] * original_points.shape[0]
# record the iteratoring time
iteration_times = 0
# if not set the bandwidth, set the bandwidth by N
if self.bandwidth is None:
self.bandwidth = self._compute_bandwidth(original_points)
# the loop for mean shift
while max_distance > self.threshold:
iteration_times += 1
print("iteration times =", iteration_times, ",", "max_distance =", max_distance)
# update the points in shifting_points simultaneously
for i in range(len(original_points)):
max_distance = 0
# the ith point has already converge
if end_flag[i] : continue
# the old point in shifting array
p_old = shifting_points[i]
# get the new point after one iteration
p_new = self._shift_point(p_old, original_points, self.bandwidth)
old_new_distance = self.alg.calculate_distance(p_new, p_old)
# cal the distance of old point and new point, compare it with threshold
# get the max distance in the shifting points
if old_new_distance > max_distance:
max_distance = old_new_distance
if old_new_distance < self.threshold:
end_flag[i] = True
shifting_points[i] = p_new
return theta_raw_points, original_points, shifting_points
# compute the bandwidth if the `self.bandwidth` is None
def _compute_bandwidth(self, points):
N = len(points)
# get the mean and variance the total points
_, var = self.alg.calculate_mean_point(points)
# compute the standard standard deviation
data_std = math.sqrt(var)
print("set bandwidth=",(1.05 * data_std) * (pow(N, -0.2)))
return (1.05 * data_std) * (pow(N, -0.2))
# Use the MS to generate new point
def _shift_point(self, p_old, points, bandwidth):
weights = self._kernel(p_old, points, bandwidth)
p_new_x = float(0)
p_new_y = float(0)
i = 0
weights_num = float(0)
for p_tmp in points:
# iterate the new point
p_new_x += p_tmp[0] * weights[i]
p_new_y += p_tmp[1] * weights[i]
weights_num += weights[i]
i = i + 1
p_new_x /= weights_num
p_new_y /= weights_num
return [p_new_x, p_new_y]
# use the kernel function to compute the point weight
def _kernel(self, point, point_set, bandwidth):
weights = []
for p_tmp in point_set:
distance = self.alg.calculate_distance(point, p_tmp)
norm = (distance ** 2) / (bandwidth ** 2)
weight_tmp = (1 / (bandwidth * math.sqrt(2 * math.pi))) * math.exp(-0.5 * norm)
weights.append(weight_tmp)
return weights