-
Notifications
You must be signed in to change notification settings - Fork 6
/
disentanglement_utils.py
115 lines (97 loc) · 4.15 KB
/
disentanglement_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from sklearn import metrics
from sklearn.model_selection import GridSearchCV
from sklearn import linear_model, kernel_ridge
import torch
import numpy as np
import scipy as sp
from typing import Union
from typing_extensions import Literal
__Mode = Union[Literal["r2"]]
def _disentanglement(z, hz, mode: __Mode = "r2", reorder=None):
"""Measure how well hz reconstructs z measured either by the Coefficient of Determination or the
Pearson/Spearman correlation coefficient."""
assert mode in ("r2", "accuracy")
if mode == "r2":
return metrics.r2_score(z, hz), None
elif mode == "accuracy":
return metrics.accuracy_score(z, hz), None
def nonlinear_disentanglement(z, hz, mode: __Mode = "r2", train_test_split=False, alpha=1.0, gamma=None, train_mode=False, model=None, scaler_z=None, scaler_hz=None):
"""Calculate disentanglement up to nonlinear transformations.
Args:
z: Ground-truth latents.
hz: Reconstructed latents.
mode: Can be r2, pearson, spearman
train_test_split: Use first half to train linear model, second half to test.
Is only relevant if there are less samples then latent dimensions.
"""
if torch.is_tensor(hz):
hz = hz.detach().cpu().numpy()
if torch.is_tensor(z):
z = z.detach().cpu().numpy()
assert isinstance(z, np.ndarray), "Either pass a torch tensor or numpy array as z"
assert isinstance(hz, np.ndarray), "Either pass a torch tensor or numpy array as hz"
# split z, hz to get train and test set for linear model
if train_test_split:
n_train = len(z) // 2
z_1 = z[:n_train]
hz_1 = hz[:n_train]
z_2 = z[n_train:]
hz_2 = hz[n_train:]
model = kernel_ridge.KernelRidge(kernel='linear', alpha=alpha, gamma=gamma)
model.fit(hz_1, z_1)
hz_2 = model.predict(hz_2)
inner_result = _disentanglement(z_2, hz_2, mode=mode, reorder=False)
return inner_result, (z_2, hz_2)
else:
if train_mode:
model = GridSearchCV(kernel_ridge.KernelRidge(kernel='rbf', gamma=0.1),
param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3],
"gamma": np.logspace(-2, 2, 4)}, cv=3, n_jobs=-1)
model.fit(hz, z)
return model
else:
hz = model.predict(hz)
inner_result = _disentanglement(z, hz, mode=mode, reorder=False)
return inner_result, (z, hz)
def linear_disentanglement(z, hz, mode: __Mode = "r2", train_test_split=False, train_mode=False, model=None):
"""Calculate disentanglement up to linear transformations.
Args:
z: Ground-truth latents.
hz: Reconstructed latents.
mode: Can be r2, pearson, spearman
train_test_split: Use first half to train linear model, second half to test.
Is only relevant if there are less samples then latent dimensions.
"""
if torch.is_tensor(hz):
hz = hz.detach().cpu().numpy()
if torch.is_tensor(z):
z = z.detach().cpu().numpy()
assert isinstance(z, np.ndarray), "Either pass a torch tensor or numpy array as z"
assert isinstance(hz, np.ndarray), "Either pass a torch tensor or numpy array as hz"
# split z, hz to get train and test set for linear model
if train_test_split:
n_train = len(z) // 2
z_1 = z[:n_train]
hz_1 = hz[:n_train]
z_2 = z[n_train:]
hz_2 = hz[n_train:]
if mode == "accuracy":
model = linear_model.LogisticRegression()
else:
model = linear_model.LinearRegression()
model.fit(hz_1, z_1)
hz_2 = model.predict(hz_2)
inner_result = _disentanglement(z_2, hz_2, mode=mode, reorder=False)
return inner_result, (z_2, hz_2)
else:
if train_mode:
if mode == "accuracy":
model = linear_model.LogisticRegression()
else:
model = linear_model.LinearRegression()
model.fit(hz, z)
return model
else:
hz = model.predict(hz)
inner_result = _disentanglement(z, hz, mode=mode, reorder=False)
return inner_result, (z, hz)