From b6aa23fdf808c3c71ccc274da4d9d61bd1991ec5 Mon Sep 17 00:00:00 2001 From: Remi Gau Date: Mon, 29 Jul 2024 10:58:32 +0200 Subject: [PATCH] refactor --- .github/workflows/run_tests_cli.yml | 14 +- miss_hit.cfg | 2 +- src/batches/stats/setBatchFactorialDesign.m | 48 ++---- .../stats/setBatchGroupLevelContrasts.m | 39 ++--- .../stats/setBatchSubjectLevelContrasts.m | 2 +- src/batches/stats/setBatchTwoSampleTTest.m | 3 +- src/bids_model/BidsModel.m | 157 +++++++++++++----- src/bids_model/checkContrast.m | 3 +- src/bids_model/getContrastsList.m | 6 +- .../getContrastsListForFactorialDesign.m | 15 +- src/bids_model/getDummyContrastsList.m | 6 +- src/stats/group_level/groupLevelGlmType.m | 42 ----- src/stats/subject_level/specifyContrasts.m | 2 +- .../subject_level/specifyDummyContrasts.m | 2 +- .../removeIntercept.m | 0 src/workflows/stats/bidsFFX.m | 6 +- src/workflows/stats/bidsRFX.m | 2 +- src/workflows/stats/bidsResults.m | 10 +- .../test_groupLevelGlmType.m | 25 +-- tests/tests_bids_model/test_validateGroupBy.m | 42 ++--- 20 files changed, 218 insertions(+), 208 deletions(-) delete mode 100644 src/stats/group_level/groupLevelGlmType.m rename src/stats/{subject_level => utils}/removeIntercept.m (100%) rename tests/{tests_stats/group_level => tests_bids_model}/test_groupLevelGlmType.m (63%) diff --git a/.github/workflows/run_tests_cli.yml b/.github/workflows/run_tests_cli.yml index a1a4a3988..eb3b15512 100644 --- a/.github/workflows/run_tests_cli.yml +++ b/.github/workflows/run_tests_cli.yml @@ -60,10 +60,10 @@ jobs: coverage run --source src -m pytest coverage xml - - name: Code coverage - uses: codecov/codecov-action@v4 - with: - file: coverage.xml - flags: cli - name: codecov-cli - fail_ci_if_error: false + # - name: Code coverage + # uses: codecov/codecov-action@v4 + # with: + # file: coverage.xml + # flags: cli + # name: codecov-cli + # fail_ci_if_error: false diff --git a/miss_hit.cfg b/miss_hit.cfg index e67dea617..cd8e82beb 100644 --- a/miss_hit.cfg +++ b/miss_hit.cfg @@ -37,6 +37,6 @@ tab_width: 2 # metrics limit for the code quality (https://florianschanda.github.io/miss_hit/metrics.html) metric "cnest": limit 5 -metric "file_length": limit 800 +metric "file_length": limit 1000 metric "cyc": limit 22 metric "parameters": limit 7 diff --git a/src/batches/stats/setBatchFactorialDesign.m b/src/batches/stats/setBatchFactorialDesign.m index a256f9836..280a628d5 100644 --- a/src/batches/stats/setBatchFactorialDesign.m +++ b/src/batches/stats/setBatchFactorialDesign.m @@ -64,17 +64,11 @@ % TODO all contrasts should have the same name contrasts = getContrastsListForFactorialDesign(opt, nodeName); - % TODO refactor - designMatrix = opt.model.bm.get_design_matrix('Name', nodeName); - designMatrix = cellfun(@(x) num2str(x), designMatrix, 'uniformoutput', false); - - groupColumnHdr = setxor(designMatrix, {'1'}); - groupColumnHdr = groupColumnHdr{1}; - % Sorting is important so that we know in which order % the groups are entered in the design matrix. % Otherwise it will be harder to properly design % the contrast vectors later. + groupColumnHdr = opt.model.bm.getGroupColumnHdrFromDesignMatrix(nodeName); availableGroups = getAvailableGroups(opt, groupColumnHdr); label = '1WayANOVA'; @@ -101,10 +95,9 @@ % Note that this will lead to different results % depending on the requested subejcts % - tsvFile = fullfile(opt.dir.raw, 'participants.tsv'); - tsv = bids.util.tsvread(tsvFile); - subjectsInGroup = strcmp(tsv.(groupColumnHdr), thisGroup); - subjectsLabel = regexprep(tsv.participant_id(subjectsInGroup), '^sub-', ''); + participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); + subjectsInGroup = strcmp(participants.(groupColumnHdr), thisGroup); + subjectsLabel = regexprep(participants.participant_id(subjectsInGroup), '^sub-', ''); subjectsLabel = intersect(subjectsLabel, opt.subjects); % collect all con images from all subjects @@ -183,13 +176,8 @@ contrasts = getContrastsListForFactorialDesign(opt, nodeName); - node = opt.model.bm.get_nodes('Name', nodeName); - groupBy = node.GroupBy; - - % TODO refactor - groupColumnHdr = setxor(groupBy, {'contrast'}); - groupColumnHdr = groupColumnHdr{1}; - + participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); + groupColumnHdr = opt.model.bm.getGroupColumnHdrFromGroupBy(nodeName, participants); availableGroups = getAvailableGroups(opt, groupColumnHdr); for iGroup = 1:numel(availableGroups) @@ -202,10 +190,8 @@ % Note that this will lead to different results depending on the requested % subejcts % - tsvFile = fullfile(opt.dir.raw, 'participants.tsv'); - tsv = bids.util.tsvread(tsvFile); - subjectsInGroup = strcmp(tsv.(groupColumnHdr), thisGroup); - subjectsLabel = regexprep(tsv.participant_id(subjectsInGroup), '^sub-', ''); + subjectsInGroup = strcmp(participants.(groupColumnHdr), thisGroup); + subjectsLabel = regexprep(participants.participant_id(subjectsInGroup), '^sub-', ''); subjectsLabel = intersect(subjectsLabel, opt.subjects); % collect all con images from all subjects @@ -301,6 +287,7 @@ % If 2 groups, then number of levels = 2 factorialDesign.des.fd.fact.name = thisGroup; factorialDesign.des.fd.fact.levels = numel(icell); + factorialDesign.des.fd.fact.dept = 0; % 1: Assumes that the variance is not the same across groups @@ -338,11 +325,11 @@ factorialDesign.masking.tm.tm_none = 1; factorialDesign.masking.im = 1; + factorialDesign.masking.em = {''}; - factorialDesign.globalc.g_omit = 1; - factorialDesign.globalm.gmsca.gmsca_no = 1; - factorialDesign.globalm.glonorm = 1; + factorialDesign = setBatchFactorialDesignImplicitMasking(factorialDesign); + factorialDesign = setBatchFactorialDesignGlobalCalcAndNorm(factorialDesign); matlabbatch{end + 1}.spm.stats.factorial_design = factorialDesign; @@ -354,11 +341,12 @@ % TODO refactor participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); - columns = fieldnames(participants); - status = opt.model.bm.validateGroupBy(nodeName, columns); + model = opt.model.bm; + + status = model.validateGroupBy(nodeName, participants); - [glmType, ~, groupBy] = groupLevelGlmType(opt, nodeName, participants); + [glmType, groupBy] = model.groupLevelGlmType(nodeName, participants); % only certain type of model supported for now if ismember(glmType, {'unknown', 'two_sample_t_test'}) @@ -370,8 +358,8 @@ return end - datasetLvlContrasts = opt.model.bm.get_contrasts('Name', nodeName); - datasetLvlDummyContrasts = opt.model.bm.get_dummy_contrasts('Name', nodeName); + datasetLvlContrasts = model.get_contrasts('Name', nodeName); + datasetLvlDummyContrasts = model.get_dummy_contrasts('Name', nodeName); if isempty(datasetLvlContrasts) && isempty(datasetLvlDummyContrasts) msg = sprintf('No contrast specified %s', commonMsg); diff --git a/src/batches/stats/setBatchGroupLevelContrasts.m b/src/batches/stats/setBatchGroupLevelContrasts.m index 2ae22aa13..8db9a2006 100644 --- a/src/batches/stats/setBatchGroupLevelContrasts.m +++ b/src/batches/stats/setBatchGroupLevelContrasts.m @@ -26,17 +26,21 @@ printBatchName('group level contrast estimation', opt); - % TODO refactor + model = opt.model.bm; + participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); - [groupGlmType, designMatrix, groupBy] = groupLevelGlmType(opt, nodeName, participants); + groupColumnHdr = model.getGroupColumnHdrFromGroupBy(nodeName, participants); + availableGroups = getAvailableGroups(opt, groupColumnHdr); + + [groupGlmType, groupBy] = model.groupLevelGlmType(nodeName, participants); switch groupGlmType case 'one_sample_t_test' contrastsList = getContrastsListForFactorialDesign(opt, nodeName); - if all(ismember(lower(groupBy), {'contrast'})) + if numel(groupBy) == 1 && ismember(lower(groupBy), {'contrast'}) for j = 1:numel(contrastsList) @@ -48,11 +52,7 @@ end - % TODO make more general than just with group - elseif all(ismember(lower(groupBy), {'contrast', 'group'})) - % TODO make more general than just with group - groupColumnHdr = groupBy{ismember(lower(groupBy), {'group'})}; - availableGroups = unique(participants.(groupColumnHdr)); + elseif numel(groupBy) == 2 && any(ismember(groupBy, fieldnames(participants))) for j = 1:numel(contrastsList) @@ -80,14 +80,10 @@ % through the Edge filter. % Then generate the between group contrasts. - designMatrix = removeIntercept(designMatrix); - - groups = unique(participants.(designMatrix{1})); - - edge = opt.model.bm.get_edge('Destination', nodeName); + edge = model.get_edge('Destination', nodeName); contrastsList = edge.Filter.contrast; - thisContrast = opt.model.bm.get_contrasts('Name', nodeName); + thisContrast = model.get_contrasts('Name', nodeName); for j = 1:numel(contrastsList) @@ -102,14 +98,14 @@ % Sort conditions and weights [ConditionList, I] = sort(thisContrast{iCon}.ConditionList); for iCdt = 1:numel(ConditionList) - ConditionList{iCdt} = strrep(ConditionList{iCdt}, [designMatrix{1}, '.'], ''); + ConditionList{iCdt} = strrep(ConditionList{iCdt}, [groupColumnHdr, '.'], ''); end Weights = thisContrast{iCon}.Weights(I); % Create contrast vectors by what was passed in the model - convec = zeros(size(groups)); - for iGroup = 1:numel(groups) - index = strcmp(groups{iGroup}, ConditionList); + convec = zeros(size(availableGroups)); + for iGroup = 1:numel(availableGroups) + index = strcmp(availableGroups{iGroup}, ConditionList); if any(index) convec(iGroup) = Weights(index); end @@ -128,17 +124,16 @@ designMatrix = removeIntercept(designMatrix); - % TODO make more general than just with group - if ismember(lower(designMatrix), {'group'}) + if any(ismember(designMatrix, fieldnames(participants))) % TODO will this ignore the contrasts define at other levels % and not passed through the filter ? - edge = opt.model.bm.get_edge('Destination', nodeName); + edge = model.get_edge('Destination', nodeName); contrastsList = edge.Filter.contrast; end for j = 1:numel(contrastsList) - thisContrast = opt.model.bm.get_contrasts('Name', nodeName); + thisContrast = model.get_contrasts('Name', nodeName); spmMatFile = fullfile(getRFXdir(opt, nodeName, contrastsList{j}), 'SPM.mat'); diff --git a/src/batches/stats/setBatchSubjectLevelContrasts.m b/src/batches/stats/setBatchSubjectLevelContrasts.m index 70d58ba25..8b4ebdd7c 100644 --- a/src/batches/stats/setBatchSubjectLevelContrasts.m +++ b/src/batches/stats/setBatchSubjectLevelContrasts.m @@ -4,7 +4,7 @@ % % USAGE:: % - % matlabbatch = setBatchSubjectLevelContrasts(matlabbatch, opt, subLabel, funcFWHM) + % matlabbatch = setBatchSubjectLevelContrasts(matlabbatch, opt, subLabel, nodeName) % % :param matlabbatch: % :type matlabbatch: structure diff --git a/src/batches/stats/setBatchTwoSampleTTest.m b/src/batches/stats/setBatchTwoSampleTTest.m index eb9071a6d..4fd4fb182 100644 --- a/src/batches/stats/setBatchTwoSampleTTest.m +++ b/src/batches/stats/setBatchTwoSampleTTest.m @@ -66,9 +66,7 @@ group1 = group1{2}; group2 = group2{2}; - % TODO refactor availableGroups = getAvailableGroups(opt, groupField); - participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); if any(~ismember({group1, group2}, availableGroups)) error(['Some requested group is not present: %s.', ... @@ -85,6 +83,7 @@ conImages{iSub} = findSubjectConImage(opt, subLabel, contrastsList); end + participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); % set up the batch for iCon = 1:numel(contrastsList) diff --git a/src/bids_model/BidsModel.m b/src/bids_model/BidsModel.m index 70fc649ff..0c49baeae 100644 --- a/src/bids_model/BidsModel.m +++ b/src/bids_model/BidsModel.m @@ -347,6 +347,82 @@ end + function [type, groupBy] = groupLevelGlmType(obj, nodeName, participants) + % + % Return type of GLM for a dataset level node. + % + % USAGE:: + % + % [type, groupBy] = bm.groupLevelGlmType(obj, nodeName, participants) + % + % :param nodeName: + % :type nodeName: char + % + % :param participants: content of participants.tsv + % :type participants: struct + % + if nargin < 3 + participants = struct(); + end + + node = obj.get_nodes('Name', nodeName); + + groupBy = node.GroupBy; + type = 'unknown'; + + if ~strcmpi(node.Level, 'Dataset') + return + end + + designMatrix = node.Model.X; + + if isnumeric(designMatrix) && designMatrix == 1 + type = 'one_sample_t_test'; + + elseif iscell(designMatrix) && numel(designMatrix) == 2 + + groupColumHdr = getGroupColumnHdrFromDesignMatrix(obj, nodeName); + + if isempty(groupColumHdr) || ~isfield(participants, groupColumHdr) + type = 'unknown'; + else + levels = participants.(groupColumHdr); + switch numel(unique(levels)) + case 1 + type = 'one_sample_t_test'; + case 2 + type = 'two_sample_t_test'; + otherwise + type = 'one_way_anova'; + end + end + + end + + end + + function groupColumnHdr = getGroupColumnHdrFromDesignMatrix(obj, nodeName) + node = obj.get_nodes('Name', nodeName); + designMatrix = node.Model.X; + if iscell(designMatrix) && numel(designMatrix) == 2 + designMatrix = removeIntercept(designMatrix); + groupColumnHdr = designMatrix{1}; + else + groupColumnHdr = ''; + end + end + + function groupColumnHdr = getGroupColumnHdrFromGroupBy(obj, nodeName, participants) + node = obj.get_nodes('Name', nodeName); + groupBy = node.GroupBy; + groupColumnHdr = intersect(groupBy, fieldnames(participants)); + if isempty(groupColumnHdr) + groupColumnHdr = ''; + else + groupColumnHdr = groupColumnHdr{1}; + end + end + function obj = addConfoundsToDesignMatrix(obj, varargin) % % Add some typical confounds to the design matrix of bids stat model. @@ -517,7 +593,11 @@ function validateRootNode(obj) end end - function status = validateGroupBy(obj, node, extraVar) + function status = validateGroupBy(obj, nodeName, participants) + % + % USAGE:: + % + % bm = bm.validateGroupBy(nodeName, participants); % % Only certain type of GroupBy supported for now for each level % @@ -526,76 +606,60 @@ function validateRootNode(obj) % (C) Copyright 2022 bidspm developers - status = true; + status = false; - if ischar(node) - node = obj.get_nodes('Name', node); - if isempty(node) - return - end + node = obj.get_nodes('Name', nodeName); + if isempty(node) + return end groupBy = sort(node.GroupBy); if nargin < 3 - extraVar = {}; + participants = struct(); end + extraVar = fieldnames(participants); + switch lower(node.Level) case 'run' % only certain type of GroupBy supported for now - if ~ismember('run', groupBy) || ... - ~all(ismember(groupBy, {'run', 'session', 'subject'})) - - status = false; - - supportedGroupBy = {'["run", "subject"]', ... - '["run", "session", "subject"]'}; - + supportedGroupBy = {'["run", "subject"]', '["run", "session", "subject"]'}; + if ismember('run', groupBy) && ... + all(ismember(groupBy, {'run', 'session', 'subject'})) + status = true; end case 'session' - if ~(numel(groupBy) == 3) || ... - ~all(ismember(groupBy, {'contrast', 'session', 'subject'})) - - status = false; - - supportedGroupBy = {'["contrast", "session", "subject"]'}; - + % only certain type of GroupBy supported for now + supportedGroupBy = {'["contrast", "session", "subject"]'}; + if numel(groupBy) == 3 && ... + all(ismember(groupBy, {'contrast', 'session', 'subject'})) + status = true; end case 'subject' - if ~(numel(groupBy) == 2) || ... - not(all([ismember('contrast', groupBy) ismember('subject', groupBy)])) - - status = false; - - supportedGroupBy = {'["contrast", "subject"]'}; - + supportedGroupBy = {'["contrast", "subject"]'}; + if numel(groupBy) == 2 && ... + all(ismember(groupBy, {'contrast', 'subject'})) + status = true; end case 'dataset' + % only certain type of GroupBy supported for now supportedGroupBy = {'["contrast"]', ... '["contrast", "x"] for "x" being a participant.tsv column name.'}; - % only certain type of GroupBy supported for now - status = false; - if numel(groupBy) == 1 && all(ismember(lower(groupBy), {'contrast'})) + if numel(groupBy) == 1 && ismember(lower(groupBy), {'contrast'}) + status = true; + elseif numel(groupBy) == 2 && any(ismember(lower(groupBy), {'contrast'})) && ... + iscellstr(extraVar) && numel(extraVar) > 0 && any(ismember(groupBy, extraVar)) status = true; - - elseif numel(groupBy) == 2 && iscellstr(extraVar) && numel(extraVar) > 0 - for i = 1:numel(extraVar) - if all(ismember(groupBy, {'contrast', extraVar{i}})) - status = true; - break - end - end - end end @@ -758,3 +822,12 @@ function bidsModelError(obj, id, msg) end end + +function designMatrix = removeIntercept(designMatrix) + % + % remove intercept because SPM includes it anyway + % + + isIntercept = cellfun(@(x) (numel(x) == 1) && (x == 1), designMatrix, 'UniformOutput', true); + designMatrix(isIntercept) = []; +end diff --git a/src/bids_model/checkContrast.m b/src/bids_model/checkContrast.m index 76b52c003..d40c514f7 100644 --- a/src/bids_model/checkContrast.m +++ b/src/bids_model/checkContrast.m @@ -15,8 +15,7 @@ model.validate_constrasts(node); - if ~ismember(lower(node.Level), {'run', 'session', 'subject'}) && ... - ~isTtest(node.Contrasts(iCon)) + if ismember(lower(node.Level), {'session'}) && ~isTtest(node.Contrasts(iCon)) notImplemented(mfilename(), 'Only t test implemented for Contrasts'); contrast = []; return diff --git a/src/bids_model/getContrastsList.m b/src/bids_model/getContrastsList.m index ee02a4274..323a800e5 100644 --- a/src/bids_model/getContrastsList.m +++ b/src/bids_model/getContrastsList.m @@ -1,4 +1,4 @@ -function contrastsList = getContrastsList(model, node, extraVar) +function contrastsList = getContrastsList(model, node, participants) % % Get list of names of Contrast from this Node % or gets its from the previous Nodes @@ -19,7 +19,7 @@ % (C) Copyright 2022 bidspm developers if nargin < 3 - extraVar = {}; + participants = struct(); end contrastsList = {}; @@ -48,7 +48,7 @@ % TODO relax those assumptions % assumptions - assert(model.validateGroupBy(node, extraVar)); + assert(model.validateGroupBy(node.Name participants)); assert(node.Model.X == 1); contrastsList = getContrastsFromParentNode(model, node); diff --git a/src/bids_model/getContrastsListForFactorialDesign.m b/src/bids_model/getContrastsListForFactorialDesign.m index acda9f41c..3307a1dc3 100644 --- a/src/bids_model/getContrastsListForFactorialDesign.m +++ b/src/bids_model/getContrastsListForFactorialDesign.m @@ -14,17 +14,16 @@ % (C) Copyright 2022 bidspm developers - % assuming we want to only average at the group level - - % TODO refactor + % assuming we want to only average / comparisons at the group level participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv')); - columns = fieldnames(participants); - [groupGlmType, ~, ~] = groupLevelGlmType(opt, nodeName, participants); + model = opt.model.bm; + + groupGlmType = model.groupLevelGlmType(nodeName, participants); if ismember(groupGlmType, {'one_sample_t_test', 'one_way_anova'}) - edge = opt.model.bm.get_edge('Destination', nodeName); + edge = model.get_edge('Destination', nodeName); if isfield(edge, 'Filter') && ... isfield(edge.Filter, 'contrast') && ... @@ -35,9 +34,9 @@ else % this assumes DummyContrasts exist - contrastsList = getDummyContrastsList(opt.model.bm, nodeName, columns); + contrastsList = getDummyContrastsList(opt.model.bm, nodeName, participants); - node = opt.model.bm.get_nodes('Name', nodeName); + node = model.get_nodes('Name', nodeName); % if no specific dummy contrasts mentioned also include all contrasts from previous levels % or if contrasts are mentioned we grab them diff --git a/src/bids_model/getDummyContrastsList.m b/src/bids_model/getDummyContrastsList.m index bb2b6c8c0..6e6cd1dec 100644 --- a/src/bids_model/getDummyContrastsList.m +++ b/src/bids_model/getDummyContrastsList.m @@ -1,4 +1,4 @@ -function dummyContrastsList = getDummyContrastsList(model, node, extraVar) +function dummyContrastsList = getDummyContrastsList(model, node, participants) % % Get list of names of DummyContrast from this Node or gets its from the % previous Nodes @@ -19,7 +19,7 @@ % (C) Copyright 2022 bidspm developers if nargin < 3 - extraVar = {}; + participants = struct(); end dummyContrastsList = {}; @@ -37,7 +37,7 @@ else - assert(model.validateGroupBy(node, extraVar)); + assert(model.validateGroupBy(node.Name, participants)); switch lower(node.Level) diff --git a/src/stats/group_level/groupLevelGlmType.m b/src/stats/group_level/groupLevelGlmType.m deleted file mode 100644 index ba048465e..000000000 --- a/src/stats/group_level/groupLevelGlmType.m +++ /dev/null @@ -1,42 +0,0 @@ -function [type, srcDesignMatrix, groupBy] = groupLevelGlmType(opt, nodeName, participants) - % - - % (C) Copyright 2022 bidspm developers - - if nargin < 3 - participants = struct(); - end - - % TODO refactor - columns = fieldnames(participants); - - node = opt.model.bm.get_nodes('Name', nodeName); - groupBy = node.GroupBy; - srcDesignMatrix = node.Model.X; - - type = 'unknown'; - if isnumeric(srcDesignMatrix) && srcDesignMatrix == 1 - type = 'one_sample_t_test'; - - elseif iscell(srcDesignMatrix) && numel(srcDesignMatrix) == 2 - - designMatrix = cellfun(@(x) num2str(x), srcDesignMatrix, 'uniformoutput', false); - - for i = 1:numel(columns) - if all(ismember(designMatrix, {'1', columns{i}})) - levels = participants.(columns{i}); - switch numel(unique(levels)) - case 1 - type = 'one_sample_t_test'; - case 2 - type = 'two_sample_t_test'; - otherwise - type = 'one_way_anova'; - end - break - end - end - - end - -end diff --git a/src/stats/subject_level/specifyContrasts.m b/src/stats/subject_level/specifyContrasts.m index 9992be005..b96924d6f 100644 --- a/src/stats/subject_level/specifyContrasts.m +++ b/src/stats/subject_level/specifyContrasts.m @@ -57,7 +57,7 @@ node = node{1}; end - if ismember(lower(node.Level), {'session', 'subject'}) && ~model.validateGroupBy(node) + if ismember(lower(node.Level), {'session', 'subject'}) && ~model.validateGroupBy(node.Name) continue end diff --git a/src/stats/subject_level/specifyDummyContrasts.m b/src/stats/subject_level/specifyDummyContrasts.m index bf2047d41..217c8bd93 100644 --- a/src/stats/subject_level/specifyDummyContrasts.m +++ b/src/stats/subject_level/specifyDummyContrasts.m @@ -26,7 +26,7 @@ return end - if ismember(node.Level, {'Subject', 'Session'}) && ~model.validateGroupBy(node) + if ismember(node.Level, {'Subject', 'Session'}) && ~model.validateGroupBy(node.Name) return end if ismember(node.Level, {'Dataset'}) diff --git a/src/stats/subject_level/removeIntercept.m b/src/stats/utils/removeIntercept.m similarity index 100% rename from src/stats/subject_level/removeIntercept.m rename to src/stats/utils/removeIntercept.m diff --git a/src/workflows/stats/bidsFFX.m b/src/workflows/stats/bidsFFX.m index c04d50be7..394aa0edf 100644 --- a/src/workflows/stats/bidsFFX.m +++ b/src/workflows/stats/bidsFFX.m @@ -117,11 +117,7 @@ opt.model.bm.validateConstrasts(); - if isempty(nodeName) - matlabbatch = setBatchSubjectLevelContrasts(matlabbatch, opt, subLabel); - else - matlabbatch = setBatchSubjectLevelContrasts(matlabbatch, opt, subLabel, nodeName); - end + matlabbatch = setBatchSubjectLevelContrasts(matlabbatch, opt, subLabel, nodeName); end diff --git a/src/workflows/stats/bidsRFX.m b/src/workflows/stats/bidsRFX.m index 035f960d2..b73bc94a0 100644 --- a/src/workflows/stats/bidsRFX.m +++ b/src/workflows/stats/bidsRFX.m @@ -99,7 +99,7 @@ nodeName = datasetNodes{i}.Name; - switch groupLevelGlmType(opt, nodeName, participants) + switch opt.model.bm.groupLevelGlmType(nodeName, participants) case {'one_sample_t_test', 'one_way_anova'} [matlabbatch, contrastsList, groups] = setBatchFactorialDesign(matlabbatch, ... diff --git a/src/workflows/stats/bidsResults.m b/src/workflows/stats/bidsResults.m index 286061c57..d9ff61123 100644 --- a/src/workflows/stats/bidsResults.m +++ b/src/workflows/stats/bidsResults.m @@ -443,7 +443,9 @@ matlabbatch = {}; - node = opt.model.bm.get_nodes('Name', opt.results(iRes).nodeName); + model = opt.model.bm; + + node = model.get_nodes('Name', opt.results(iRes).nodeName); opt = checkMontage(opt, iRes, node); @@ -461,7 +463,7 @@ logger('WARNING', msg, 'id', id, 'options', opt, 'filename', mfilename()); end - [glmType, ~, groupBy] = groupLevelGlmType(opt, result.nodeName, participants); + [glmType, groupBy] = model.groupLevelGlmType(result.nodeName, participants); switch glmType @@ -480,7 +482,7 @@ % TODO make more general than just with group groupColumnHdr = groupBy{ismember(lower(groupBy), {'group'})}; - availableGroups = unique(participants.(groupColumnHdr)); + availableGroups = getAvailableGroups(opt, groupColumnHdr); for iGroup = 1:numel(availableGroups) @@ -497,7 +499,7 @@ case 'two_sample_t_test' - thisContrast = opt.model.bm.get_contrasts('Name', result.nodeName); + thisContrast = model.get_contrasts('Name', result.nodeName); result.dir = getRFXdir(opt, result.nodeName, name); diff --git a/tests/tests_stats/group_level/test_groupLevelGlmType.m b/tests/tests_bids_model/test_groupLevelGlmType.m similarity index 63% rename from tests/tests_stats/group_level/test_groupLevelGlmType.m rename to tests/tests_bids_model/test_groupLevelGlmType.m index 5deb2c053..e26dfb2ba 100644 --- a/tests/tests_stats/group_level/test_groupLevelGlmType.m +++ b/tests/tests_bids_model/test_groupLevelGlmType.m @@ -11,36 +11,37 @@ function test_groupLevelGlmType_basic() opt = setOptions('vismotion', {'01' 'ctrl01'}, 'pipelineType', 'stats'); - opt.model.bm.Nodes{3}.GroupBy = {'contrast', 'group'}; - type = groupLevelGlmType(opt, 'dataset_level'); + model = opt.model.bm; + model.Nodes{3}.GroupBy = {'contrast', 'group'}; + type = model.groupLevelGlmType('dataset_level'); assertEqual(type, 'one_sample_t_test'); - opt.model.bm.Nodes{3}.Model.X = {1, 'group'}; - type = groupLevelGlmType(opt, 'dataset_level'); + model.Nodes{3}.Model.X = {1, 'group'}; + type = model.groupLevelGlmType('dataset_level'); assertEqual(type, 'unknown'); - opt.model.bm.Nodes{3}.Model.X = {1, 'diagnostic'}; + model.Nodes{3}.Model.X = {1, 'diagnostic'}; participants = struct('participant_id', {{'01', '02'}}, ... 'diagnostic', {{'ctrl', 'ctrl'}}); - type = groupLevelGlmType(opt, 'dataset_level', participants); + type = model.groupLevelGlmType('dataset_level', participants); assertEqual(type, 'one_sample_t_test'); - opt.model.bm.Nodes{3}.Model.X = {1, 'diagnostic'}; + model.Nodes{3}.Model.X = {1, 'diagnostic'}; participants = struct('participant_id', {{'01', '02'}}, ... 'diagnostic', {{'ctrl', 'patient'}}); - type = groupLevelGlmType(opt, 'dataset_level', participants); + type = model.groupLevelGlmType('dataset_level', participants); assertEqual(type, 'two_sample_t_test'); - opt.model.bm.Nodes{3}.Model.X = {1, 'group'}; + model.Nodes{3}.Model.X = {1, 'group'}; participants = struct('participant_id', {{'01', '02'}}, ... 'diagnostic', {{'ctrl', 'patient'}}); - type = groupLevelGlmType(opt, 'dataset_level', participants); + type = model.groupLevelGlmType('dataset_level', participants); assertEqual(type, 'unknown'); - opt.model.bm.Nodes{3}.Model.X = {1, 'group'}; + model.Nodes{3}.Model.X = {1, 'group'}; participants = struct('participant_id', {{'01', '02'}}, ... 'group', {{'ctrl', 'patient', 'relative'}}); - type = groupLevelGlmType(opt, 'dataset_level', participants); + type = model.groupLevelGlmType('dataset_level', participants); assertEqual(type, 'one_way_anova'); end diff --git a/tests/tests_bids_model/test_validateGroupBy.m b/tests/tests_bids_model/test_validateGroupBy.m index 40f1e43ce..103950f16 100644 --- a/tests/tests_bids_model/test_validateGroupBy.m +++ b/tests/tests_bids_model/test_validateGroupBy.m @@ -14,14 +14,13 @@ function test_validateGroupBy_run() bm = BidsModel('file', opt.model.file); bm.verbose = false; - node = bm.Nodes{1}; + nodeName = bm.Nodes{1}.Name; - node.GroupBy = {'subject'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{1}.GroupBy = {'subject'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); - node.GroupBy = {'run', 'dataset'}; - bm.Nodes{1} = node; - assertWarning(@()bm.validateGroupBy(node.Name), 'BidsModel:notImplemented'); + bm.Nodes{1}.GroupBy = {'run', 'dataset'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); end @@ -30,13 +29,13 @@ function test_validateGroupBy_subject() opt = setOptions('vismotion', {'01' 'ctrl01'}, 'pipelineType', 'stats'); bm = BidsModel('file', opt.model.file); bm.verbose = false; - node = bm.Nodes{2}; + nodeName = bm.Nodes{2}.Name; - node.GroupBy = {'subject'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{2}.GroupBy = {'subject'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); - node.GroupBy = {'session', 'subject'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{2}.GroupBy = {'session', 'subject'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); end function test_checkGroupBy_dataset() @@ -45,21 +44,21 @@ function test_checkGroupBy_dataset() bm = BidsModel('file', opt.model.file); bm.verbose = false; + nodeName = bm.Nodes{3}.Name; % should be fine - node = bm.Nodes{3}; - node.GroupBy = {'contrast'}; - status = bm.validateGroupBy(node); + bm.Nodes{3}.GroupBy = {'contrast'}; + status = bm.validateGroupBy(nodeName); assertEqual(status, true); - node.GroupBy = {'subject'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{3}.GroupBy = {'subject'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); - node.GroupBy = {'session', 'subject'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{3}.GroupBy = {'session', 'subject'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); - node.GroupBy = {'session', 'subject', 'foo'}; - assertWarning(@()bm.validateGroupBy(node), 'BidsModel:notImplemented'); + bm.Nodes{3}.GroupBy = {'session', 'subject', 'foo'}; + assertWarning(@()bm.validateGroupBy(nodeName), 'BidsModel:notImplemented'); end @@ -69,8 +68,9 @@ function test_checkGroupBy_dataset_group_from_participant() bm = BidsModel('file', opt.model.file); bm.verbose = false; + nodeName = bm.Nodes{3}.Name; bm.Nodes{3}.GroupBy = {'contrast', 'diagnostic'}; - bm.validateGroupBy(bm.Nodes{3}, {'diagnostic'}); + bm.validateGroupBy(nodeName, {'diagnostic'}); end