From 67d787a1903339cabfda19f3b688b905c8595846 Mon Sep 17 00:00:00 2001 From: Anthony Chiu Date: Tue, 16 Jan 2024 16:02:07 +0800 Subject: [PATCH] LOFOImportance accepts groups params for GroupKFold cv (#57) * Add groups option * Add groups option * Update lofo_importance.py --- lofo/lofo_importance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lofo/lofo_importance.py b/lofo/lofo_importance.py index 1e2b313..7ed749f 100644 --- a/lofo/lofo_importance.py +++ b/lofo/lofo_importance.py @@ -24,9 +24,12 @@ class LOFOImportance: Same as cv in sklearn API n_jobs: int, optional Number of jobs for parallel computation + cv_groups: array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into train/test set. + Only used in conjunction with a “Group” cv instance (e.g., GroupKFold). """ - def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None): + def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None, cv_groups=None): self.fit_params = fit_params if fit_params else dict() if model is None: @@ -38,6 +41,7 @@ def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=N self.dataset = dataset self.scoring = scoring self.cv = cv + self.cv_groups = cv_groups self.n_jobs = n_jobs if self.n_jobs is not None and self.n_jobs > 1: warning_str = ("Warning: If your model is multithreaded, please initialise the number" @@ -50,7 +54,7 @@ def _get_cv_score(self, feature_to_remove): with warnings.catch_warnings(): warnings.simplefilter("ignore") - cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params) + cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params, groups=self.cv_groups) return cv_results['test_score'] def _get_cv_score_parallel(self, feature, result_queue):