Skip to content

Commit

Permalink
add mpi mode
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Nov 19, 2024
1 parent d05af18 commit 419db9a
Showing 1 changed file with 49 additions and 48 deletions.
97 changes: 49 additions & 48 deletions py/rvspecfit/desi/desi_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,12 @@ def __init__(self):
def submit(self, f, *args, **kw):
return FakeFuture(f(*args, **kw))

def __enter__(self):
return self

def __exit__(self):
pass


def proc_many(files,
output_dir,
Expand All @@ -1364,7 +1370,7 @@ def proc_many(files,
figure_dir=None,
figure_prefix=None,
config_fname=None,
nthreads=1,
pool=None,
fit_targetid=None,
objtypes=None,
minsn=-1e9,
Expand Down Expand Up @@ -1423,22 +1429,14 @@ def proc_many(files,
assert (config is not None)
assert ('template_lib' in config)

if nthreads > 1:
parallel = True
else:
parallel = False
if process_status_file is not None:
update_process_status_file(process_status_file,
None,
None,
None,
None,
start=True)
if parallel:
poolEx = concurrent.futures.ProcessPoolExecutor(nthreads)
else:
poolEx = FakeExecutor()
res = []

for f in files:
fname = f.split('/')[-1]
if subdirs:
Expand Down Expand Up @@ -1485,7 +1483,7 @@ def proc_many(files,
doplot=doplot,
minsn=minsn,
expid_range=expid_range,
poolex=poolEx,
poolex=pool,
fitarm=fitarm,
cmdline=cmdline,
zbest_select=zbest_select,
Expand All @@ -1497,15 +1495,6 @@ def proc_many(files,
throw_exceptions=throw_exceptions)
proc_desi_wrapper(*args, **kwargs)

if parallel:
try:
poolEx.shutdown(wait=True)
except KeyboardInterrupt:
for r in res:
r.cancel()
poolEx.shutdown(wait=False)
raise

logging.info("Successfully finished processing")


Expand Down Expand Up @@ -1661,6 +1650,10 @@ def main(args):
help='Make plots',
action='store_true',
default=False)
parser.add_argument('--mpi',
help='MPI mode',
action='store_true',
default=False)

parser.add_argument(
'--no_ccf_continuum_normalize',
Expand Down Expand Up @@ -1769,34 +1762,42 @@ def main(args):

if args.overwrite is not None:
logging.warning('overwrite keyword is meaningless now')

proc_many(
files,
output_dir,
output_tab_prefix,
output_mod_prefix,
figure_dir=figure_dir,
figure_prefix=args.figure_prefix,
nthreads=nthreads,
config_fname=config_fname,
fit_targetid=fit_targetid,
objtypes=objtypes,
doplot=doplot,
subdirs=args.subdirs,
minsn=minsn,
process_status_file=args.process_status_file,
expid_range=(minexpid, maxexpid),
skipexisting=args.skipexisting,
fitarm=fitarm,
cmdline=cmdline,
zbest_select=zbest_select,
zbest_include=zbest_include,
ccf_continuum_normalize=ccf_continuum_normalize,
use_resolution_matrix=args.resolution_matrix,
ccf_init=ccf_init,
npoly=npoly,
throw_exceptions=args.throw_exceptions,
)
if args.mpi:
from mpi4py.futures import MPICommExecutor
from mpi4py import MPI
else:
MPI, MPICommExecutor = None, None

with (MPICommExecutor(MPI.COMM_WORLD, root=0) if args.mpi else
(concurrent.futures.ProcessPoolExecutor(nthreads)
if nthreads > 1 else FakeExecutor())) as poolEx:
proc_many(
files,
output_dir,
output_tab_prefix,
output_mod_prefix,
figure_dir=figure_dir,
figure_prefix=args.figure_prefix,
pool=poolEx,
config_fname=config_fname,
fit_targetid=fit_targetid,
objtypes=objtypes,
doplot=doplot,
subdirs=args.subdirs,
minsn=minsn,
process_status_file=args.process_status_file,
expid_range=(minexpid, maxexpid),
skipexisting=args.skipexisting,
fitarm=fitarm,
cmdline=cmdline,
zbest_select=zbest_select,
zbest_include=zbest_include,
ccf_continuum_normalize=ccf_continuum_normalize,
use_resolution_matrix=args.resolution_matrix,
ccf_init=ccf_init,
npoly=npoly,
throw_exceptions=args.throw_exceptions,
)


if __name__ == '__main__':
Expand Down

0 comments on commit 419db9a

Please sign in to comment.