diff --git a/src/sid/contacts.py b/src/sid/contacts.py index 299eea1c..aa08c6d1 100644 --- a/src/sid/contacts.py +++ b/src/sid/contacts.py @@ -621,28 +621,30 @@ def post_process_contacts(contacts, states, contact_models): contacts.loc[has_no_contacts, recurrent_models] = False if random_models: - random_contacts = contacts[random_models] - - integers = random_contacts.select_dtypes(include=np.integer).columns.tolist() + integers = ( + contacts[random_models].select_dtypes(include=np.integer).columns.tolist() + ) if integers: - random_contacts[integers] = random_contacts[integers].astype( - DTYPE_N_CONTACTS - ) + contacts[integers] = contacts[integers].astype(DTYPE_N_CONTACTS) - floats = random_contacts.select_dtypes(include=np.floating).columns.tolist() - if floats: - random_contacts[floats] = random_contacts[floats].apply( - lambda x: _sum_preserving_round(x.to_numpy()).astype(DTYPE_N_CONTACTS) - ) + floats = ( + contacts[random_models].select_dtypes(include=np.floating).columns.tolist() + ) + for float_col in floats: + contacts.loc[:, float_col] = _sum_preserving_round( + contacts[float_col].to_numpy() + ).astype(DTYPE_N_CONTACTS) - no_integers = random_contacts.select_dtypes(exclude=np.integer).columns.tolist() + no_integers = ( + contacts[random_models].select_dtypes(exclude=np.integer).columns.tolist() + ) if no_integers: - dtype_mapping = random_contacts[no_integers].dtypes.to_dict() + dtype_mapping = contacts[no_integers].dtypes.to_dict() raise ValueError( "The following contacts should be integers or floats, but they have a " f"different dtype.\n\n{dtype_mapping}" ) - + random_contacts = contacts[random_models] else: random_contacts = None