Skip to content

Commit

Permalink
Added new RayPoolExecutor for limited concurrency with Ray.
Browse files Browse the repository at this point in the history
  • Loading branch information
adivekar-utexas committed Sep 16, 2024
1 parent 4bcd8bb commit c1442cf
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 23 deletions.
160 changes: 141 additions & 19 deletions src/synthesizrr/base/util/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.exceptions import GetTimeoutError
from ray.util.dask import RayDaskCallback
from pydantic import validate_arguments, conint, confloat
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, if_else
from synthesizrr.base.util.language import ProgressBar, set_param_from_alias, type_str, get_default, first_item, Parameters
from synthesizrr.base.constants.DataProcessingConstants import Parallelize, FailureAction, Status, COMPLETED_STATUSES

from functools import partial
Expand All @@ -26,6 +26,12 @@
import time, inspect
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, wait as wait_future

_RAY_ACCUMULATE_ITEM_WAIT: float = 10e-3 ## 10ms
_LOCAL_ACCUMULATE_ITEM_WAIT: float = 1e-3 ## 1ms

_RAY_ACCUMULATE_ITER_WAIT: float = 1000e-3 ## 1000ms
_LOCAL_ACCUMULATE_ITER_WAIT: float = 100e-3 ## 100ms


def _asyncio_start_event_loop(loop):
asyncio.set_event_loop(loop)
Expand Down Expand Up @@ -173,6 +179,25 @@ def run_concurrent(
raise e


class RestrictedConcurrencyThreadPoolExecutor(ThreadPoolExecutor):
"""
Similar functionality to @concurrent.
"""

def __init__(self, max_active_threads: Optional[int] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
if max_active_threads is None:
max_active_threads: int = self._max_workers
assert isinstance(max_active_threads, int)
self._semaphore = Semaphore(max_active_threads)

def submit(self, *args, **kwargs):
self._semaphore.acquire()
future = super().submit(*args, **kwargs)
future.add_done_callback(lambda _: self._semaphore.release())
return future


_GLOBAL_PROCESS_POOL_EXECUTOR: ProcessPoolExecutor = ProcessPoolExecutor(
max_workers=max(1, min(32, mp.cpu_count() - 1))
)
Expand Down Expand Up @@ -255,10 +280,105 @@ def stop_executor(


@ray.remote(num_cpus=1)
def __run_parallel_ray_executor(fn, *args, **kwargs):
def _run_parallel_ray_executor(fn, *args, **kwargs):
return fn(*args, **kwargs)


def _ray_asyncio_start_event_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()


class RayPoolExecutor(Parameters):
max_workers: Union[int, Literal[inf]]
iter_wait: float = _RAY_ACCUMULATE_ITER_WAIT
item_wait: float = _RAY_ACCUMULATE_ITEM_WAIT
_asyncio_event_loop: Optional = None
_asyncio_event_loop_thread: Optional = None
_submission_executor: Optional[ThreadPoolExecutor] = None
_running_tasks: Dict = {}
_latest_submit: Optional[int] = None

def _set_asyncio(self):
# Create a new loop and a thread running this loop
if self._asyncio_event_loop is None:
self._asyncio_event_loop = asyncio.new_event_loop()
# print(f'Started _asyncio_event_loop')
if self._asyncio_event_loop_thread is None:
self._asyncio_event_loop_thread = threading.Thread(
target=_ray_asyncio_start_event_loop,
args=(self._asyncio_event_loop,),
)
self._asyncio_event_loop_thread.start()
# print(f'Started _asyncio_event_loop_thread')

def submit(
self,
fn,
*args,
scheduling_strategy: str = "SPREAD",
num_cpus: int = 1,
num_gpus: int = 0,
max_retries: int = 0,
retry_exceptions: Union[List, bool] = True,
**kwargs,
):
# print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
def _submit_task():
return _run_parallel_ray_executor.options(
scheduling_strategy=scheduling_strategy,
num_cpus=num_cpus,
num_gpus=num_gpus,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
).remote(fn, *args, **kwargs)

_task_uid = str(time.time_ns())

if self.max_workers == inf:
return _submit_task() ## Submit to Ray directly
self._set_asyncio()
## Create a coroutine (i.e. Future), but do not actually start executing it.
coroutine = self._ray_run_fn_async(
submit_task=_submit_task,
task_uid=_task_uid,
)

## Schedule the coroutine to execute on the event loop (which is running on thread _asyncio_event_loop).
fut = asyncio.run_coroutine_threadsafe(coroutine, self._asyncio_event_loop)
# while _task_uid not in self._running_tasks: ## Ensure task has started scheduling
# time.sleep(self.item_wait)
return fut

async def _ray_run_fn_async(
self,
submit_task: Callable,
task_uid: str,
):
# self._running_tasks[task_uid] = None
while len(self._running_tasks) >= self.max_workers:
for _task_uid in sorted(self._running_tasks.keys()):
if is_done(self._running_tasks[_task_uid]):
self._running_tasks.pop(_task_uid, None)
# print(f'Popped {_task_uid}')
if len(self._running_tasks) < self.max_workers:
break
time.sleep(self.item_wait)
if len(self._running_tasks) < self.max_workers:
break
time.sleep(self.iter_wait)
fut = submit_task()
self._running_tasks[task_uid] = fut
# print(f'Started {task_uid}. Num running: {len(self._running_tasks)}')

# ## Cleanup any completed tasks:
# for k in list(self._running_tasks.keys()):
# if is_done(self._running_tasks[k]):
# self._running_tasks.pop(k, None)
# time.sleep(self.item_wait)
return fut


def run_parallel_ray(
fn,
*args,
Expand All @@ -270,7 +390,7 @@ def run_parallel_ray(
**kwargs,
):
# print(f'Running {fn_str(fn)} using {Parallelize.ray} with num_cpus={num_cpus}, num_gpus={num_gpus}')
return __run_parallel_ray_executor.options(
return _run_parallel_ray_executor.options(
scheduling_strategy=scheduling_strategy,
num_cpus=num_cpus,
num_gpus=num_gpus,
Expand All @@ -285,7 +405,7 @@ def dispatch(
parallelize: Parallelize,
forward_parallelize: bool = False,
delay: float = 0.0,
executor: Optional[Union[ThreadPoolExecutor, ProcessPoolExecutor]] = None,
executor: Optional[Union[ThreadPoolExecutor, ProcessPoolExecutor, RayPoolExecutor]] = None,
**kwargs
) -> Any:
parallelize: Parallelize = Parallelize.from_str(parallelize)
Expand All @@ -301,24 +421,26 @@ def dispatch(
elif parallelize is Parallelize.processes:
return run_parallel(fn, *args, executor=executor, **kwargs)
elif parallelize is Parallelize.ray:
return run_parallel_ray(fn, *args, **kwargs)
return run_parallel_ray(fn, *args, executor=executor, **kwargs)
raise NotImplementedError(f'Unsupported parallelization: {parallelize}')


def dispatch_executor(
parallelize: Parallelize,
**kwargs
) -> Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]]:
) -> Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor, RayPoolExecutor]]:
parallelize: Parallelize = Parallelize.from_str(parallelize)
set_param_from_alias(kwargs, param='max_workers', alias=['num_workers'], default=None)
max_workers: Optional[int] = kwargs.pop('max_workers', None)
if max_workers is None:
## Uses the default executor for threads/processes.
## Uses the default executor for threads/processes/ray.
return None
if parallelize is Parallelize.threads:
return ThreadPoolExecutor(max_workers=max_workers)
elif parallelize is Parallelize.processes:
return ProcessPoolExecutor(max_workers=max_workers)
elif parallelize is Parallelize.ray:
return RayPoolExecutor(max_workers=max_workers)
else:
return None

Expand All @@ -329,7 +451,7 @@ def get_result(
wait: float = 1.0, ## 1000 ms
) -> Optional[Any]:
if isinstance(x, Future):
return x.result()
return get_result(x.result(), wait=wait)
if isinstance(x, ray.ObjectRef):
while True:
try:
Expand Down Expand Up @@ -399,13 +521,6 @@ def is_failed(x, *, pending_returns_false: bool = False) -> Optional[bool]:
return True


_RAY_ACCUMULATE_ITEM_WAIT: float = 100e-3 ## 100ms
_LOCAL_ACCUMULATE_ITEM_WAIT: float = 10e-3 ## 10ms

_RAY_ACCUMULATE_ITER_WAIT: float = 1000e-3 ## 1000ms
_LOCAL_ACCUMULATE_ITER_WAIT: float = 100e-3 ## 100ms


def accumulate(
futures: Union[Tuple, List, Set, Dict, Any],
*,
Expand Down Expand Up @@ -678,8 +793,9 @@ def retry(
wait: confloat(ge=0.0) = 10.0,
jitter: confloat(gt=0.0) = 0.5,
silent: bool = True,
return_num_failures: bool = False,
**kwargs
):
) -> Union[Any, Tuple[Any, int]]:
"""
Retries a function call a certain number of times, waiting between calls (with a jitter in the wait period).
:param fn: the function to call.
Expand All @@ -694,15 +810,21 @@ def retry(
"""
wait: float = float(wait)
latest_exception = None
num_failures: int = 0
for retry_num in range(retries + 1):
try:
return fn(*args, **kwargs)
out = fn(*args, **kwargs)
if return_num_failures:
return out, num_failures
else:
return out
except Exception as e:
num_failures += 1
latest_exception = traceback.format_exc()
if not silent:
logging.debug(f'Function call failed with the following exception:\n{latest_exception}')
print(f'Function call failed with the following exception:\n{latest_exception}')
if retry_num < (retries - 1):
logging.debug(f'Retrying {retries - (retry_num + 1)} more times...\n')
print(f'Retrying {retries - (retry_num + 1)} more times...\n')
time.sleep(np.random.uniform(wait - wait * jitter, wait + wait * jitter))
raise RuntimeError(f'Function call failed {retries} times.\nLatest exception:\n{latest_exception}\n')

Expand Down
61 changes: 57 additions & 4 deletions src/synthesizrr/base/util/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,38 @@ def as_list(l) -> List:
return [l]


def list_pop_inplace(l: List, *, pop_condition: Callable) -> List:
assert isinstance(l, list) ## Needs to be a mutable
## Iterate backwards to preserve indexes while iterating
for i in range(len(l) - 1, -1, -1): # Iterate backwards
if pop_condition(l[i]):
l.pop(i) ## Remove the item inplace
return l


def set_union(*args) -> Set:
_union: Set = set()
for s in args:
if isinstance(s, (pd.Series, np.ndarray)):
s: List = s.tolist()
s: Set = set(s)
_union: Set = _union.union(s)
return _union


def set_intersection(*args) -> Set:
_intersection: Optional[Set] = None
for s in args:
if isinstance(s, (pd.Series, np.ndarray)):
s: List = s.tolist()
s: Set = set(s)
if _intersection is None:
_intersection: Set = s
else:
_intersection: Set = _intersection.intersection(s)
return _intersection


def filter_string_list(l: List[str], pattern: str, ignorecase: bool = False) -> List[str]:
"""
Filter a list of strings based on an exact match to a regex pattern. Leaves non-string items untouched.
Expand Down Expand Up @@ -1201,7 +1233,7 @@ def as_tuple(l) -> Tuple:


## ======================== Set utils ======================== ##
def is_set_like(l: Union[Set, frozenset]) -> bool:
def is_set_like(l: Any) -> bool:
return isinstance(l, (set, frozenset, KeysView))


Expand Down Expand Up @@ -1945,7 +1977,7 @@ def mean(vals):


def random_sample(
data: Union[List, Tuple, np.ndarray],
data: Union[List, Tuple, Set, np.ndarray],
n: SampleSizeType,
*,
replacement: bool = False,
Expand All @@ -1961,6 +1993,8 @@ def random_sample(
"""
np_random = np.random.RandomState(seed)
py_random = random.Random(seed)
if is_set_like(data):
data: List = list(data)
if not is_list_like(data):
raise ValueError(
f'Input `data` must be {list}, {tuple} or {np.ndarray}; '
Expand Down Expand Up @@ -2973,7 +3007,7 @@ def pd_partial_column_order(df: pd.DataFrame, columns: List) -> pd.DataFrame:

class ProgressBar(MutableParameters):
pbar: Optional[TqdmProgressBar] = None
style: Literal['auto', 'notebook', 'std'] = 'auto'
style: Literal['auto', 'notebook', 'std', 'ray'] = 'auto'
unit: str = 'row'
color: str = '#0288d1' ## Bluish
ncols: int = 100
Expand All @@ -2998,7 +3032,7 @@ def _set_params(cls, params: Dict) -> Dict:
@classmethod
def _create_pbar(
cls,
style: Literal['auto', 'notebook', 'std'],
style: Literal['auto', 'notebook', 'std', 'ray'],
**kwargs,
) -> TqdmProgressBar:
if style == 'auto':
Expand All @@ -3009,6 +3043,15 @@ def _create_pbar(
with optional_dependency('ipywidgets'):
kwargs['ncols']: Optional[int] = None
return NotebookTqdmProgressBar(**kwargs)
elif style == 'ray':
from ray.experimental import tqdm_ray
kwargs = filter_keys(
kwargs,
keys=set(get_fn_spec(tqdm_ray.tqdm).args + get_fn_spec(tqdm_ray.tqdm).kwargs),
how='include',
)
from ray.experimental import tqdm_ray
return tqdm_ray.tqdm(**kwargs)
else:
return StdTqdmProgressBar(**kwargs)

Expand Down Expand Up @@ -3363,3 +3406,13 @@ def plotsum(
else:
raise not_impl('how', how)
return plots


def to_pct(counts: pd.Series): ## Converts value counts to percentages
_sum = counts.sum()
return pd.DataFrame({
'value': counts.index.tolist(),
'count': counts.tolist(),
'pct': counts.apply(lambda x: 100 * x / _sum).tolist(),
'count_str': counts.apply(lambda x: f'{x} of {_sum}').tolist(),
})

0 comments on commit c1442cf

Please sign in to comment.