diff --git a/py/rvspecfit/desi/desi_fit.py b/py/rvspecfit/desi/desi_fit.py index 392738a..71c138f 100644 --- a/py/rvspecfit/desi/desi_fit.py +++ b/py/rvspecfit/desi/desi_fit.py @@ -1356,6 +1356,12 @@ def __init__(self): def submit(self, f, *args, **kw): return FakeFuture(f(*args, **kw)) + def __enter__(self): + return self + + def __exit__(self): + pass + def proc_many(files, output_dir, @@ -1364,7 +1370,7 @@ def proc_many(files, figure_dir=None, figure_prefix=None, config_fname=None, - nthreads=1, + pool=None, fit_targetid=None, objtypes=None, minsn=-1e9, @@ -1423,10 +1429,6 @@ def proc_many(files, assert (config is not None) assert ('template_lib' in config) - if nthreads > 1: - parallel = True - else: - parallel = False if process_status_file is not None: update_process_status_file(process_status_file, None, @@ -1434,11 +1436,7 @@ def proc_many(files, None, None, start=True) - if parallel: - poolEx = concurrent.futures.ProcessPoolExecutor(nthreads) - else: - poolEx = FakeExecutor() - res = [] + for f in files: fname = f.split('/')[-1] if subdirs: @@ -1485,7 +1483,7 @@ def proc_many(files, doplot=doplot, minsn=minsn, expid_range=expid_range, - poolex=poolEx, + poolex=pool, fitarm=fitarm, cmdline=cmdline, zbest_select=zbest_select, @@ -1497,15 +1495,6 @@ def proc_many(files, throw_exceptions=throw_exceptions) proc_desi_wrapper(*args, **kwargs) - if parallel: - try: - poolEx.shutdown(wait=True) - except KeyboardInterrupt: - for r in res: - r.cancel() - poolEx.shutdown(wait=False) - raise - logging.info("Successfully finished processing") @@ -1661,6 +1650,10 @@ def main(args): help='Make plots', action='store_true', default=False) + parser.add_argument('--mpi', + help='MPI mode', + action='store_true', + default=False) parser.add_argument( '--no_ccf_continuum_normalize', @@ -1769,34 +1762,42 @@ def main(args): if args.overwrite is not None: logging.warning('overwrite keyword is meaningless now') - - proc_many( - files, - output_dir, - output_tab_prefix, - output_mod_prefix, - figure_dir=figure_dir, - figure_prefix=args.figure_prefix, - nthreads=nthreads, - config_fname=config_fname, - fit_targetid=fit_targetid, - objtypes=objtypes, - doplot=doplot, - subdirs=args.subdirs, - minsn=minsn, - process_status_file=args.process_status_file, - expid_range=(minexpid, maxexpid), - skipexisting=args.skipexisting, - fitarm=fitarm, - cmdline=cmdline, - zbest_select=zbest_select, - zbest_include=zbest_include, - ccf_continuum_normalize=ccf_continuum_normalize, - use_resolution_matrix=args.resolution_matrix, - ccf_init=ccf_init, - npoly=npoly, - throw_exceptions=args.throw_exceptions, - ) + if args.mpi: + from mpi4py.futures import MPICommExecutor + from mpi4py import MPI + else: + MPI, MPICommExecutor = None, None + + with (MPICommExecutor(MPI.COMM_WORLD, root=0) if args.mpi else + (concurrent.futures.ProcessPoolExecutor(nthreads) + if nthreads > 1 else FakeExecutor())) as poolEx: + proc_many( + files, + output_dir, + output_tab_prefix, + output_mod_prefix, + figure_dir=figure_dir, + figure_prefix=args.figure_prefix, + pool=poolEx, + config_fname=config_fname, + fit_targetid=fit_targetid, + objtypes=objtypes, + doplot=doplot, + subdirs=args.subdirs, + minsn=minsn, + process_status_file=args.process_status_file, + expid_range=(minexpid, maxexpid), + skipexisting=args.skipexisting, + fitarm=fitarm, + cmdline=cmdline, + zbest_select=zbest_select, + zbest_include=zbest_include, + ccf_continuum_normalize=ccf_continuum_normalize, + use_resolution_matrix=args.resolution_matrix, + ccf_init=ccf_init, + npoly=npoly, + throw_exceptions=args.throw_exceptions, + ) if __name__ == '__main__':