@@ -42,7 +42,6 @@ class FieldArray:
42
42
"""
43
43
44
44
def __init__ (self , length : int ):
45
- self ._length = length
46
45
self ._array = communicator_array (length )
47
46
self ._id = 0
48
47
return
@@ -117,8 +116,6 @@ class SPCommunicator:
117
116
receive_fields = ()
118
117
119
118
def __init__ (self , spbase_object , fullcomm , strata_comm , cylinder_comm , communicators , options = None ):
120
- # flag for if the windows have been constructed
121
- self ._windows_constructed = False
122
119
self .fullcomm = fullcomm
123
120
self .strata_comm = strata_comm
124
121
self .cylinder_comm = cylinder_comm
@@ -152,9 +149,9 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
152
149
153
150
self .register_send_fields ()
154
151
155
- self ._exchange_send_fields ()
156
- # TODO: here we can have a dynamic exchange of the send fields
157
- # so we can do error checking (all-to-all in send fields)
152
+ self ._make_windows ()
153
+ self . _create_field_rank_mappings ()
154
+
158
155
self .register_receive_fields ()
159
156
160
157
# TODO: check that we have something in receive_field_spcomms??
@@ -188,25 +185,37 @@ def _build_window_spec(self) -> dict[Field, int]:
188
185
## End for
189
186
return window_spec
190
187
191
- def _exchange_send_fields (self ) -> None :
192
- """ Do an all-to-all so we know what the other communicators are sending """
193
- send_buffers = tuple ((k , buff ._length ) for k , buff in self .send_buffers .items ())
194
- self .send_fields_lengths_by_rank = self .strata_comm .allgather (send_buffers )
195
-
196
- self .send_fields_by_rank = {}
188
+ def _create_field_rank_mappings (self ) -> None :
189
+ self .fields_to_ranks = {}
190
+ self .ranks_to_fields = {}
197
191
198
- self .available_receive_fields = {}
199
- for rank , fields_lengths in enumerate (self .send_fields_lengths_by_rank ):
192
+ for rank , buffer_layout in enumerate (self .window .strata_buffer_layouts ):
200
193
if rank == self .strata_rank :
201
194
continue
202
- self .send_fields_by_rank [rank ] = []
203
- for f , length in fields_lengths :
204
- if f not in self .available_receive_fields :
205
- self .available_receive_fields [f ] = []
206
- self .available_receive_fields [f ].append (rank )
207
- self .send_fields_by_rank [rank ].append (f )
208
-
209
- # print(f"{self.__class__.__name__}: {self.available_receive_fields=}")
195
+ self .ranks_to_fields [rank ] = []
196
+ for field in buffer_layout :
197
+ if field not in self .fields_to_ranks :
198
+ self .fields_to_ranks [field ] = []
199
+ self .fields_to_ranks [field ].append (rank )
200
+ self .ranks_to_fields [rank ].append (field )
201
+
202
+ # print(f"{self.__class__.__name__}: {self.fields_to_ranks=}, {self.ranks_to_fields=}")
203
+
204
+ def _validate_recv_field (self , field : Field , origin : int , length : int ):
205
+ remote_buffer_layout = self .window .strata_buffer_layouts [origin ]
206
+ if field not in remote_buffer_layout :
207
+ raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
208
+ f"could not find { field = } on remote rank { origin } with "
209
+ f"class { self .communicators [origin ]['spcomm_class' ]} ."
210
+ )
211
+ _ , remote_length = remote_buffer_layout [field ]
212
+ if (length + 1 ) != remote_length :
213
+ raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
214
+ f"{ field = } has length { length } on local "
215
+ f"{ self .strata_rank = } and length { remote_length } "
216
+ f"on remote rank { origin } with class "
217
+ f"{ self .communicators [origin ]['spcomm_class' ]} ."
218
+ )
210
219
211
220
def register_recv_field (self , field : Field , origin : int , length : int = - 1 ) -> RecvArray :
212
221
# print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
@@ -217,13 +226,7 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
217
226
my_fa = self .receive_buffers [key ]
218
227
assert (length + 1 == np .size (my_fa .array ()))
219
228
else :
220
- available_fields_from_origin = self .send_fields_lengths_by_rank [origin ]
221
- for _field , _length in available_fields_from_origin :
222
- if field == _field :
223
- assert length == _length
224
- break
225
- else : # couldn't find field!
226
- raise RuntimeError (f"Couldn't find { field = } from { origin = } " )
229
+ self ._validate_recv_field (field , origin , length )
227
230
my_fa = RecvArray (length )
228
231
self .receive_buffers [key ] = my_fa
229
232
## End if
@@ -276,20 +279,10 @@ def hub_finalize(self):
276
279
def allreduce_or (self , val ):
277
280
return self .opt .allreduce_or (val )
278
281
279
- def free_windows (self ):
280
- """
281
- """
282
- if self ._windows_constructed :
283
- self .window .free ()
284
- self ._windows_constructed = False
285
-
286
- def make_windows (self ) -> None :
287
- if self ._windows_constructed :
288
- return
282
+ def _make_windows (self ) -> None :
289
283
290
284
window_spec = self ._build_window_spec ()
291
285
self .window = SPWindow (window_spec , self .strata_comm )
292
- self ._windows_constructed = True
293
286
294
287
return
295
288
@@ -305,6 +298,6 @@ def register_receive_fields(self) -> None:
305
298
if strata_rank == self .strata_rank :
306
299
continue
307
300
cls = comm ["spcomm_class" ]
308
- if field in self .send_fields_by_rank [strata_rank ]:
301
+ if field in self .ranks_to_fields [strata_rank ]:
309
302
buff = self .register_recv_field (field , strata_rank )
310
303
self .receive_field_spcomms [field ].append ((strata_rank , cls , buff ))
0 commit comments