-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathMYOPTICS.py
108 lines (108 loc) · 5.28 KB
/
MYOPTICS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import matplotlib.pyplot as plt
import time
import operator
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
def compute_squared_EDM(X):
return squareform(pdist(X,metric='euclidean'))
# 显示决策图
def plotReachability(data,eps):
plt.figure()
plt.plot(range(0,len(data)), data)
plt.plot([0, len(data)], [eps, eps])
plt.show()
# 显示分类的类别
def plotFeature(data,labels):
clusterNum = len(set(labels))
fig = plt.figure()
scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
ax = fig.add_subplot(111)
for i in range(-1, clusterNum):
colorSytle = scatterColors[i % len(scatterColors)]
subCluster = data[np.where(labels == i)]
ax.scatter(subCluster[:, 0], subCluster[:, 1], c=colorSytle, s=12)
plt.show()
def updateSeeds(seeds,core_PointId,neighbours,core_dists,reach_dists,disMat,isProcess):
# 获得核心点core_PointId的核心距离
core_dist=core_dists[core_PointId]
# 遍历core_PointId 的每一个邻居点
for neighbour in neighbours:
# 如果neighbour没有被处理过,计算该核心距离
if(isProcess[neighbour]==-1):
# 首先计算改点的针对core_PointId的可达距离
new_reach_dist = max(core_dist, disMat[core_PointId][neighbour])
if(np.isnan(reach_dists[neighbour])):
reach_dists[neighbour]=new_reach_dist
seeds[neighbour] = new_reach_dist
elif(new_reach_dist<reach_dists[neighbour]):
reach_dists[neighbour] = new_reach_dist
seeds[neighbour] = new_reach_dist
return seeds
def OPTICS(data,eps=np.inf,minPts=15):
# 获得距离矩阵
orders = []
disMat = compute_squared_EDM(data)
# 获得数据的行和列(一共有n条数据)
n, m = data.shape
# np.argsort(disMat)[:,minPts-1] 按照距离进行 行排序 找第minPts个元素的索引
# disMat[np.arange(0,n),np.argsort(disMat)[:,minPts-1]] 计算minPts个元素的索引的距离
temp_core_distances = disMat[np.arange(0,n),np.argsort(disMat)[:,minPts-1]]
# 计算核心距离
core_dists = np.where(temp_core_distances <= eps, temp_core_distances, -1)
# 将每一个点的可达距离未定义
reach_dists= np.full((n,), np.nan)
# 将矩阵的中小于minPts的数赋予1,大于minPts的数赋予零,然后1代表对每一行求和,然后求核心点坐标的索引
core_points_index = np.where(np.sum(np.where(disMat <= eps, 1, 0), axis=1) >= minPts)[0]
# 用于标识是否被处理,没有被处理,设置为-1
isProcess = np.full((n,), -1)
# 遍历所有的核心点
for pointId in core_points_index:
# 如果核心点未被分类,将其作为的种子点,开始寻找相应簇集
if (isProcess[pointId] == -1):
# 将点pointId标记为当前类别(即标识为已操作)
isProcess[pointId] = 1
orders.append(pointId)
# 寻找种子点的eps邻域且没有被分类的点,将其放入种子集合
neighbours = np.where((disMat[:, pointId] <= eps) & (disMat[:, pointId] > 0) & (isProcess == -1))[0]
seeds = dict()
seeds=updateSeeds(seeds,pointId,neighbours,core_dists,reach_dists,disMat,isProcess)
while len(seeds)>0:
nextId = sorted(seeds.items(), key=operator.itemgetter(1))[0][0]
del seeds[nextId]
isProcess[nextId] = 1
orders.append(nextId)
# 寻找newPoint种子点eps邻域(包含自己)
# 这里没有加约束isProcess == -1,是因为如果加了,本是核心点的,可能就变成了非和核心点
queryResults = np.where(disMat[:, nextId] <= eps)[0]
if len(queryResults) >= minPts:
seeds=updateSeeds(seeds,nextId,queryResults,core_dists,reach_dists,disMat,isProcess)
# 簇集生长完毕,寻找到一个类别
# 返回数据集中的可达列表,及其可达距离
return orders,reach_dists
def extract_dbscan(data,orders, reach_dists, eps):
# 获得原始数据的行和列
n,m=data.shape
# reach_dists[orders] 将每个点的可达距离,按照有序列表排序(即输出顺序)
# np.where(reach_dists[orders] <= eps)[0],找到有序列表中小于eps的点的索引,即对应有序列表的索引
reach_distIds=np.where(reach_dists[orders] <= eps)[0]
# 正常来说:current的值的值应该比pre的值多一个索引。如果大于一个索引就说明不是一个类别
pre=reach_distIds[0]-1
clusterId=0
labels=np.full((n,),-1)
for current in reach_distIds:
# 正常来说:current的值的值应该比pre的值多一个索引。如果大于一个索引就说明不是一个类别
if(current-pre!=1):
# 类别+1
clusterId=clusterId+1
labels[orders[current]]=clusterId
pre=current
return labels
data = np.loadtxt("data/cluster.csv", delimiter=",")
start = time.clock()
orders,reach_dists=OPTICS(data,np.inf,30)
end = time.clock()
print('finish all in %s' % str(end - start))
labels=extract_dbscan(data,orders,reach_dists,3)
plotReachability(reach_dists[orders],3)
plotFeature(data,labels)