From 5b0c33d137d4c6eb637aafbc101046c09fa3819f Mon Sep 17 00:00:00 2001 From: Ivan Raikov Date: Fri, 5 Apr 2024 12:30:14 -0700 Subject: [PATCH] small refactor of measure_distances to avoid use of bcast_cell_attributes --- .../simulator/measure_distances.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/miv_simulator/simulator/measure_distances.py b/src/miv_simulator/simulator/measure_distances.py index dafcdf9..f701673 100644 --- a/src/miv_simulator/simulator/measure_distances.py +++ b/src/miv_simulator/simulator/measure_distances.py @@ -17,7 +17,7 @@ from mpi4py import MPI from neuroh5.io import ( append_cell_attributes, - bcast_cell_attributes, + read_cell_attributes, read_population_ranges, ) @@ -95,29 +95,40 @@ def measure_distances( comm = MPI.COMM_WORLD rank = comm.rank - soma_coords = {} - if rank == 0: logger.info("Reading population coordinates...") if not populations: populations = read_population_ranges(filepath, comm)[0].keys() - for population in sorted(populations): - coords = bcast_cell_attributes( - filepath, population, 0, namespace=coordinate_namespace, comm=comm - ) + if rank == 0: + color = 1 + else: + color = 0 + ## comm0 includes only rank 0 + comm0 = comm.Split(color, 0) - soma_coords[population] = { - k: ( - v["U Coordinate"][0], - v["V Coordinate"][0], - v["L Coordinate"][0], + soma_coords = {} + if rank == 0: + for population in sorted(populations): + coords_iter = read_cell_attributes( + filepath, + population, + mask={"U Coordinate", "V Coordinate", "L Coordinate"}, + namespace=coordinate_namespace, + comm=comm0, ) - for (k, v) in coords - } - del coords - gc.collect() + + soma_coords[population] = { + k: ( + v["U Coordinate"][0], + v["V Coordinate"][0], + v["L Coordinate"][0], + ) + for (k, v) in coords_iter + } + comm.barrier() + soma_coords = comm.bcast(soma_coords, root=0) has_ip_dist = False origin_ranges = None @@ -135,7 +146,7 @@ def measure_distances( base64.b64decode(ip_dist_dset[()]) ) f.close() - has_ip_dist = MPI.COMM_WORLD.bcast(has_ip_dist, root=0) + has_ip_dist = comm.bcast(has_ip_dist, root=0) if not has_ip_dist: if rank == 0: