Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.0 #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,6 @@ dmypy.json
# Custom ones
figs_*
*.npz


**/.DS_Store
8 changes: 8 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"python.formatting.provider": "black",
"python.linting.flake8Args": ["--max-line-length=130"],
"editor.formatOnSave": true,
}
61 changes: 29 additions & 32 deletions graphs/generate_traversals.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,80 @@
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from tensorflow import sigmoid
from scipy.stats import pearsonr, spearmanr, pointbiserialr
import torch
from torch.nn.functional import sigmoid
from scipy.stats import spearmanr
from sklearn.feature_selection import mutual_info_regression
import seaborn as sns

def generate_traversals(model, s_dim, s_sample, S_real, filenames=[], naive=False, colour=False):
elements = 10

fig = plt.figure(figsize=(8,10))
gs = gridspec.GridSpec(s_dim, 3, width_ratios=[5,1,1])
arg_max_hist_value = np.zeros(s_dim)
start_val = np.zeros(s_dim)
end_val = np.zeros(s_dim)
arg_max_hist_value = torch.zeros(s_dim)
start_val = torch.zeros(s_dim)
end_val = torch.zeros(s_dim)
for s_indx in range(s_dim):
plt.subplot(gs[s_indx*3+1])
hh = plt.hist(s_sample[:,s_indx])
hh = plt.hist(s_sample[:,s_indx].cpu().numpy())

if naive:
arg_max_hist_value[s_indx] = 0.0
start_val[s_indx] = -3.0
end_val[s_indx] = 3.0
else:
index_of_highest = np.argmax(hh[0])
index_of_highest = torch.argmax(torch.tensor(hh[0]))
arg_max_hist_value[s_indx] = (hh[1][index_of_highest]+hh[1][index_of_highest+1])/2.0
start_val[s_indx] = (hh[1][0]+hh[1][1])/2.0
end_val[s_indx] = (hh[1][-2]+hh[1][-1])/2.0

start_val = np.array([-5.0, -5.0, -2.0, -5.0, -1.3, -0.65, -2.0, -2.5, 0.4, -2.5])
arg_max_hist_value = np.array([-1.5,0.0,-1.5,0.0,1.0,0.0,0.0,0.0,0.0,0.0])
#arg_max_hist_value = np.array([-1.5,0.0,-1.5,0.0,0.75,1.0,0.0,0.0,0.75,0.0]) # for s8
end_val = np.array([4.0, 5.0, 2.0, 5.0, 4.75, 2.1, 2.0, 2.5, 3.45, 2.5])
start_val = torch.tensor([-5.0, -5.0, -2.0, -5.0, -1.3, -0.65, -2.0, -2.5, 0.4, -2.5])
arg_max_hist_value = torch.tensor([-1.5,0.0,-1.5,0.0,1.0,0.0,0.0,0.0,0.0,0.0])
end_val = torch.tensor([4.0, 5.0, 2.0, 5.0, 4.75, 2.1, 2.0, 2.5, 3.45, 2.5])

if len(S_real) > 0:
correlations = np.zeros((10,6))
correlations_cat = np.zeros((10,6))
correlations_p = np.zeros((10,6))
correlations = torch.zeros((10,6))
correlations_cat = torch.zeros((10,6))
correlations_p = torch.zeros((10,6))
labels = ['shape', 'scale', 'orientation', 'posX', 'posY', 'reward']
for real_s_indx in range(6):
for s_indx in range(s_dim):
correlations[s_indx,real_s_indx],correlations_p[s_indx,real_s_indx] = spearmanr(s_sample[:,s_indx],S_real[:,real_s_indx])
correlations[s_indx,real_s_indx] = abs(correlations[s_indx,real_s_indx])
correlations_cat[s_indx,real_s_indx] = mutual_info_regression(s_sample[:,s_indx].numpy().reshape(-1,1),S_real[:,real_s_indx])
corr, p_value = spearmanr(s_sample[:,s_indx].cpu().numpy(), S_real[:,real_s_indx].cpu().numpy())
correlations[s_indx,real_s_indx] = abs(corr)
correlations_p[s_indx,real_s_indx] = p_value
correlations_cat[s_indx,real_s_indx] = mutual_info_regression(s_sample[:,s_indx].cpu().numpy().reshape(-1,1), S_real[:,real_s_indx].cpu().numpy())

for s_indx in range(s_dim):
plt.subplot(gs[s_indx*3+2])
plt.plot(correlations[s_indx][1:])
if np.max(correlations[s_indx][1:]) < 0.5:
sns.lineplot(data=correlations[s_indx][1:].cpu().numpy())
if torch.max(correlations[s_indx][1:]) < 0.5:
plt.ylim(0.0,0.5)
plt.plot(correlations_cat[s_indx])
sns.lineplot(data=correlations_cat[s_indx].cpu().numpy())
plt.ylabel('Correlation')
plt.xticks(range(len(labels)-1),labels[1:], rotation='vertical')

for s_indx in range(s_dim):
plt.subplot(gs[s_indx*3])
plt.ylabel(r'$s_{'+str(s_indx)+'}$')
s = np.zeros((elements,s_dim))
s = torch.zeros((elements,s_dim))
for x in range(elements):
for y in range(s_dim):
s[x,y] = arg_max_hist_value[y]

for x,s_x in enumerate(np.linspace(start_val[s_indx],end_val[s_indx],elements)):
for x,s_x in enumerate(torch.linspace(start_val[s_indx],end_val[s_indx],elements)):
s[x,s_indx] = s_x
with torch.no_grad():
new_img = model.model_down.decoder(s)
if colour:
new_img = model.model_down.decoder(s)[:,:,:]
plt.imshow(np.hstack(new_img), vmin=0, vmax=1)
plt.imshow(torch.hstack(new_img[:,:,:3]).cpu(), vmin=0, vmax=1)
else:
new_img = model.model_down.decoder(s)[:,:,:,0]
plt.imshow(np.hstack(new_img), cmap='gray', vmin=0, vmax=1)
plt.imshow(torch.hstack(new_img[:,:,:,0]).cpu(), cmap='gray', vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.xlabel(str(round(start_val[s_indx],4))+' <-- '+str(round(arg_max_hist_value[s_indx],4))+' --> '+str(round(end_val[s_indx],4)))

plt.xlabel(f"{start_val[s_indx]:.4f} <-- {arg_max_hist_value[s_indx]:.4f} --> {end_val[s_indx]:.4f}")

fig.set_tight_layout(True)
for filename in filenames:
plt.savefig(filename)
#plt.show()
plt.close()


#
43 changes: 43 additions & 0 deletions src/causal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn

class StructuralCausalModel(nn.Module):
def __init__(self, s_dim, pi_dim, gamma, beta_s, beta_o, colour_channels, resolution):
super(StructuralCausalModel, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(colour_channels, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(128 * (resolution // 8) * (resolution // 8), s_dim)
)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(s_dim, 128 * (resolution // 8) * (resolution // 8)),
nn.ReLU(),
nn.Unflatten(1, (128, resolution // 8, resolution // 8)),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, colour_channels, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
)
self.beta_s = beta_s
self.beta_o = beta_o
self.gamma = gamma

def forward(self, x):
s = self.encoder(x)
x_recon = self.decoder(s)
return x_recon, s

def counterfactual(self, x, intervention):
s = self.encoder(x)
s_intervened = s + intervention
x_recon = self.decoder(s_intervened)
return x_recon, s_intervened
Loading