Skip to content

Commit

Permalink
Limit memory consuption for tracking closed stream ids (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
iyangsj authored Apr 26, 2024
1 parent b6173ba commit f410451
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 14 deletions.
149 changes: 135 additions & 14 deletions src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ pub struct StreamMap {
/// a STOP_SENDING frame.
stopped: StreamIdHashMap<u64>,

/// Keep track of IDs of previously closed streams, to prevent peers from
/// re-creating them.
/// Keep track of IDs of previously closed streams. It can grow and use up a
/// lot of memory, so it is used only in unit tests.
#[cfg(test)]
closed: StreamIdHashSet,

/// Streams that peer are almost out of flow control capacity, and
Expand Down Expand Up @@ -647,12 +648,13 @@ impl StreamMap {
return Err(Error::ProtocolViolation);
}

let closed = self.is_closed(id);
match self.streams.entry(id) {
// 1.Can not find any stream with the given stream ID.
// It may not be created yet or it has been closed.
hash_map::Entry::Vacant(v) => {
// Stream has already been closed and collected into `closed`.
if self.closed.contains(&id) {
if closed {
return Err(Error::Done);
}

Expand Down Expand Up @@ -700,6 +702,7 @@ impl StreamMap {
self.next_stream_id_uni = self.next_stream_id_uni.saturating_add(4);
}

self.concurrency_control.remove_avail_id(id, self.is_server);
self.events.add(Event::StreamCreated(id));
Ok(v.insert(new_stream))
}
Expand Down Expand Up @@ -871,7 +874,7 @@ impl StreamMap {
/// Note that this method does not check if the stream id is complied with
/// the role of the endpoint.
fn mark_closed(&mut self, stream_id: u64, local: bool) {
if self.closed.contains(&stream_id) {
if self.is_closed(stream_id) {
return;
}

Expand All @@ -888,6 +891,10 @@ impl StreamMap {

self.mark_readable(stream_id, false);
self.mark_writable(stream_id, false);
if let Some(stream) = self.get_mut(stream_id) {
stream.mark_closed();
}
#[cfg(test)]
self.closed.insert(stream_id);

if self.events.add(Event::StreamClosed(stream_id)) {
Expand Down Expand Up @@ -1093,9 +1100,23 @@ impl StreamMap {
self.stopped.iter()
}

/// Return true if the stream has been closed and collected to `closed`.
/// Return true if the stream has been closed.
pub fn is_closed(&self, stream_id: u64) -> bool {
self.closed.contains(&stream_id)
// It is an existing stream
if let Some(stream) = self.get(stream_id) {
return stream.is_closed();
}

// It is a stream to be create
let is_server = self.is_server;
if self.concurrency_control.is_available(stream_id, is_server)
|| self.concurrency_control.is_limited(stream_id, is_server)
{
return false;
}

// It is a destroyed stream
true
}

/// Return true if there are any streams that have buffered data to send.
Expand Down Expand Up @@ -1703,8 +1724,11 @@ enum StreamFlags {
/// Upper layer want to read data from stream.
WantRead = 1 << 0,

// Upper layer want to write data to stream.
/// Upper layer want to write data to stream.
WantWrite = 1 << 1,

/// The stream has been closed and is waiting to release its resources.
Closed = 1 << 2,
}

#[derive(Default)]
Expand Down Expand Up @@ -1903,6 +1927,16 @@ impl Stream {

Ok(())
}

/// Check whether the stream is closed.
pub fn is_closed(&self) -> bool {
self.flags.contains(Closed)
}

/// Mark the stream as closed.
pub fn mark_closed(&mut self) {
self.flags.insert(Closed);
}
}

/// Return true if the stream was created locally.
Expand Down Expand Up @@ -2976,8 +3010,8 @@ impl StreamTransportParams {
}

/// Concurrency control for streams.
// RFC9000 4.6 Controlling Concurrency
// https://www.rfc-editor.org/rfc/rfc9000.html#name-controlling-concurrency
/// RFC9000 4.6 Controlling Concurrency
/// https://www.rfc-editor.org/rfc/rfc9000.html#name-controlling-concurrency
#[derive(Clone, Debug, PartialEq, Default)]
struct ConcurrencyControl {
/// Maximum bidirectional streams that the peer allow local endpoint to open.
Expand Down Expand Up @@ -3017,15 +3051,34 @@ struct ConcurrencyControl {
/// peer's concurrency control limit, we need to send a STREAMS_BLOCKED(type 0x17)
/// frame to notify peer.
streams_blocked_at_uni: Option<u64>,

/// Available stream ids for peer initiated bidirectional streams.
peer_bidi_avail_ids: ranges::RangeSet,

/// Available stream ids for peer initiated unidirectional streams.
peer_uni_avail_ids: ranges::RangeSet,

/// Available stream ids for local initiated bidirectional streams.
local_bidi_avail_ids: ranges::RangeSet,

/// Available stream ids for local initiated unidirectional streams.
local_uni_avail_ids: ranges::RangeSet,
}

impl ConcurrencyControl {
fn new(local_max_streams_bidi: u64, local_max_streams_uni: u64) -> ConcurrencyControl {
let mut peer_bidi_avail_ids = ranges::RangeSet::default();
peer_bidi_avail_ids.insert(0..local_max_streams_bidi);
let mut peer_uni_avail_ids = ranges::RangeSet::default();
peer_uni_avail_ids.insert(0..local_max_streams_uni);

ConcurrencyControl {
local_max_streams_bidi,
local_max_streams_bidi_next: local_max_streams_bidi,
local_max_streams_uni,
local_max_streams_uni_next: local_max_streams_uni,
peer_bidi_avail_ids,
peer_uni_avail_ids,
..ConcurrencyControl::default()
}
}
Expand All @@ -3035,7 +3088,12 @@ impl ConcurrencyControl {
fn update_peer_max_streams(&mut self, bidi: bool, max_streams: u64) {
match bidi {
true => {
self.peer_max_streams_bidi = cmp::max(self.peer_max_streams_bidi, max_streams);
if self.peer_max_streams_bidi < max_streams {
// insert available ids for local initiated bidi-streams
let ids = self.peer_max_streams_bidi..max_streams;
self.insert_avail_id(ids, true, true);
self.peer_max_streams_bidi = max_streams;
}

// Cancel the concurrency control blocked state if the max_streams_bidi limit
// is increased, avoid sending redundant STREAMS_BLOCKED(0x16) frames.
Expand All @@ -3045,7 +3103,12 @@ impl ConcurrencyControl {
}

false => {
self.peer_max_streams_uni = cmp::max(self.peer_max_streams_uni, max_streams);
if self.peer_max_streams_uni < max_streams {
// insert available ids for local initiated uni-streams
let ids = self.peer_max_streams_uni..max_streams;
self.insert_avail_id(ids, true, false);
self.peer_max_streams_uni = max_streams;
}

// Cancel the concurrency control blocked state if the max_streams_uni limit
// is increased, avoid sending redundant STREAMS_BLOCKED(type: 0x17) frames.
Expand All @@ -3058,9 +3121,16 @@ impl ConcurrencyControl {

/// After sending a MAX_STREAMS(type: 0x12..0x13) frame, update local max_streams limit.
fn update_local_max_streams(&mut self, bidi: bool) {
match bidi {
true => self.local_max_streams_bidi = self.local_max_streams_bidi_next,
false => self.local_max_streams_uni = self.local_max_streams_uni_next,
if bidi {
// insert available ids for peer initiated bidi-streams
let ids = self.local_max_streams_bidi..self.local_max_streams_bidi_next;
self.insert_avail_id(ids, false, true);
self.local_max_streams_bidi = self.local_max_streams_bidi_next;
} else {
// insert available ids for peer initiated uni-streams
let ids = self.local_max_streams_uni..self.local_max_streams_uni_next;
self.insert_avail_id(ids, false, false);
self.local_max_streams_uni = self.local_max_streams_uni_next;
}
}

Expand Down Expand Up @@ -3188,6 +3258,49 @@ impl ConcurrencyControl {

Ok(())
}

/// Check whether the given stream ID exceeds stream limits.
fn is_limited(&self, stream_id: u64, is_server: bool) -> bool {
let seq = (stream_id >> 2) + 1;
match (is_local(stream_id, is_server), is_bidi(stream_id)) {
(true, true) => seq > self.peer_max_streams_bidi,
(true, false) => seq > self.peer_max_streams_uni,
(false, true) => seq > self.local_max_streams_bidi,
(false, false) => seq > self.local_max_streams_uni,
}
}

/// Check whether the given stream id is available for stream creation.
fn is_available(&self, stream_id: u64, is_server: bool) -> bool {
let id = stream_id >> 2;
match (is_local(stream_id, is_server), is_bidi(stream_id)) {
(true, true) => self.local_bidi_avail_ids.contains(id),
(true, false) => self.local_uni_avail_ids.contains(id),
(false, true) => self.peer_bidi_avail_ids.contains(id),
(false, false) => self.peer_uni_avail_ids.contains(id),
}
}

/// Inset the given stream ids into available set.
fn insert_avail_id(&mut self, ids: Range<u64>, is_local: bool, is_bidi: bool) {
match (is_local, is_bidi) {
(true, true) => self.local_bidi_avail_ids.insert(ids),
(true, false) => self.local_uni_avail_ids.insert(ids),
(false, true) => self.peer_bidi_avail_ids.insert(ids),
(false, false) => self.peer_uni_avail_ids.insert(ids),
}
}

/// Remove the given stream id from available set.
fn remove_avail_id(&mut self, stream_id: u64, is_server: bool) {
let id = stream_id >> 2;
match (is_local(stream_id, is_server), is_bidi(stream_id)) {
(true, true) => self.local_bidi_avail_ids.remove_elem(id),
(true, false) => self.local_uni_avail_ids.remove_elem(id),
(false, true) => self.peer_bidi_avail_ids.remove_elem(id),
(false, false) => self.peer_uni_avail_ids.remove_elem(id),
}
}
}

/// Connection-level send capacity for all streams
Expand Down Expand Up @@ -6045,6 +6158,11 @@ mod tests {
#[test]
fn concurrency_control_new() {
let cc = ConcurrencyControl::new(10, 3);

let mut peer_bidi_avail_ids = ranges::RangeSet::default();
peer_bidi_avail_ids.insert(0..10);
let mut peer_uni_avail_ids = ranges::RangeSet::default();
peer_uni_avail_ids.insert(0..3);
assert_eq!(
cc,
ConcurrencyControl {
Expand All @@ -6060,6 +6178,9 @@ mod tests {
peer_opened_streams_uni: 0,
streams_blocked_at_bidi: None,
streams_blocked_at_uni: None,
peer_bidi_avail_ids,
peer_uni_avail_ids,
..ConcurrencyControl::default()
}
);
}
Expand Down
44 changes: 44 additions & 0 deletions src/ranges.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ impl RangeSet {
}
}

/// Remove `elem` from the set, i.e. remove range [elem, elem + 1) from the set.
pub fn remove_elem(&mut self, elem: u64) {
self.remove(elem..elem + 1);
}

/// Remove all ranges that are smaller or equal to `elem` from the set.
pub fn remove_until(&mut self, elem: u64) {
let ranges: Vec<Range<u64>> = self
Expand Down Expand Up @@ -227,6 +232,21 @@ impl RangeSet {
.next()
}

/// Check if the element exists or not
pub fn contains(&self, elem: u64) -> bool {
if let Some(prev) = self.prev_to(elem) {
if prev.contains(&elem) {
return true;
}
}
if let Some(next) = self.next_to(elem) {
if next.contains(&elem) {
return true;
}
}
false
}

/// Peek at the smallest range in the set.
pub fn peek_min(&self) -> Option<Range<u64>> {
let (&start, &end) = self.set.iter().next()?;
Expand Down Expand Up @@ -937,6 +957,30 @@ mod tests {
}
}

#[test]
fn contains() {
let mut r = RangeSet::default();
// Insert ranges: [2..6), [8, 13)
r.insert(2..6);
r.insert(8..13);

for i in [0, 1] {
assert_eq!(r.contains(i), false);
}
for i in 2..6 {
assert_eq!(r.contains(i), true);
}
for i in [6, 7] {
assert_eq!(r.contains(i), false);
}
for i in 8..13 {
assert_eq!(r.contains(i), true);
}
for i in 13..20 {
assert_eq!(r.contains(i), false);
}
}

#[test]
fn flatten() {
let mut r = RangeSet::default();
Expand Down

0 comments on commit f410451

Please sign in to comment.