Skip to content

Commit

Permalink
A quick patch to fix z-scores in SHASH HBR.
Browse files Browse the repository at this point in the history
  • Loading branch information
smkia committed Nov 8, 2023
1 parent 933b553 commit ec9c7cc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
15 changes: 12 additions & 3 deletions pcntoolkit/normative.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def estimate(covfile, respfile, **kwargs):
kwargs['trbefile'] = 'be_kfold_tr_tempfile.pkl'
kwargs['tsbefile'] = 'be_kfold_ts_tempfile.pkl'

# estimate the models for all subjects
# estimate the models for all response variables
for i in range(0, len(nz)):
print("Estimating model ", i+1, "of", len(nz))
nm = norm_init(Xz_tr, Yz_tr[:, i], alg=alg, **kwargs)
Expand Down Expand Up @@ -500,7 +500,14 @@ def estimate(covfile, respfile, **kwargs):
else:
Ytest = Y[ts, nz[i]]

Z[ts, nz[i]] = (Ytest - Yhat[ts, nz[i]]) / \
if alg=='hbr':
if outscaler in ['standardize', 'minmax', 'robminmax']:
Ytestz = Y_scaler.transform(Ytest.reshape(-1,1), index=i)
else:
Ytestz = Ytest.reshape(-1,1)
Z[ts, nz[i]] = nm.get_mcmc_zscores(Xz_ts, Ytestz, **kwargs)
else:
Z[ts, nz[i]] = (Ytest - Yhat[ts, nz[i]]) / \
np.sqrt(S2[ts, nz[i]])

except Exception as e:
Expand Down Expand Up @@ -750,6 +757,7 @@ def predict(covfile, respfile, maskfile=None, **kwargs):
Xz = X

# estimate the models for all subjects
#TODO Z-scores adaptation for SHASH HBR
for i, m in enumerate(models):
print("Prediction by model ", i+1, "of", feature_num)
nm = norm_init(Xz)
Expand Down Expand Up @@ -806,7 +814,7 @@ def predict(covfile, respfile, maskfile=None, **kwargs):

warp_param = nm.blr.hyp[1:nm.blr.warp.get_n_params()+1]
Yw[:,i] = nm.blr.warp.f(Y[:,i], warp_param)
Y = Yw;
Y = Yw
else:
warp = False

Expand Down Expand Up @@ -1063,6 +1071,7 @@ def transfer(covfile, respfile, testcov=None, testresp=None, maskfile=None,
else:
warp = False

#TODO Z-scores adaptation for SHASH HBR
Z = (Yte - Yhat) / np.sqrt(S2)

print("Evaluating the model ...")
Expand Down
22 changes: 13 additions & 9 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,21 +488,25 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
return quantiles.mean(axis=-1)


def get_mcmc_zscores(self, X, y, batch_effects=None):
def get_mcmc_zscores(self, X, y, **kwargs):

"""
Computes zscores of data given an estimated model
Args:
X ([N*p]ndarray): covariates
y ([N*1]ndarray): response variables
batch_effects (ndarray): the batch effects corresponding to X
"""
# Set batch effects to zero if none are provided

print(self.configs['likelihood'])
if batch_effects is None:
batch_effects = batch_effects_test = np.zeros([X.shape[0], 1])


tsbefile = kwargs.get("tsbefile", None)
if tsbefile is not None:
batch_effects_test = fileio.load(tsbefile)
else: # Set batch effects to zero if none are provided
print("Could not find batch-effects file! Initializing all as zeros ...")
batch_effects_test = np.zeros([X.shape[0], 1])

# Determine the variables to predict
if self.configs["likelihood"] == "Normal":
var_names = ["mu_samples", "sigma_samples","sigma_plus_samples"]
Expand All @@ -525,7 +529,7 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):
# Do a forward to get the posterior predictive in the idata
self.hbr.predict(
X=X,
batch_effects=batch_effects,
batch_effects=batch_effects_test,
batch_effects_maps=self.batch_effects_maps,
pred="single",
var_names=var_names+["y_like"],
Expand All @@ -536,7 +540,7 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):
self.hbr.idata, "posterior_predictive", var_names=var_names
)

# Remove superfluous var_nammes
# Remove superfluous var_names
var_names.remove('sigma_samples')
if 'delta_samples' in var_names:
var_names.remove('delta_samples')
Expand All @@ -553,7 +557,7 @@ def get_mcmc_zscores(self, X, y, batch_effects=None):
*array_of_vars,
kwargs={"y": y, "likelihood": self.configs['likelihood']},
)
return z_scores.mean(axis=-1)
return z_scores.mean(axis=-1).values



Expand Down

0 comments on commit ec9c7cc

Please sign in to comment.