Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
fingoldo committed Jul 9, 2024
1 parent 58b86b7 commit 79ba8e6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mlframe/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,11 @@ def create_robustness_subgroups(
return subgroups


def create_robustness_subgroups_indices(subgroups: dict, train_idx: np.ndarray, val_idx: np.ndarray, group_weights: dict = {}, cont_nbins: int = 3) -> dict:
def create_robustness_subgroups_indices(subgroups: dict, train_idx: np.ndarray, val_idx: np.ndarray, test_idx: np.ndarray, group_weights: dict = {}, cont_nbins: int = 3) -> dict:
res = {}
for arr in (train_idx, val_idx):
if len(val_idx)==len(test_idx):
logger.warning(f"Validation and test sets have the same size. Robustness subgroups estimation will be incorrect.")
for arr in (train_idx, test_idx,val_idx):
npoints = len(arr)
robustness_subgroups_indices = {}
for group_name, group_params in subgroups.items():
Expand Down

0 comments on commit 79ba8e6

Please sign in to comment.