Skip to content

Commit

Permalink
basic test
Browse files Browse the repository at this point in the history
  • Loading branch information
HeikoSchuett committed Aug 1, 2024
1 parent b94dd6c commit 52902a8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/rsatoolbox/rdm/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def rescale(rdms, method: str = 'evidence', threshold=1e-8):
)


def _mean(vectors: ndarray, weights: ndarray = None) -> ndarray:
def _mean(vectors: ndarray, weights: Optional[ndarray] = None) -> ndarray:
"""Weighted mean of RDM vectors, ignores nans
See :meth:`rsatoolbox.rdm.rdms.RDMs.mean`
Expand Down
19 changes: 19 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ def test_two_rdms(self):
np.nanmean(np.abs(rdiff_reg_opt)), 0.001,
msg_tem.format('regression', 'optimization', i_method))

def test_normalize_flag(self):
from rsatoolbox.model import ModelWeighted
from rsatoolbox.model.fitter import fit_regress
from rsatoolbox.rdm import concat, compare
self.sample_data()
model_rdms = concat([self.rdms[0], self.rdms[1]])
model_weighted = ModelWeighted(
'm_weighted',
model_rdms)
for i_method in ['cosine', 'corr', 'cosine_cov', 'corr_cov']:
theta = fit_regress(
model_weighted, self.rdms, method=i_method)
theta_no_normalized = fit_regress(
model_weighted, self.rdms, method=i_method, normalize=False)
rdm = model_weighted.predict(theta)
rdm_no_normalized = model_weighted.predict(theta_no_normalized)
assert compare(rdm, rdm_no_normalized) > 0.999


def test_two_rdms_nn(self):
from rsatoolbox.model import ModelInterpolate, ModelWeighted
from rsatoolbox.model.fitter import fit_regress_nn, fit_optimize_positive
Expand Down

0 comments on commit 52902a8

Please sign in to comment.