Skip to content

Commit

Permalink
Create repeat_indices helper and use it in _owen_sampling_shapley and…
Browse files Browse the repository at this point in the history
… _combinatorial_montecarlo_shapley
  • Loading branch information
AnesBenmerzoug committed Dec 10, 2023
1 parent efca423 commit 4ad3eeb
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/pydvl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .config import *
from .dataset import *
from .numeric import *
from .progress import *
from .score import *
from .status import *
from .types import *
Expand Down
30 changes: 30 additions & 0 deletions src/pydvl/utils/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections.abc import Iterator
from itertools import cycle, takewhile
from typing import Collection

from tqdm.auto import tqdm

from pydvl.value.result import ValuationResult
from pydvl.value.stopping import StoppingCriterion

__all__ = ["repeat_indices"]


def repeat_indices(
indices: Collection[int], result: ValuationResult, done: StoppingCriterion, **kwargs
) -> Iterator[int]:
"""Helper function to cycle indefinitely over a collection of indices
until the stopping criterion is satisfied while displaying progress.
Args:
indices: Collection of indices that will be cycled until done.
result: Object containing the current results.
done: Stopping criterion.
kwargs: Keyword arguments passed to tqdm.
"""
with tqdm(total=100, unit="%", **kwargs) as pbar:
it = takewhile(lambda _: not done(result), cycle(indices))
for i in it:
yield i
pbar.update(100 * done.completion() - pbar.n)
pbar.refresh()
11 changes: 5 additions & 6 deletions src/pydvl/value/shapley/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import operator
from concurrent.futures import FIRST_COMPLETED, Future, wait
from functools import reduce
from itertools import cycle, takewhile
from typing import Optional, Sequence, Union

import numpy as np
Expand All @@ -65,6 +64,7 @@
init_parallel_backend,
)
from pydvl.utils.numeric import random_powerset
from pydvl.utils.progress import repeat_indices
from pydvl.utils.types import Seed, ensure_seed_sequence
from pydvl.utils.utility import Utility
from pydvl.value.result import ValuationResult
Expand Down Expand Up @@ -281,11 +281,10 @@ def _combinatorial_montecarlo_shapley(
)

rng = np.random.default_rng(seed)
repeat_indices = takewhile(lambda _: not done(result), cycle(indices))
pbar = tqdm(disable=not progress, position=job_id, total=100, unit="%")
for idx in repeat_indices:
pbar.n = 100 * done.completion()
pbar.refresh()

for idx in repeat_indices(
indices, result=result, done=done, disable=not progress, position=job_id
):
# Randomly sample subsets of full dataset without idx
subset = np.setxor1d(u.data.indices, [idx], assume_unique=True)
s = next(random_powerset(subset, n_samples=1, seed=rng))
Expand Down
12 changes: 5 additions & 7 deletions src/pydvl/value/shapley/owen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import operator
from enum import Enum
from functools import reduce
from itertools import cycle, takewhile
from typing import Optional, Sequence

import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm

from pydvl.parallel import MapReduceJob, ParallelConfig
from pydvl.utils import Utility, random_powerset
from pydvl.utils.progress import repeat_indices
from pydvl.utils.types import Seed
from pydvl.value import ValuationResult
from pydvl.value.stopping import MinUpdates
Expand Down Expand Up @@ -76,11 +75,10 @@ def _owen_sampling_shapley(

rng = np.random.default_rng(seed)
done = MinUpdates(1)
repeat_indices = takewhile(lambda _: not done(result), cycle(indices))
pbar = tqdm(disable=not progress, position=job_id, total=100, unit="%")
for idx in repeat_indices:
pbar.n = 100 * done.completion()
pbar.refresh()

for idx in repeat_indices(
indices, result=result, done=done, disable=not progress, position=job_id
):
e = np.zeros(max_q)
subset = np.setxor1d(u.data.indices, [idx], assume_unique=True)
for j, q in enumerate(q_steps):
Expand Down

0 comments on commit 4ad3eeb

Please sign in to comment.