diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 3a0fc1ad8..089876243 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -255,16 +255,27 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { + // this computes the greatest positive power of 2 that is + // less than or equal to target. + fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) + } + assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; // define read size as a multiplier of record size. The multiplier must be // a power of two to align perfectly with total capacity. We don't want to exceed - // the target read size, so we pick a power of two <= read_size. - let read_size_multiplier = - // this computes the highest power of 2 less than or equal to read_size / record_size. - // Note, that if record_size is greater than read_size, we round it to 1 - 1 << (std::cmp::max(1, usize::BITS - (gateway_config.read_size.get() / record_size).leading_zeros()) - 1); + // the target read size, so multiplier * record_size <= read_size. We want to get + // as close as possible to read_size. + let read_size_multiplier = { + let target = gateway_config.read_size.get() / record_size; + // If record_size is greater than read_size, we set the multiplier to 1 + // as read size cannot be 0. + non_zero_prev_power_of_two(target) + }; let this = Self { total_capacity: total_capacity.try_into().unwrap(), @@ -279,7 +290,11 @@ impl SendChannelConfig { total_records, }; + // If capacity can't fit all active work items, the protocol deadlocks on + // inserts above the total capacity. assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + // if capacity is not aligned with read size, we can get a deadlock + // described in ipa/1300 assert_eq!(0, this.total_capacity.get() % this.read_size.get()); this