diff --git a/src/pydvl/utils/__init__.py b/src/pydvl/utils/__init__.py index 7beac241f..245c596dd 100644 --- a/src/pydvl/utils/__init__.py +++ b/src/pydvl/utils/__init__.py @@ -3,6 +3,7 @@ from .config import * from .dataset import * from .numeric import * +from .progress import * from .score import * from .status import * from .types import * diff --git a/src/pydvl/utils/progress.py b/src/pydvl/utils/progress.py new file mode 100644 index 000000000..bbcfca682 --- /dev/null +++ b/src/pydvl/utils/progress.py @@ -0,0 +1,30 @@ +from collections.abc import Iterator +from itertools import cycle, takewhile +from typing import Collection + +from tqdm.auto import tqdm + +from pydvl.value.result import ValuationResult +from pydvl.value.stopping import StoppingCriterion + +__all__ = ["repeat_indices"] + + +def repeat_indices( + indices: Collection[int], result: ValuationResult, done: StoppingCriterion, **kwargs +) -> Iterator[int]: + """Helper function to cycle indefinitely over a collection of indices + until the stopping criterion is satisfied while displaying progress. + + Args: + indices: Collection of indices that will be cycled until done. + result: Object containing the current results. + done: Stopping criterion. + kwargs: Keyword arguments passed to tqdm. + """ + with tqdm(total=100, unit="%", **kwargs) as pbar: + it = takewhile(lambda _: not done(result), cycle(indices)) + for i in it: + yield i + pbar.update(100 * done.completion() - pbar.n) + pbar.refresh() diff --git a/src/pydvl/value/shapley/montecarlo.py b/src/pydvl/value/shapley/montecarlo.py index e6f1dbf2a..aabc2d813 100644 --- a/src/pydvl/value/shapley/montecarlo.py +++ b/src/pydvl/value/shapley/montecarlo.py @@ -47,7 +47,6 @@ import operator from concurrent.futures import FIRST_COMPLETED, Future, wait from functools import reduce -from itertools import cycle, takewhile from typing import Optional, Sequence, Union import numpy as np @@ -65,6 +64,7 @@ init_parallel_backend, ) from pydvl.utils.numeric import random_powerset +from pydvl.utils.progress import repeat_indices from pydvl.utils.types import Seed, ensure_seed_sequence from pydvl.utils.utility import Utility from pydvl.value.result import ValuationResult @@ -281,11 +281,10 @@ def _combinatorial_montecarlo_shapley( ) rng = np.random.default_rng(seed) - repeat_indices = takewhile(lambda _: not done(result), cycle(indices)) - pbar = tqdm(disable=not progress, position=job_id, total=100, unit="%") - for idx in repeat_indices: - pbar.n = 100 * done.completion() - pbar.refresh() + + for idx in repeat_indices( + indices, result=result, done=done, disable=not progress, position=job_id + ): # Randomly sample subsets of full dataset without idx subset = np.setxor1d(u.data.indices, [idx], assume_unique=True) s = next(random_powerset(subset, n_samples=1, seed=rng)) diff --git a/src/pydvl/value/shapley/owen.py b/src/pydvl/value/shapley/owen.py index 07b9e972b..2d7cde6ba 100644 --- a/src/pydvl/value/shapley/owen.py +++ b/src/pydvl/value/shapley/owen.py @@ -9,15 +9,14 @@ import operator from enum import Enum from functools import reduce -from itertools import cycle, takewhile from typing import Optional, Sequence import numpy as np from numpy.typing import NDArray -from tqdm import tqdm from pydvl.parallel import MapReduceJob, ParallelConfig from pydvl.utils import Utility, random_powerset +from pydvl.utils.progress import repeat_indices from pydvl.utils.types import Seed from pydvl.value import ValuationResult from pydvl.value.stopping import MinUpdates @@ -76,11 +75,10 @@ def _owen_sampling_shapley( rng = np.random.default_rng(seed) done = MinUpdates(1) - repeat_indices = takewhile(lambda _: not done(result), cycle(indices)) - pbar = tqdm(disable=not progress, position=job_id, total=100, unit="%") - for idx in repeat_indices: - pbar.n = 100 * done.completion() - pbar.refresh() + + for idx in repeat_indices( + indices, result=result, done=done, disable=not progress, position=job_id + ): e = np.zeros(max_q) subset = np.setxor1d(u.data.indices, [idx], assume_unique=True) for j, q in enumerate(q_steps):