-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcolor_torch.py
79 lines (70 loc) · 3.25 KB
/
color_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
def rgb2hsl_torch(rgb: torch.Tensor) -> torch.Tensor:
cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
cmin = torch.min(rgb, dim=1, keepdim=True)[0]
delta = cmax - cmin
hsl_h = torch.empty_like(rgb[:, 0:1, :, :])
cmax_idx[delta == 0] = 3
hsl_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
hsl_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
hsl_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
hsl_h[cmax_idx == 3] = 0.
hsl_h /= 6.
hsl_l = (cmax + cmin) / 2.
hsl_s = torch.empty_like(hsl_h)
hsl_s[hsl_l == 0] = 0
hsl_s[hsl_l == 1] = 0
hsl_l_ma = torch.bitwise_and(hsl_l > 0, hsl_l < 1)
hsl_l_s0_5 = torch.bitwise_and(hsl_l_ma, hsl_l <= 0.5)
hsl_l_l0_5 = torch.bitwise_and(hsl_l_ma, hsl_l > 0.5)
hsl_s[hsl_l_s0_5] = ((cmax - cmin) / (hsl_l * 2.))[hsl_l_s0_5]
hsl_s[hsl_l_l0_5] = ((cmax - cmin) / (- hsl_l * 2. + 2.))[hsl_l_l0_5]
return torch.cat([hsl_h, hsl_s, hsl_l], dim=1)
def rgb2hsv_torch(rgb: torch.Tensor) -> torch.Tensor:
cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
cmin = torch.min(rgb, dim=1, keepdim=True)[0]
delta = cmax - cmin
hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
cmax_idx[delta == 0] = 3
hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
hsv_h[cmax_idx == 3] = 0.
hsv_h /= 6.
hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
hsv_v = cmax
return torch.cat([hsv_h, hsv_s, hsv_v], dim=1)
def hsv2rgb_torch(hsv: torch.Tensor) -> torch.Tensor:
hsv_h, hsv_s, hsv_l = hsv[:, 0:1], hsv[:, 1:2], hsv[:, 2:3]
_c = hsv_l * hsv_s
_x = _c * (- torch.abs(hsv_h * 6. % 2. - 1) + 1.)
_m = hsv_l - _c
_o = torch.zeros_like(_c)
idx = (hsv_h * 6.).type(torch.uint8)
idx = (idx % 6).expand(-1, 3, -1, -1)
rgb = torch.empty_like(hsv)
rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0]
rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1]
rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2]
rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3]
rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4]
rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5]
rgb += _m
return rgb
def hsl2rgb_torch(hsl: torch.Tensor) -> torch.Tensor:
hsl_h, hsl_s, hsl_l = hsl[:, 0:1], hsl[:, 1:2], hsl[:, 2:3]
_c = (-torch.abs(hsl_l * 2. - 1.) + 1) * hsl_s
_x = _c * (-torch.abs(hsl_h * 6. % 2. - 1) + 1.)
_m = hsl_l - _c / 2.
idx = (hsl_h * 6.).type(torch.uint8)
idx = (idx % 6).expand(-1, 3, -1, -1)
rgb = torch.empty_like(hsl)
_o = torch.zeros_like(_c)
rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0]
rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1]
rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2]
rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3]
rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4]
rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5]
rgb += _m
return rgb