Skip to content

Commit 82ebd46

Browse files
committed
Ensure non-empty buffers for large vectored I/O
`readv` and `writev` are constrained by a platform-specific upper bound on the number of buffers which can be passed. Currently, `read_vectored` and `write_vectored` implementations simply truncate to this limit when larger. However, when the only non-empty buffers are at indices above this limit, they will erroneously return `Ok(0)`. Instead, slice the buffers starting at the first non-empty buffer. This trades a conditional move for a branch, so it's barely a penalty in the common case. The new method `limit_slices` on `IoSlice` and `IoSliceMut` may be generally useful to users like `advance_slices` is, but I have left it as `pub(crate)` for now.
1 parent 5cc6072 commit 82ebd46

File tree

6 files changed

+125
-53
lines changed

6 files changed

+125
-53
lines changed

library/std/src/io/mod.rs

+37
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@
297297
#[cfg(test)]
298298
mod tests;
299299

300+
use core::intrinsics;
300301
#[unstable(feature = "read_buf", issue = "78485")]
301302
pub use core::io::{BorrowedBuf, BorrowedCursor};
302303
use core::slice::memchr;
@@ -1388,6 +1389,24 @@ impl<'a> IoSliceMut<'a> {
13881389
}
13891390
}
13901391

1392+
/// Limits a slice of buffers to at most `n` buffers.
1393+
///
1394+
/// When the slice contains over `n` buffers, ensure that at least one
1395+
/// non-empty buffer is in the truncated slice, if there is one.
1396+
#[inline]
1397+
pub(crate) fn limit_slices(bufs: &mut &mut [IoSliceMut<'a>], n: usize) {
1398+
if intrinsics::unlikely(bufs.len() > n) {
1399+
for (i, buf) in bufs.iter().enumerate() {
1400+
if !buf.is_empty() {
1401+
let len = cmp::min(bufs.len() - i, n);
1402+
*bufs = &mut take(bufs)[i..i + len];
1403+
return;
1404+
}
1405+
}
1406+
*bufs = &mut take(bufs)[..0];
1407+
}
1408+
}
1409+
13911410
/// Get the underlying bytes as a mutable slice with the original lifetime.
13921411
///
13931412
/// # Examples
@@ -1549,6 +1568,24 @@ impl<'a> IoSlice<'a> {
15491568
}
15501569
}
15511570

