Skip to content

Commit f2ff331

Browse files
committed
better validation; code cleanup
1 parent 4d3b668 commit f2ff331

7 files changed

+40
-66
lines changed

mpisppy/cylinders/hub.py

-16
Original file line numberDiff line numberDiff line change
@@ -442,14 +442,6 @@ class PHHub(Hub):
442442
receive_fields = (*Hub.receive_fields,)
443443

444444
def setup_hub(self):
445-
""" Must be called after make_windows(), so that
446-
the hub knows the sizes of all the spokes windows
447-
"""
448-
if not self._windows_constructed:
449-
raise RuntimeError(
450-
"Cannot call setup_hub before memory windows are constructed"
451-
)
452-
453445
## Generate some warnings if nothing is giving bounds
454446
if not self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
455447
logger.warn(
@@ -573,14 +565,6 @@ class LShapedHub(Hub):
573565
receive_fields = (*Hub.receive_fields,)
574566

575567
def setup_hub(self):
576-
""" Must be called after make_windows(), so that
577-
the hub knows the sizes of all the spokes windows
578-
"""
579-
if not self._windows_constructed:
580-
raise RuntimeError(
581-
"Cannot call setup_hub before memory windows are constructed"
582-
)
583-
584568
## Generate some warnings if nothing is giving bounds
585569
if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
586570
logger.warn(

mpisppy/cylinders/spcommunicator.py

+34-41
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class FieldArray:
4242
"""
4343

4444
def __init__(self, length: int):
45-
self._length = length
4645
self._array = communicator_array(length)
4746
self._id = 0
4847
return
@@ -117,8 +116,6 @@ class SPCommunicator:
117116
receive_fields = ()
118117

119118
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
120-
# flag for if the windows have been constructed
121-
self._windows_constructed = False
122119
self.fullcomm = fullcomm
123120
self.strata_comm = strata_comm
124121
self.cylinder_comm = cylinder_comm
@@ -152,9 +149,9 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
152149

153150
self.register_send_fields()
154151

155-
self._exchange_send_fields()
156-
# TODO: here we can have a dynamic exchange of the send fields
157-
# so we can do error checking (all-to-all in send fields)
152+
self._make_windows()
153+
self._create_field_rank_mappings()
154+
158155
self.register_receive_fields()
159156

160157
# TODO: check that we have something in receive_field_spcomms??
@@ -188,25 +185,37 @@ def _build_window_spec(self) -> dict[Field, int]:
188185
## End for
189186
return window_spec
190187

191-
def _exchange_send_fields(self) -> None:
192-
""" Do an all-to-all so we know what the other communicators are sending """
193-
send_buffers = tuple((k, buff._length) for k, buff in self.send_buffers.items())
194-
self.send_fields_lengths_by_rank = self.strata_comm.allgather(send_buffers)
195-
196-
self.send_fields_by_rank = {}
188+
def _create_field_rank_mappings(self) -> None:
189+
self.fields_to_ranks = {}
190+
self.ranks_to_fields = {}
197191

198-
self.available_receive_fields = {}
199-
for rank, fields_lengths in enumerate(self.send_fields_lengths_by_rank):
192+
for rank, buffer_layout in enumerate(self.window.strata_buffer_layouts):
200193
if rank == self.strata_rank:
201194
continue
202-
self.send_fields_by_rank[rank] = []
203-
for f, length in fields_lengths:
204-
if f not in self.available_receive_fields:
205-
self.available_receive_fields[f] = []
206-
self.available_receive_fields[f].append(rank)
207-
self.send_fields_by_rank[rank].append(f)
208-
209-
# print(f"{self.__class__.__name__}: {self.available_receive_fields=}")
195+
self.ranks_to_fields[rank] = []
196+
for field in buffer_layout:
197+
if field not in self.fields_to_ranks:
198+
self.fields_to_ranks[field] = []
199+
self.fields_to_ranks[field].append(rank)
200+
self.ranks_to_fields[rank].append(field)
201+
202+
# print(f"{self.__class__.__name__}: {self.fields_to_ranks=}, {self.ranks_to_fields=}")
203+
204+
def _validate_recv_field(self, field: Field, origin: int, length: int):
205+
remote_buffer_layout = self.window.strata_buffer_layouts[origin]
206+
if field not in remote_buffer_layout:
207+
raise RuntimeError(f"{self.__class__.__name__} on local {self.strata_rank=} "
208+
f"could not find {field=} on remote rank {origin} with "
209+
f"class {self.communicators[origin]['spcomm_class']}."
210+
)
211+
_, remote_length = remote_buffer_layout[field]
212+
if (length + 1) != remote_length:
213+
raise RuntimeError(f"{self.__class__.__name__} on local {self.strata_rank=} "
214+
f"{field=} has length {length} on local "
215+
f"{self.strata_rank=} and length {remote_length} "
216+
f"on remote rank {origin} with class "
217+
f"{self.communicators[origin]['spcomm_class']}."
218+
)
210219

211220
def register_recv_field(self, field: Field, origin: int, length: int = -1) -> RecvArray:
212221
# print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
@@ -217,13 +226,7 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
217226
my_fa = self.receive_buffers[key]
218227
assert(length + 1 == np.size(my_fa.array()))
219228
else:
220-
available_fields_from_origin = self.send_fields_lengths_by_rank[origin]
221-
for _field, _length in available_fields_from_origin:
222-
if field == _field:
223-
assert length == _length
224-
break
225-
else: # couldn't find field!
226-
raise RuntimeError(f"Couldn't find {field=} from {origin=}")
229+
self._validate_recv_field(field, origin, length)
227230
my_fa = RecvArray(length)
228231
self.receive_buffers[key] = my_fa
229232
## End if
@@ -276,20 +279,10 @@ def hub_finalize(self):
276279
def allreduce_or(self, val):
277280
return self.opt.allreduce_or(val)
278281

279-
def free_windows(self):
280-
"""
281-
"""
282-
if self._windows_constructed:
283-
self.window.free()
284-
self._windows_constructed = False
285-
286-
def make_windows(self) -> None:
287-
if self._windows_constructed:
288-
return
282+
def _make_windows(self) -> None:
289283

290284
window_spec = self._build_window_spec()
291285
self.window = SPWindow(window_spec, self.strata_comm)
292-
self._windows_constructed = True
293286

294287
return
295288

@@ -305,6 +298,6 @@ def register_receive_fields(self) -> None:
305298
if strata_rank == self.strata_rank:
306299
continue
307300
cls = comm["spcomm_class"]
308-
if field in self.send_fields_by_rank[strata_rank]:
301+
if field in self.ranks_to_fields[strata_rank]:
309302
buff = self.register_recv_field(field, strata_rank)
310303
self.receive_field_spcomms[field].append((strata_rank, cls, buff))

mpisppy/cylinders/spoke.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def got_kill_signal(self):
101101
""" Spoke should call this method at least every iteration
102102
to see if the Hub terminated
103103
"""
104-
self.updatereceive_buffers()
104+
self.update_receive_buffers()
105105
return self._got_kill_signal()
106106

107107
@abc.abstractmethod
@@ -114,7 +114,7 @@ def main(self):
114114
"""
115115
pass
116116

117-
def updatereceive_buffers(self):
117+
def update_receive_buffers(self):
118118
for (key, recv_buf) in self.receive_buffers.items():
119119
field, rank = self._split_key(key)
120120
# The below code will need to be updated for spoke to spoke communication

mpisppy/cylinders/spwindow.py

-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def __init__(self, my_fields: dict, strata_comm: MPI.Comm, field_order=None):
125125

126126
return
127127

128-
129128
def free(self):
130129

131130
if self.window_constructed:

mpisppy/extensions/cross_scen_extension.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def setup_hub(self):
280280

281281
def register_receive_fields(self):
282282
spcomm = self.opt.spcomm
283-
cross_scenario_cut_ranks = spcomm.available_receive_fields[Field.CROSS_SCENARIO_CUT]
283+
cross_scenario_cut_ranks = spcomm.fields_to_ranks[Field.CROSS_SCENARIO_CUT]
284284
assert len(cross_scenario_cut_ranks) == 1
285285
index = cross_scenario_cut_ranks[0]
286286

mpisppy/extensions/reduced_costs_fixer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def post_iter0_after_sync(self):
9595

9696
def register_receive_fields(self):
9797
spcomm = self.opt.spcomm
98-
expected_reduced_cost_ranks = spcomm.available_receive_fields[Field.EXPECTED_REDUCED_COST]
98+
expected_reduced_cost_ranks = spcomm.fields_to_ranks[Field.EXPECTED_REDUCED_COST]
9999
assert len(expected_reduced_cost_ranks) == 1
100100
index = expected_reduced_cost_ranks[0]
101101

mpisppy/spin_the_wheel.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def run(self, comm_world=None):
127127
spcomm = sp_class(opt, fullcomm, strata_comm, cylinder_comm,
128128
communicator_list, **sp_kwargs)
129129

130-
# Create the windows, run main(), destroy the windows
131-
spcomm.make_windows()
130+
# Run main()
132131
if strata_rank == 0:
133132
spcomm.setup_hub()
134133

@@ -151,9 +150,8 @@ def run(self, comm_world=None):
151150
spcomm.hub_finalize()
152151

153152
fullcomm.Barrier()
153+
global_toc("Finalize Complete")
154154

155-
spcomm.free_windows()
156-
global_toc("Windows freed")
157155

158156
self.spcomm = spcomm
159157
self.spcomm_dict = spcomm_dict

0 commit comments

Comments
 (0)