Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/amarquand/PCNtoolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
AuguB committed Aug 9, 2023
2 parents ca5619d + b402548 commit ad5b2b6
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 8 deletions.
Binary file removed dist/pcntoolkit-0.27-py3.11.egg
Binary file not shown.
Binary file removed dist/pcntoolkit-0.28-py3.10.egg
Binary file not shown.
5 changes: 3 additions & 2 deletions pcntoolkit/model/hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,13 @@ def create_dummy_inputs(self, covariate_ranges=[[0.1, 0.9, 0.01]]):
batch_effects_dummy = np.repeat(batch_effects, X.shape[0], axis=0)
return X_dummy, batch_effects_dummy

def Rhats(self, var_names, thin = 1, resolution = 100):
def Rhats(self, var_names=None, thin = 1, resolution = 100):
"""Get Rhat of posterior samples as function of sampling iteration"""
idata = self.idata
testvars = az.extract(idata, group='posterior', var_names=var_names, combined=False)
testvar_names = [var for var in list(testvars.data_vars.keys()) if not '_samples' in var]
rhat_dict={}
for var_name in var_names:
for var_name in testvar_names:
var = np.stack(testvars[var_name].to_numpy())[:,::thin]
var = var.reshape((var.shape[0], var.shape[1], -1))
vardim = var.shape[2]
Expand Down
22 changes: 18 additions & 4 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,16 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
z_scores = np.arange(-3, 4)
likelihood=self.configs['likelihood']

# Determine the variables to predict
# Determine the variables to predict
if self.configs["likelihood"] == "Normal":
var_names = ["mu_samples", "sigma_plus_samples"]
var_names = ["mu_samples", "sigma_samples","sigma_plus_samples"]
elif self.configs["likelihood"].startswith("SHASH"):
var_names = [
"mu_samples",
"sigma_samples",
"sigma_plus_samples",
"epsilon_samples",
"delta_samples",
"delta_plus_samples",
]
else:
Expand All @@ -462,6 +464,11 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
post_pred = az.extract(
self.hbr.idata, "posterior_predictive", var_names=var_names
)

# Remove superfluous var_nammes
var_names.remove('sigma_samples')
if 'delta_samples' in var_names:
var_names.remove('delta_samples')

# Separate the samples into a list so that they can be unpacked
array_of_vars = list(map(lambda x: post_pred[x], var_names))
Expand Down Expand Up @@ -498,12 +505,14 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):

# Determine the variables to predict
if self.configs["likelihood"] == "Normal":
var_names = ["mu_samples", "sigma_plus_samples"]
var_names = ["mu_samples", "sigma_samples","sigma_plus_samples"]
elif self.configs["likelihood"].startswith("SHASH"):
var_names = [
"mu_samples",
"sigma_samples",
"sigma_plus_samples",
"epsilon_samples",
"delta_samples",
"delta_plus_samples",
]
else:
Expand All @@ -526,6 +535,11 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):
post_pred = az.extract(
self.hbr.idata, "posterior_predictive", var_names=var_names
)

# Remove superfluous var_nammes
var_names.remove('sigma_samples')
if 'delta_samples' in var_names:
var_names.remove('delta_samples')

# Separate the samples into a list so that they can be unpacked
array_of_vars = list(map(lambda x: post_pred[x], var_names))
Expand All @@ -541,7 +555,7 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):
)
return z_scores.mean(axis=-1)



def S_inv(x, e, d):
return np.sinh((np.arcsinh(x) + e) / d)
Expand Down
13 changes: 11 additions & 2 deletions tests/testHBR.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

for model_type in model_types:

nm = norm_init(X_train, Y_train, alg='hbr',likelihood='Normal', model_type=model_type,n_samples=100,n_tuning=10)
nm = norm_init(X_train, Y_train, alg='hbr',likelihood='SHASHb', model_type=model_type,n_samples=100,n_tuning=10)
nm.estimate(X_train, Y_train, trbefile=working_dir+'trbefile.pkl')
yhat, ys2 = nm.predict(X_test, tsbefile=working_dir+'tsbefile.pkl')

Expand All @@ -63,19 +63,28 @@

plt.figure()
for j in range(n_grps):
plt.scatter(temp_X[temp_be==j,], temp_Y[temp_be==j,],
scat1 = plt.scatter(temp_X[temp_be==j,], temp_Y[temp_be==j,],
label='Group' + str(j))
plt.plot(temp_X[temp_be==j,], temp_yhat[temp_be==j,])
plt.fill_between(temp_X[temp_be==j,], temp_yhat[temp_be==j,] -
1.96 * np.sqrt(temp_s2[temp_be==j,]),
temp_yhat[temp_be==j,] +
1.96 * np.sqrt(temp_s2[temp_be==j,]),
color='gray', alpha=0.2)

# Showing the quantiles
resolution = 200
synth_X = np.linspace(-3, 3, resolution)
q = nm.get_mcmc_quantiles(synth_X, batch_effects=j*np.ones(resolution))
col = scat1.get_facecolors()[0]
plt.plot(synth_X, q.T, linewidth=1, color=col, zorder = 0)

plt.title('Model %s, Feature %d' %(model_type, i))
plt.legend()
plt.show()



############################## Normative Modelling Test #######################


Expand Down

0 comments on commit ad5b2b6

Please sign in to comment.