1571+
/// Limits a slice of buffers to at most `n` buffers.
1572+
///
1573+
/// When the slice contains over `n` buffers, ensure that at least one
1574+
/// non-empty buffer is in the truncated slice, if there is one.
1575+
#[inline]
1576+
pub(crate) fn limit_slices(bufs: &mut &[IoSlice<'a>], n: usize) {
1577+
if intrinsics::unlikely(bufs.len() > n) {
1578+
for (i, buf) in bufs.iter().enumerate() {
1579+
if !buf.is_empty() {
1580+
let len = cmp::min(bufs.len() - i, n);
1581+
*bufs = &bufs[i..i + len];
1582+
return;
1583+
}
1584+
}
1585+
*bufs = &bufs[..0];
1586+
}
1587+
}
1588+
15521589
/// Get the underlying bytes as a slice with the original lifetime.
15531590
///
15541591
/// This doesn't borrow from `self`, so is less restrictive than calling

library/std/src/sys/net/connection/socket/solid.rs

+7-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::os::solid::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, Owne
99
use crate::sys::abi;
1010
use crate::sys_common::{FromInner, IntoInner};
1111
use crate::time::Duration;
12-
use crate::{cmp, mem, ptr, str};
12+
use crate::{mem, ptr, str};
1313

1414
pub(super) mod netc {
1515
pub use crate::sys::abi::sockets::*;
@@ -222,13 +222,10 @@ impl Socket {
222222
self.recv_with_flags(buf, 0)
223223
}
224224

225-
pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
225+
pub fn read_vectored(&self, mut bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
226+
IoSliceMut::limit_slices(&mut bufs, max_iov());
226227
let ret = cvt(unsafe {
227-
netc::readv(
228-
self.as_raw_fd(),
229-
bufs.as_ptr() as *const netc::iovec,
230-
cmp::min(bufs.len(), max_iov()) as c_int,
231-
)
228+
netc::readv(self.as_raw_fd(), bufs.as_ptr() as *const netc::iovec, bufs.len() as c_int)
232229
})?;
233230
Ok(ret as usize)
234231
}
@@ -267,13 +264,10 @@ impl Socket {
267264
self.recv_from_with_flags(buf, MSG_PEEK)
268265
}
269266

270-
pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
267+
pub fn write_vectored(&self, mut bufs: &[IoSlice<'_>]) -> io::Result<usize> {
268+
IoSlice::limit_slices(&mut bufs, max_iov());
271269
let ret = cvt(unsafe {
272-
netc::writev(
273-
self.as_raw_fd(),
274-
bufs.as_ptr() as *const netc::iovec,
275-
cmp::min(bufs.len(), max_iov()) as c_int,
276-
)
270+
netc::writev(self.as_raw_fd(), bufs.as_ptr() as *const netc::iovec, bufs.len() as c_int)
277271
})?;
278272
Ok(ret as usize)
279273
}

library/std/src/sys/net/connection/socket/windows.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,6 @@ impl Socket {
299299
}
300300

301301
fn recv_with_flags(&self, mut buf: BorrowedCursor<'_>, flags: c_int) -> io::Result<()> {
302-
// On unix when a socket is shut down all further reads return 0, so we
303-
// do the same on windows to map a shut down socket to returning EOF.
304302
let length = cmp::min(buf.capacity(), i32::MAX as usize) as i32;
305303
let result =
306304
unsafe { c::recv(self.as_raw(), buf.as_mut().as_mut_ptr() as *mut _, length, flags) };
@@ -309,6 +307,9 @@ impl Socket {
309307
c::SOCKET_ERROR => {
310308
let error = unsafe { c::WSAGetLastError() };
311309

310+
// On Unix when a socket is shut down, all further reads return
311+
// 0, so we do the same on Windows to map a shut down socket to
312+
// returning EOF.
312313
if error == c::WSAESHUTDOWN {
313314
Ok(())
314315
} else {
@@ -332,17 +333,15 @@ impl Socket {
332333
self.recv_with_flags(buf, 0)
333334
}
334335

335-
pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
336-
// On unix when a socket is shut down all further reads return 0, so we
337-
// do the same on windows to map a shut down socket to returning EOF.
338-
let length = cmp::min(bufs.len(), u32::MAX as usize) as u32;
336+
pub fn read_vectored(&self, mut bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
337+
IoSliceMut::limit_slices(&mut bufs, u32::MAX as usize);
339338
let mut nread = 0;
340339
let mut flags = 0;
341340
let result = unsafe {
342341
c::WSARecv(
343342
self.as_raw(),
344343
bufs.as_mut_ptr() as *mut c::WSABUF,
345-
length,
344+
bufs.len() as u32,
346345
&mut nread,
347346
&mut flags,
348347
ptr::null_mut(),
@@ -355,6 +354,9 @@ impl Socket {
355354
_ => {
356355
let error = unsafe { c::WSAGetLastError() };
357356

357+
// On Unix when a socket is shut down, all further reads return
358+
// 0, so we do the same on Windows to map a shut down socket to
359+
// returning EOF.
358360
if error == c::WSAESHUTDOWN {
359361
Ok(0)
360362
} else {
@@ -384,8 +386,6 @@ impl Socket {
384386
let mut addrlen = size_of_val(&storage) as netc::socklen_t;
385387
let length = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;
386388

387-
// On unix when a socket is shut down all further reads return 0, so we
388-
// do the same on windows to map a shut down socket to returning EOF.
389389
let result = unsafe {
390390
c::recvfrom(
391391
self.as_raw(),
@@ -401,6 +401,9 @@ impl Socket {
401401
c::SOCKET_ERROR => {
402402
let error = unsafe { c::WSAGetLastError() };
403403

404+
// On Unix when a socket is shut down, all further reads return
405+
// 0, so we do the same on Windows to map a shut down socket to
406+
// returning EOF.
404407
if error == c::WSAESHUTDOWN {
405408
Ok((0, unsafe { socket_addr_from_c(&storage, addrlen as usize)? }))
406409
} else {
@@ -419,14 +422,14 @@ impl Socket {
419422
self.recv_from_with_flags(buf, c::MSG_PEEK)
420423
}
421424

422-
pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
423-
let length = cmp::min(bufs.len(), u32::MAX as usize) as u32;
425+
pub fn write_vectored(&self, mut bufs: &[IoSlice<'_>]) -> io::Result<usize> {
426+
IoSlice::limit_slices(&mut bufs, u32::MAX as usize);
424427
let mut nwritten = 0;
425428
let result = unsafe {
426429
c::WSASend(
427430
self.as_raw(),
428431
bufs.as_ptr() as *const c::WSABUF as *mut _,
429-
length,
432+
bufs.len() as u32,
430433
&mut nwritten,
431434
0,
432435
ptr::null_mut(),

library/std/src/sys/pal/hermit/fd.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#![unstable(reason = "not public", issue = "none", feature = "fd")]
22

33
use super::hermit_abi;
4-
use crate::cmp;
54
use crate::io::{self, BorrowedCursor, IoSlice, IoSliceMut, Read, SeekFrom};
65
use crate::os::hermit::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
76
use crate::sys::{cvt, unsupported};
@@ -38,12 +37,13 @@ impl FileDesc {
3837
Ok(())
3938
}
4039

41-
pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
40+
pub fn read_vectored(&self, mut bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
41+
IoSliceMut::limit_slices(&mut bufs, max_iov());
4242
let ret = cvt(unsafe {
4343
hermit_abi::readv(
4444
self.as_raw_fd(),
4545
bufs.as_mut_ptr() as *mut hermit_abi::iovec as *const hermit_abi::iovec,
46-
cmp::min(bufs.len(), max_iov()),
46+
bufs.len(),
4747
)
4848
})?;
4949
Ok(ret as usize)
@@ -65,12 +65,13 @@ impl FileDesc {
6565
Ok(result as usize)
6666
}
6767

68-
pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
68+
pub fn write_vectored(&self, mut bufs: &[IoSlice<'_>]) -> io::Result<usize> {
69+
IoSlice::limit_slices(&mut bufs, max_iov());
6970
let ret = cvt(unsafe {
7071
hermit_abi::writev(
7172
self.as_raw_fd(),
7273
bufs.as_ptr() as *const hermit_abi::iovec,
73-
cmp::min(bufs.len(), max_iov()),
74+
bufs.len(),
7475
)
7576
})?;
7677
Ok(ret as usize)

0 commit comments

Comments
 (0)