-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdither_shr.pyx
431 lines (373 loc) · 19.7 KB
/
dither_shr.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# cython: infer_types=True
# cython: profile=False
# cython: boundscheck=False
# cython: wraparound=False
cimport cython
import colour
import numpy as np
cimport common
def dither_shr_perfect(
float[:, :, ::1] input_rgb, float[:, ::1] full_palette_cam, float[:, ::1] full_palette_rgb,
float[:,::1] rgb_to_cam16ucs):
cdef int y, x, idx, best_colour_idx, i, j
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb
cdef float quant_error
cdef float[:, ::1] palette_rgb, palette_cam
cdef float[:, :, ::1] working_image = np.copy(input_rgb)
cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32)
cdef int palette_size = full_palette_rgb.shape[0]
cdef float decay = 0.5
cdef int floyd_steinberg = 1
cdef common.float3 cam, pixel_cam
total_image_error = 0.0
for y in range(200):
for x in range(320):
cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for j in range(3):
line_cam[x, j] = cam.data[j]
for x in range(320):
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = -1
for idx in range(palette_size):
for j in range(3):
cam.data[j] = full_palette_cam[idx,j]
distance = common.colour_distance_squared(pixel_cam.data, cam.data)
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
best_colour_rgb = full_palette_rgb[best_colour_idx, :]
total_image_error += best_distance
for i in range(3):
quant_error = working_image[y, x, i] - best_colour_rgb[i]
working_image[y, x, i] = best_colour_rgb[i]
if floyd_steinberg:
# Floyd-Steinberg dither
# 0 * 7
# 3 5 1
if x < 319:
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
# 0 0 X 7 5
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = common.clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = common.clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = common.clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = common.clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = common.clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = common.clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = common.clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = common.clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
return total_image_error, working_image
def dither_shr(
float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb,
float[:,::1] rgb_to_cam16ucs):
cdef int y, x, idx, best_colour_idx, best_palette, i, j
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb
cdef float quant_error
cdef float[:, ::1] palette_rgb, palette_cam
cdef (unsigned char)[:, ::1] output_4bit = np.zeros((200, 320), dtype=np.uint8)
cdef float[:, :, ::1] working_image = np.copy(input_rgb)
cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32)
cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32)
cdef double[::1] palette_line_errors = np.zeros(200, dtype=np.float64)
cdef PaletteSelection palette_line
cdef float decay = 0.5
cdef int floyd_steinberg = 1
cdef common.float3 pixel_cam, cam
best_palette = -1
total_image_error = 0.0
for y in range(200):
for x in range(320):
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for j in range(3):
line_cam[x, j] = pixel_cam.data[j]
palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette)
best_palette = palette_line.palette_idx
palette_line_errors[y] = palette_line.total_error
palette_rgb = palettes_rgb[best_palette, :, :]
palette_cam = palettes_cam[best_palette, :, :]
line_to_palette[y] = best_palette
for x in range(320):
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = -1
for idx in range(16):
for j in range(3):
cam.data[j] = palette_cam[idx, j]
distance = common.colour_distance_squared(pixel_cam.data, cam.data)
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
best_colour_rgb = palette_rgb[best_colour_idx]
output_4bit[y, x] = best_colour_idx
total_image_error += best_distance
for i in range(3):
quant_error = working_image[y, x, i] - best_colour_rgb[i]
working_image[y, x, i] = best_colour_rgb[i]
if floyd_steinberg:
# Floyd-Steinberg dither
# 0 * 7
# 3 5 1
if x < 319:
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
# 0 0 X 7 5
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = common.clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = common.clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = common.clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = common.clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = common.clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = common.clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = common.clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = common.clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
return (
np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error,
np.array(palette_line_errors, dtype=np.float64)
)
cdef struct PaletteSelection:
int palette_idx
double total_error
cdef PaletteSelection best_palette_for_line(
float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx) nogil:
cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx
cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef float[:, ::1] palette_cam
cdef common.float3 pixel_cam, cam
cdef int j
best_total_dist = 1e9
best_palette_idx = -1
cdef int line_size = line_cam.shape[0]
for palette_idx in range(16):
palette_cam = palettes_cam[palette_idx, :, :]
total_dist = 0
for pixel_idx in range(line_size):
for j in range(3):
pixel_cam.data[j] = line_cam[pixel_idx, j]
best_pixel_dist = 1e9
for palette_entry_idx in range(16):
for j in range(3):
cam.data[j] = palette_cam[palette_entry_idx, j]
pixel_dist = common.colour_distance_squared(pixel_cam.data, cam.data)
if pixel_dist < best_pixel_dist:
best_pixel_dist = pixel_dist
total_dist += best_pixel_dist
if total_dist < best_total_dist:
best_total_dist = total_dist
best_palette_idx = palette_idx
cdef PaletteSelection res
res.palette_idx = best_palette_idx
res.total_error = best_total_dist
return res
cdef common.float3 _convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) nogil:
cdef int rgb12 = (point_rgb12[0] << 8) | (point_rgb12[1] << 4) | point_rgb12[2]
cdef int i
cdef common.float3 res
for i in range(3):
res.data[i] = rgb12_iigs_to_cam16ucs[rgb12, i]
return res
# Wrapper around _convert_rgb12_iigs_to_cam to allow calling from python while retaining fast path for cython calls.
def convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) -> float[::1]:
cdef common.float3 cam = _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, point_rgb12)
cdef int i
cdef float[::1] res = np.empty((3), dtype=np.float32)
for i in range(3):
res[i] = cam.data[i]
return res
@cython.cdivision(True)
cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4):
cdef int i, j
cdef float[:, ::1] res = np.empty_like(a, dtype=np.float32)
for i in range(res.shape[0]):
for j in range(3):
if a[i, j] <= 0.0031308:
res[i, j] = a[i, j] * 12.92
else:
res[i, j] = 1.055 * a[i, j] ** (1.0 / gamma) - 0.055
return res
# TODO: optimize
cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
cdef float[:, ::1] rgb
cdef (float)[:, ::1] rgb12_iigs
# Convert CAM16UCS input to RGB. Even though this dynamically constructs a path on the graph of colour conversions
# every time, in practise this seems to have a negligible overhead compared to the actual conversion functions.
with colour.utilities.suppress_warnings(python_warnings=True):
rgb = colour.convert(point_cam, "CAM16UCS", "RGB").astype(np.float32)
# TODO: precompute this conversion matrix since it's static. This accounts for about 10% of the CPU time here.
rgb12_iigs = np.ascontiguousarray(
np.clip(
# Convert to Rec.601 R'G'B'
colour.YCbCr_to_RGB(
# Gamma correct and convert Rec.709 R'G'B' to YCbCr
colour.RGB_to_YCbCr(
linear_to_srgb_array(rgb), K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1)
).astype(np.float32) * 15
return np.round(rgb12_iigs).astype(np.uint8)
# Wrapper around _convert_cam16ucs_to_rgb12_iigs to allow calling from python while retaining fast path for cython
# calls.
def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
return _convert_cam16ucs_to_rgb12_iigs(point_cam)
@cython.cdivision(True)
def k_means_with_fixed_centroids(
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations,
float [:, ::1] rgb12_iigs_to_cam16ucs):
cdef double error, best_error, total_error, last_total_error
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
cdef (unsigned char)[:, ::1] new_centroids_rgb12
cdef common.float3 point_cam
cdef float[:, ::1] new_centroids_cam = np.empty((n_clusters - n_fixed, 3), dtype=np.float32)
cdef float[:, ::1] centroid_cam_sample_positions_total
cdef int[::1] centroid_sample_counts
last_total_error = 1e9
for iteration in range(max_iterations):
total_error = 0.0
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_sample_counts = np.zeros(16, dtype=np.int32)
# For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples
# but we do this by accumulating the (coordinate vector) total and number of associated samples.
#
# Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour
# space.
for point_idx in range(samples.shape[0]):
for j in range(3):
point_cam.data[j] = samples[point_idx, j]
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
error = common.colour_distance_squared(
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]).data,
point_cam.data)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
for i in range(3):
centroid_cam_sample_positions_total[closest_centroid_idx, i] += point_cam.data[i]
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error
# Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we
# can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest
# available point to an intended next centroid position will increase the total distance, or centroids may
# oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing.
if total_error >= last_total_error:
break
last_total_error = total_error
# Compute new centroid positions in CAM16UCS colour space
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroids_cam[centroid_idx - n_fixed, i] = (
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
# Convert all new centroids back to //gb RGB colour space (done as a single matrix since
# _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead)
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
# Update positions for non-fixed centroids
for centroid_idx in range(n_clusters - n_fixed):
for i in range(3):
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
return centroids_rgb12