Skip to content

Commit

Permalink
Merge pull request #208 from amarquand/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
amarquand authored Jun 13, 2024
2 parents 8556b07 + 0c0f692 commit c40a2d3
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 112 deletions.
2 changes: 1 addition & 1 deletion pcntoolkit/model/hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def predict_on_new_site(self, X, batch_effects):
modeler = self.get_modeler()
with modeler(X, y, batch_effects, self.configs, idata=self.idata):
self.idata = pm.sample_posterior_predictive(
self.idata, extend_inferencedata=True, progressbar=True
self.idata, extend_inferencedata=True, progressbar=True, var_names=self.vars_to_sample
)
pred_mean = self.idata.posterior_predictive["y_like"].mean(axis=(0, 1))
pred_var = self.idata.posterior_predictive["y_like"].var(axis=(0, 1))
Expand Down
16 changes: 13 additions & 3 deletions pcntoolkit/normative.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,6 @@ def predict(covfile, respfile, maskfile=None, **kwargs):
Xz = X

# estimate the models for all variabels
# TODO Z-scores adaptation for SHASH HBR
for i, m in enumerate(models):
print("Prediction by model ", i+1, "of", feature_num)
nm = norm_init(Xz)
Expand All @@ -842,6 +841,19 @@ def predict(covfile, respfile, maskfile=None, **kwargs):
else:
Yhat[:, i] = yhat.squeeze()
S2[:, i] = s2.squeeze()
if respfile is not None:
Y, maskvol = load_response_vars(respfile, maskfile)
Y = Y[:, i:i+1]
if alg == 'hbr':
if outscaler in ['standardize', 'minmax', 'robminmax']:
Yz = scaler_resp[fold].transform(Y)
else:
Yz = Y
Z[:,i] = nm.get_mcmc_zscores(Xz, Yz, **kwargs)
else:
Z[:,i] = (Y - Yhat[:, i]) / np.sqrt(S2[:, i])



if respfile is None:
save_results(None, Yhat, S2, None, outputsuffix=outputsuffix)
Expand Down Expand Up @@ -881,8 +893,6 @@ def predict(covfile, respfile, maskfile=None, **kwargs):
else:
warp = False

Z = (Y - Yhat) / np.sqrt(S2)

print("Evaluating the model ...")
if meta_data and not warp:

Expand Down
21 changes: 14 additions & 7 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,21 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
if 'posterior_predictive' in self.hbr.idata.groups():
del self.hbr.idata.posterior_predictive

if self.configs["transferred"] == True:
self.predict_on_new_sites(
X=X,
batch_effects=batch_effects
)
#var_names = ["y_like"]
else:
self.hbr.predict(
# Do a forward to get the posterior predictive in the idata
self.hbr.predict(
X=X,
batch_effects=batch_effects,
batch_effects_maps=self.batch_effects_maps,
pred="single",
var_names=var_names+["y_like"],
)
X=X,
batch_effects=batch_effects,
batch_effects_maps=self.batch_effects_maps,
pred="single",
var_names=var_names+["y_like"],
)

# Extract the relevant samples from the idata
post_pred = az.extract(
Expand Down
Loading

0 comments on commit c40a2d3

Please sign in to comment.