-
Notifications
You must be signed in to change notification settings - Fork 84
/
generators.py
146 lines (121 loc) · 4.29 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python
import random
import functools
from range_map import RangeMap
class SegmentTreeGenerator(object):
"""
Returns random elements with a probability proportional to the frequency
distribution of each element in the population in O(log(N)) time. Uses up
to 2N space but Supports O(log(N)) weight updates.
"""
def __init__(self, population):
"""
:type population: Dict[T, double]
"""
data = list(map(lambda value: (value, float(population[value])),
population.keys()))
self._range_map = RangeMap(data)
def random(self):
"""
Returns an element from the original population with probability
proportional to its relative frequency. O(log(N)) time.
:rtype: T
"""
i = random.random() * self._range_map.range()
return self._range_map.get(i)
def update(self, value, weight):
"""
Updates the weight of the value in the population. O(log(N)) time.
"""
self._range_map.update(value, weight)
class ArrayGenerator(object):
"""
Returns random elements with a probability proportional to the frequency
distribution of each element in the population in O(log(N)) time. Uses
N space but weight updates are proportional to O(N) time.
"""
def __init__(self, population):
"""
:type population: Dict[T, double]
"""
assert len(population) > 0
self._data = list()
self._index = dict()
self._total = float(sum(population.values()))
offset = 0
for value in population:
weight = population[value]
if weight < 0:
raise ValueError('weights must be >= 0')
self._index[value] = len(self._data)
self._data.append(ArrayGenerator.Entry(value, offset, weight))
offset += weight
def random(self):
"""
Returns an element from the original population with probability
proportional to its relative frequency. O(log(N)) time.
:rtype: T
"""
i = random.random() * self._total
j = 0
while j < len(self._data) - 1 and self._data[j + 1].offset <= i:
j += 1
return self._data[j].value
def update(self, value, weight):
"""
Updates the weight of the value in the population. O(N) time.
"""
assert value in self._index
assert weight >= 0
# Compute the change in offset for following values.
i = self._index[value]
delta = weight - self._data[i].weight
self._data[i].weight = weight
# Update offset for following values.
i += 1
while i < len(self._data):
self._data[i].offset += delta
i += 1
# Update total.
self._total += delta
class Entry(object):
def __init__(self, value, offset, weight):
self.value = value
self.offset = offset
self.weight = weight
class DictGenerator(object):
"""
Returns random elements with a probability proportional to the frequency
distribution of each element in the population in O(N) time but weight
updates take O(1) time.
"""
def __init__(self, population):
"""
:type population: Dict[T, double]
"""
self._population = dict(population)
self._total = functools.reduce(lambda s, key: population[key] + s,
population,
0)
def random(self):
"""
Returns an element from the original population with probability
proportional to its relative frequency. O(N) time.
:rtype: T
"""
i = random.random() * self._total
prefix = 0
for key in self._population:
prefix += self._population[key]
if prefix >= i:
return key
raise RuntimeError('This should never happen...')
def update(self, value, weight):
"""
Updates the weight of the value in the population. O(1) time.
"""
assert value in self._population
assert weight >= 0
current = self._population[value]
self._total += (weight - current)
self._population[value] = weight