Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LOFOImportance accepts groups params for GroupKFold cv #57

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lofo/lofo_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ class LOFOImportance:
Same as cv in sklearn API
n_jobs: int, optional
Number of jobs for parallel computation
cv_groups: array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into train/test set.
Only used in conjunction with a “Group” cv instance (e.g., GroupKFold).
"""

def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None):
def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None, cv_groups=None):

self.fit_params = fit_params if fit_params else dict()
if model is None:
Expand All @@ -38,6 +41,7 @@ def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=N
self.dataset = dataset
self.scoring = scoring
self.cv = cv
self.cv_groups = cv_groups
self.n_jobs = n_jobs
if self.n_jobs is not None and self.n_jobs > 1:
warning_str = ("Warning: If your model is multithreaded, please initialise the number"
Expand All @@ -50,7 +54,7 @@ def _get_cv_score(self, feature_to_remove):

with warnings.catch_warnings():
warnings.simplefilter("ignore")
cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params)
cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params, groups=self.cv_groups)
return cv_results['test_score']

def _get_cv_score_parallel(self, feature, result_queue):
Expand Down