diff --git a/py/rvspecfit/desi/desi_fit.py b/py/rvspecfit/desi/desi_fit.py index 3dbd5d8..392738a 100644 --- a/py/rvspecfit/desi/desi_fit.py +++ b/py/rvspecfit/desi/desi_fit.py @@ -1356,12 +1356,6 @@ def __init__(self): def submit(self, f, *args, **kw): return FakeFuture(f(*args, **kw)) - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - def proc_many(files, output_dir, @@ -1370,7 +1364,7 @@ def proc_many(files, figure_dir=None, figure_prefix=None, config_fname=None, - pool=None, + nthreads=1, fit_targetid=None, objtypes=None, minsn=-1e9, @@ -1429,6 +1423,10 @@ def proc_many(files, assert (config is not None) assert ('template_lib' in config) + if nthreads > 1: + parallel = True + else: + parallel = False if process_status_file is not None: update_process_status_file(process_status_file, None, @@ -1436,7 +1434,11 @@ def proc_many(files, None, None, start=True) - + if parallel: + poolEx = concurrent.futures.ProcessPoolExecutor(nthreads) + else: + poolEx = FakeExecutor() + res = [] for f in files: fname = f.split('/')[-1] if subdirs: @@ -1483,7 +1485,7 @@ def proc_many(files, doplot=doplot, minsn=minsn, expid_range=expid_range, - poolex=pool, + poolex=poolEx, fitarm=fitarm, cmdline=cmdline, zbest_select=zbest_select, @@ -1495,6 +1497,15 @@ def proc_many(files, throw_exceptions=throw_exceptions) proc_desi_wrapper(*args, **kwargs) + if parallel: + try: + poolEx.shutdown(wait=True) + except KeyboardInterrupt: + for r in res: + r.cancel() + poolEx.shutdown(wait=False) + raise + logging.info("Successfully finished processing") @@ -1650,10 +1661,6 @@ def main(args): help='Make plots', action='store_true', default=False) - parser.add_argument('--mpi', - help='MPI mode', - action='store_true', - default=False) parser.add_argument( '--no_ccf_continuum_normalize', @@ -1762,42 +1769,34 @@ def main(args): if args.overwrite is not None: logging.warning('overwrite keyword is meaningless now') - if args.mpi: - from mpi4py.futures import MPICommExecutor - from mpi4py import MPI - else: - MPI, MPICommExecutor = None, None - - with (MPICommExecutor(MPI.COMM_WORLD, root=0) if args.mpi else - (concurrent.futures.ProcessPoolExecutor(nthreads) - if nthreads > 1 else FakeExecutor())) as poolEx: - proc_many( - files, - output_dir, - output_tab_prefix, - output_mod_prefix, - figure_dir=figure_dir, - figure_prefix=args.figure_prefix, - pool=poolEx, - config_fname=config_fname, - fit_targetid=fit_targetid, - objtypes=objtypes, - doplot=doplot, - subdirs=args.subdirs, - minsn=minsn, - process_status_file=args.process_status_file, - expid_range=(minexpid, maxexpid), - skipexisting=args.skipexisting, - fitarm=fitarm, - cmdline=cmdline, - zbest_select=zbest_select, - zbest_include=zbest_include, - ccf_continuum_normalize=ccf_continuum_normalize, - use_resolution_matrix=args.resolution_matrix, - ccf_init=ccf_init, - npoly=npoly, - throw_exceptions=args.throw_exceptions, - ) + + proc_many( + files, + output_dir, + output_tab_prefix, + output_mod_prefix, + figure_dir=figure_dir, + figure_prefix=args.figure_prefix, + nthreads=nthreads, + config_fname=config_fname, + fit_targetid=fit_targetid, + objtypes=objtypes, + doplot=doplot, + subdirs=args.subdirs, + minsn=minsn, + process_status_file=args.process_status_file, + expid_range=(minexpid, maxexpid), + skipexisting=args.skipexisting, + fitarm=fitarm, + cmdline=cmdline, + zbest_select=zbest_select, + zbest_include=zbest_include, + ccf_continuum_normalize=ccf_continuum_normalize, + use_resolution_matrix=args.resolution_matrix, + ccf_init=ccf_init, + npoly=npoly, + throw_exceptions=args.throw_exceptions, + ) if __name__ == '__main__':