-
Notifications
You must be signed in to change notification settings - Fork 3
/
cbs.py
152 lines (131 loc) · 5.23 KB
/
cbs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*-coding:utf-8-*-
# Code taken from: https://github.com/jeremy9959/cbs
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import logging
log = logging.getLogger()
logging.basicConfig(level=logging.WARN)
def cbs_stat(x):
'''Given x, Compute the subinterval x[i0:i1] with the maximal segmentation statistic t.
Returns t, i0, i1'''
x0 = x - np.mean(x)
n = len(x0)
y = np.cumsum(x0)
e0, e1 = np.argmin(y), np.argmax(y)
i0, i1 = min(e0, e1), max(e0, e1)
s0, s1 = y[i0], y[i1]
return (s1 - s0) ** 2 * n / (i1 - i0 + 1) / (n + 1 - i1 + i0), i0, i1 + 1
def tstat(x, i):
'''Return the segmentation statistic t testing if i is a (one-sided) breakpoint in x'''
n = len(x)
s0 = np.mean(x[:i])
s1 = np.mean(x[i:])
return (n - i) * i / n * (s0 - s1) ** 2
def cbs(x, shuffles=1000, p=.05):
'''Given x, find the interval x[i0:i1] with maximal segmentation statistic t. Test that statistic against
given (shuffles) number of random permutations with significance p. Return True/False, t, i0, i1; True if
interval is significant, false otherwise.'''
max_t, max_start, max_end = cbs_stat(x)
if max_end - max_start == len(x):
return False, max_t, max_start, max_end
if max_start < 5:
max_start = 0
if len(x) - max_end < 5:
max_end = len(x)
thresh_count = 0
alpha = shuffles * p
xt = x.copy()
for i in range(shuffles):
np.random.shuffle(xt)
threshold, s0, e0 = cbs_stat(xt)
if threshold >= max_t:
thresh_count += 1
if thresh_count > alpha:
return False, max_t, max_start, max_end
return True, max_t, max_start, max_end
def rsegment(x, start, end, L=[], shuffles=1000, p=.05):
'''Recursively segment the interval x[start:end] returning a list L of pairs (i,j) where each (i,j) is a significant segment.
'''
threshold, t, s, e = cbs(x[start:end], shuffles=shuffles, p=p)
log.info(
'Proposed partition of {} to {} from {} to {} with t value {} is {}'.format(start, end, start + s, start + e, t,
threshold))
if (not threshold) | (e - s < 5) | (e - s == end - start):
L.append((start, end))
else:
if s > 0:
rsegment(x, start, start + s, L)
if e - s > 0:
rsegment(x, start + s, start + e, L)
if start + e < end:
rsegment(x, start + e, end, L)
return L
def segment(x, shuffles=1000, p=.05):
'''Segment the array x, using significance test based on shuffles rearrangements and significance level p
'''
start = 0
end = len(x)
L = []
rsegment(x, start, end, L, shuffles=shuffles, p=p)
return L
def validate(x, L, shuffles=1000, p=.01):
S = [x[0] for x in L] + [len(x)]
SV = [0]
left = 0
for test, s in enumerate(S[1:-1]):
t = tstat(x[S[left]:S[test + 2]], S[test + 1] - S[left])
log.info('Testing validity of {} in interval from {} to {} yields statistic {}'.format(S[test + 1], S[left],
S[test + 2], t))
threshold = 0
thresh_count = 0
site = S[test + 1] - S[left]
xt = x[S[left]:S[test + 2]].copy()
flag = True
for k in range(shuffles):
np.random.shuffle(xt)
threshold = tstat(xt, site)
if threshold > t:
thresh_count += 1
if thresh_count >= p * shuffles:
flag = False
log.info('Breakpoint {} rejected'.format(S[test + 1]))
break
if flag:
log.info('Breakpoint {} accepted'.format(S[test + 1]))
SV.append(S[test + 1])
left += 1
SV.append(S[-1])
return SV
def generate_normal_time_series(num, minl=50, maxl=1000):
'''Generate a time series with num segments of minimal length minl and maximal length maxl. Within a segment,
data is normal with randomly chosen, normally distributed mean between -10 and 10, variance between 0 and 1.
'''
data = np.array([], dtype=np.float64)
partition = np.random.randint(minl, maxl, num)
for p in partition:
mean = np.random.randn() * 10
var = np.random.randn() * 1
if var < 0:
var = var * -1
tdata = np.random.normal(mean, var, p)
data = np.concatenate((data, tdata))
return data
def draw_segmented_data(data, S, title=None):
'''Draw a scatterplot of the data with vertical lines at segment boundaries and horizontal lines at means of
the segments. S is a list of segment boundaries.'''
j = sns.scatterplot(range(len(data)), data, color='black', size=.1, legend=None)
for x in S:
j.axvline(x)
for i in range(1, len(S)):
j.hlines(np.mean(data[S[i - 1]:S[i]]), S[i - 1], S[i], color='green')
j.set_title(title)
j.get_figure().set_size_inches(16, 4)
return j
if __name__ == '__main__':
log.setLevel(logging.INFO)
sample = generate_normal_time_series(5)
L = segment(sample)
S = validate(sample, L)
ax = draw_segmented_data(sample, S, title='Circular Binary Segmentation of Data')
ax.get_figure().savefig('plot.png')