diff --git a/dissect/util/stream.py b/dissect/util/stream.py index bb6fb4e..ab20341 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -12,22 +12,27 @@ class AlignedStream(io.RawIOBase): - """Basic buffered stream that provides easy aligned reads. + """Basic buffered stream that provides aligned reads. Must be subclassed for various stream implementations. Subclasses can implement: - - _read(offset, length) - - _seek(pos, whence=io.SEEK_SET) + - :meth:`~AlignedStream._read` + - :meth:`~AlignedStream._readinto` + - :meth:`~AlignedStream._seek` - The offset and length for _read are guaranteed to be aligned. The only time + The offset and length for ``_read`` and ``_readinto`` are guaranteed to be aligned. The only time that overriding _seek would make sense is if there's no known size of your stream, - but still want to provide SEEK_END functionality. + but still want to provide ``SEEK_END`` functionality. - Most subclasses of AlignedStream take one or more file-like objects as source. + Most subclasses of ``AlignedStream`` take one or more file-like objects as source. Operations on these subclasses, like reading, will modify the source file-like object as a side effect. Args: - size: The size of the stream. This is used in read and seek operations. None if unknown. + size: The size of the stream. This is used in read and seek operations. ``None`` if unknown. align: The alignment size. Read operations are aligned on this boundary. Also determines buffer size. + + .. automethod:: _read + .. automethod:: _readinto + .. automethod:: _seek """ def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): @@ -38,27 +43,28 @@ def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): self._pos = 0 self._pos_align = 0 - self._buf = None - self._seek_lock = Lock() - - def _set_pos(self, pos: int) -> None: - """Update the position and aligned position within the stream.""" - new_pos_align = pos - (pos % self.align) + self._buf = memoryview(bytearray(align)) + self._buf_size = 0 + self._read_lock = Lock() - if self._pos_align != new_pos_align: - self._pos_align = new_pos_align - self._buf = None + def readable(self) -> bool: + """Indicate that the stream is readable.""" + return True - self._pos = pos + def seekable(self) -> bool: + """Indicate that the stream is seekable.""" + return True - def _fill_buf(self) -> None: - """Fill the alignment buffer if we can.""" - if self._buf or self.size is not None and (self.size <= self._pos or self.size <= self._pos_align): - return + def seek(self, pos: int, whence: int = io.SEEK_SET) -> int: + """Seek the stream to the specified position.""" + with self._read_lock: + pos = self._seek(pos, whence) + self._set_pos(pos) - self._buf = self._read(self._pos_align, self.align) + return pos def _seek(self, pos: int, whence: int = io.SEEK_SET) -> int: + """Calculate and return the new stream position after a seek.""" if whence == io.SEEK_SET: if pos < 0: raise ValueError(f"negative seek position {pos}") @@ -73,112 +79,169 @@ def _seek(self, pos: int, whence: int = io.SEEK_SET) -> int: return pos - def seek(self, pos: int, whence: int = io.SEEK_SET) -> int: - """Seek the stream to the specified position.""" - with self._seek_lock: - pos = self._seek(pos, whence) - self._set_pos(pos) + def _set_pos(self, pos: int) -> None: + """Update the position and aligned position within the stream.""" + new_pos_align = pos - (pos % self.align) - return pos + if self._pos_align != new_pos_align: + self._pos_align = new_pos_align + self._buf_size = 0 - def read(self, n: int = -1) -> bytes: - """Read and return up to n bytes, or read to the end of the stream if n is -1. + self._pos = pos - Returns an empty bytes object on EOF. + def tell(self) -> int: + """Return current stream position.""" + return self._pos + + def _fill_buf(self) -> None: + """Fill the alignment buffer if we can.""" + if self._buf_size or self.size is not None and (self.size <= self._pos or self.size <= self._pos_align): + # Don't fill the buffer if: + # - We already have a buffer + # - The stream position is at the end (or beyond) the stream size + return + + self._buf_size = self._readinto(self._pos_align, self._buf) + + def readinto(self, b: bytearray) -> int: + """Read bytes into a pre-allocated bytes-like object b. + + Returns an int representing the number of bytes read (0 for EOF). """ - if n is not None and n < -1: - raise ValueError("invalid number of bytes to read") + with self._read_lock: + return self._readinto_unlocked(b) + + def _readinto_unlocked(self, b: bytearray) -> int: + if not isinstance(b, memoryview): + b = memoryview(b) + b = b.cast("B") - r = [] + n = len(b) size = self.size align = self.align + total_read = 0 - with self._seek_lock: - if size is None and n == -1: - r = [] - if self._buf: - buffer_pos = self._pos - self._pos_align - r.append(self._buf[buffer_pos:]) - self._set_pos(self._pos_align + align) - - r.append(self._read(self._pos_align, -1)) + # If we know the stream size, adjust n + if size is not None: + remaining = size - self._pos - buf = b"".join(r) - self._set_pos(self._pos + len(buf)) - return buf + n = remaining if n == -1 else min(n, remaining) - if size is not None: - remaining = size - self._pos - n = remaining if n == -1 else min(n, remaining) + # Short path for when it turns out we don't need to read anything + if n == 0 or size is not None and size <= self._pos: + return 0 - if n == 0 or size is not None and size <= self._pos: - return b"" + # Read misaligned start from buffer + if self._pos != self._pos_align: + self._fill_buf() - # Read misaligned start from buffer - if self._pos != self._pos_align: - self._fill_buf() + buffer_pos = self._pos - self._pos_align + buffer_remaining = max(0, self._buf_size - buffer_pos) + read_len = min(n, buffer_remaining) - buffer_pos = self._pos - self._pos_align - remaining = align - buffer_pos - buffer_len = min(n, remaining) + b[:read_len] = self._buf[buffer_pos : buffer_pos + read_len] + b = b[read_len:] - r.append(self._buf[buffer_pos : buffer_pos + buffer_len]) + n -= read_len + total_read += read_len + self._set_pos(self._pos + read_len) - n -= buffer_len - self._set_pos(self._pos + buffer_len) + # Aligned blocks + if n >= align: + count, n = divmod(n, align) - # Aligned blocks - if n >= align: - count, n = divmod(n, align) + read_len = count * align + actual_read = self._readinto(self._pos, b[:read_len]) + b = b[actual_read:] - read_len = count * align - r.append(self._read(self._pos, read_len)) + total_read += actual_read + self._set_pos(self._pos + read_len) - self._set_pos(self._pos + read_len) + # Misaligned remaining bytes + if n > 0: + self._fill_buf() - # Misaligned end - if n > 0: - self._fill_buf() - r.append(self._buf[:n]) - self._set_pos(self._pos + n) + buffer_pos = self._pos - self._pos_align + buffer_remaining = max(0, min(align, self._buf_size) - buffer_pos) + read_len = min(n, buffer_remaining) - return b"".join(r) + b[:read_len] = self._buf[:read_len] - def readinto(self, b: bytearray) -> int: - """Read bytes into a pre-allocated bytes-like object b. + total_read += read_len + self._set_pos(self._pos + read_len) - Returns an int representing the number of bytes read (0 for EOF). - """ - buf = self.read(len(b)) - length = len(buf) - b[:length] = buf - return length + return total_read def _read(self, offset: int, length: int) -> bytes: - """Read method that backs this aligned stream.""" + """Provide an aligned ``read`` implementation for this stream.""" raise NotImplementedError("_read needs to be implemented by subclass") + def _readinto(self, offset: int, buf: memoryview) -> int: + """Provide an aligned ``readinto`` implementation for this stream. + + For backwards compatibility, ``AlignedStream`` provides a default ``_readinto`` implementation, implemented + in ``_readinto_fallback``, that falls back on ``_read``. However, subclasses should override the ``_readinto`` + method instead of ``_readinto_fallback``. + """ + return self._readinto_fallback(offset, buf) + + def _readinto_fallback(self, offset: int, buf: bytearray) -> int: + """``_readinto`` fallback on ``_read``.""" + read_len = len(buf) + result = self._read(offset, read_len) + length = len(result) + + if length > read_len: + raise IOError(f"_read returned more bytes than requested, wanted {read_len} and returned {length}") + + buf[:length] = result + return length + def readall(self) -> bytes: """Read until end of stream.""" - return self.read() + if self.size is not None: + # If we have a known stream size, we can do a more optimized read + return self.read(self.size - self._pos) + + with self._read_lock: + result = bytearray() + + if self._buf: + # Drain the buffer first + buffer_pos = self._pos - self._pos_align + buffer_remaining = max(0, min(self.align, self._buf_size) - buffer_pos) + result += self._buf[buffer_pos : buffer_pos + buffer_remaining] + + self._set_pos(self._pos + buffer_remaining) + + # Read the remaining bytes + try: + # Check if our stream implementation has a _read we can use + result += self._read(self._pos, -1) + except NotImplementedError: + # Otherwise call _readinto a bunch of times + buf = bytearray(io.DEFAULT_BUFFER_SIZE) + + while n := self._readinto(self._pos, buf): + result += buf[:n] + self._set_pos(self._pos + n) + + return bytes(result) def readoffset(self, offset: int, length: int) -> bytes: - """Convenience method to read from a certain offset with 1 call.""" + """Convenience method to read from a given offset.""" self.seek(offset) return self.read(length) - def tell(self) -> int: - """Return current stream position.""" - return self._pos + def peek(self, n: int = 0) -> bytes: + """Convenience method to peek from the current offset without advancing the stream position.""" + pos = self._pos + data = self.read(n) + self._set_pos(pos) + return data def close(self) -> None: - pass - - def readable(self) -> bool: - return True - - def seekable(self) -> bool: - return True + """Close the stream. Does nothing by default.""" class RangeStream(AlignedStream): @@ -202,14 +265,30 @@ def __init__(self, fh: BinaryIO, offset: int, size: int, align: int = STREAM_BUF super().__init__(size, align) self._fh = fh self.offset = offset + self._has_readinto = hasattr(self._fh, "readinto") + + def _seek(self, pos: int, whence: int = io.SEEK_SET) -> int: + if self.size is None and whence == io.SEEK_END: + pos = self._fh.seek(pos, whence) + if pos is None: + pos = self._fh.tell() + return max(0, pos - self.offset) + return super()._seek(pos, whence) def _read(self, offset: int, length: int) -> bytes: - read_length = min(length, self.size - offset) + # We will generally only end up here from :func:`AlignedStream.readall` + read_length = min(length, self.size - offset) if self.size else length self._fh.seek(self.offset + offset) return self._fh.read(read_length) + def _readinto(self, offset: int, buf: memoryview) -> int: + if self._has_readinto: + self._fh.seek(self.offset + offset) + return self._fh.readinto(buf) + return self._readinto_fallback(offset, buf) + -class RelativeStream(AlignedStream): +class RelativeStream(RangeStream): """Create a relative stream from another file-like object. ASCII representation:: @@ -227,22 +306,7 @@ class RelativeStream(AlignedStream): """ def __init__(self, fh: BinaryIO, offset: int, size: int | None = None, align: int = STREAM_BUFFER_SIZE): - super().__init__(size, align) - self._fh = fh - self.offset = offset - - def _seek(self, pos: int, whence: int = io.SEEK_SET) -> int: - if whence == io.SEEK_END: - pos = self._fh.seek(pos, whence) - if pos is None: - pos = self._fh.tell() - return max(0, pos - self.offset) - return super()._seek(pos, whence) - - def _read(self, offset: int, length: int) -> bytes: - read_length = min(length, self.size - offset) if self.size else length - self._fh.seek(self.offset + offset) - return self._fh.read(read_length) + super().__init__(fh, offset, size, align) class BufferedStream(RelativeStream): @@ -284,7 +348,7 @@ def add(self, offset: int, size: int, fh: BinaryIO, file_offset: int = 0) -> Non """ self._runs.append((offset, size, fh, file_offset)) self._runs = sorted(self._runs) - self._buf = None + self._buf_size = 0 self.size = self._runs[-1][0] + self._runs[-1][1] def _get_run_idx(self, offset: int) -> tuple[int, int, BinaryIO, int]: @@ -305,19 +369,22 @@ def _get_run_idx(self, offset: int) -> tuple[int, int, BinaryIO, int]: raise EOFError(f"No mapping for offset {offset}") - def _read(self, offset: int, length: int) -> bytes: - result = [] + def _readinto(self, offset: int, buf: memoryview) -> int: + size = self.size + runs = self._runs run_idx = self._get_run_idx(offset) runlist_len = len(self._runs) - size = self.size + + n = 0 + length = len(buf) while length > 0: if run_idx >= runlist_len: # We somehow requested more data than we have runs for break - run_offset, run_size, run_fh, run_file_offset = self._runs[run_idx] + run_offset, run_size, run_fh, run_file_offset = runs[run_idx] if run_offset > offset: # We landed in a gap, stop reading @@ -332,13 +399,18 @@ def _read(self, offset: int, length: int) -> bytes: read_count = min(size - offset, min(run_remaining, length)) run_fh.seek(run_file_offset + run_pos) - result.append(run_fh.read(read_count)) + if hasattr(run_fh, "readinto"): + n += run_fh.readinto(buf[:read_count]) + else: + buf[:read_count] = run_fh.read(read_count) + n += read_count offset += read_count length -= read_count + buf = buf[read_count:] run_idx += 1 - return b"".join(result) + return n class RunlistStream(AlignedStream): @@ -358,7 +430,12 @@ class RunlistStream(AlignedStream): """ def __init__( - self, fh: BinaryIO, runlist: list[tuple[int, int]], size: int, block_size: int, align: int | None = None + self, + fh: BinaryIO, + runlist: list[tuple[int, int]], + size: int, + block_size: int, + align: int | None = None, ): super().__init__(size, align or block_size) @@ -372,6 +449,7 @@ def __init__( self.runlist = runlist self.block_size = block_size + self._has_readinto = hasattr(self._fh, "readinto") @property def runlist(self) -> list[tuple[int, int]]: @@ -389,16 +467,21 @@ def runlist(self, runlist: list[tuple[int, int]]) -> None: self._runlist_offsets.append(offset) offset += block_count - self._buf = None + self._buf_size = 0 - def _read(self, offset: int, length: int) -> bytes: - r = [] + def _readinto(self, offset: int, buf: memoryview) -> int: + fh = self._fh + size = self.size + runlist = self.runlist + runlist_offsets = self._runlist_offsets + block_size = self.block_size block_offset = offset // self.block_size - run_idx = bisect_right(self._runlist_offsets, block_offset) runlist_len = len(self.runlist) - size = self.size + + n = 0 + length = len(buf) while length > 0: if run_idx >= runlist_len: @@ -406,11 +489,11 @@ def _read(self, offset: int, length: int) -> bytes: break # If run_idx == 0, we only have a single run - run_block_pos = 0 if run_idx == 0 else self._runlist_offsets[run_idx - 1] - run_block_offset, run_block_count = self.runlist[run_idx] + run_block_pos = 0 if run_idx == 0 else runlist_offsets[run_idx - 1] + run_block_offset, run_block_count = runlist[run_idx] - run_size = run_block_count * self.block_size - run_pos = offset - run_block_pos * self.block_size + run_size = run_block_count * block_size + run_pos = offset - run_block_pos * block_size run_remaining = run_size - run_pos # Sometimes the self.size is way larger than what we actually have runs for? @@ -422,16 +505,22 @@ def _read(self, offset: int, length: int) -> bytes: # Sparse run if run_block_offset is None: - r.append(b"\x00" * read_count) + buf[:read_count] = b"\x00" * read_count + n += read_count else: - self._fh.seek(run_block_offset * self.block_size + run_pos) - r.append(self._fh.read(read_count)) + fh.seek(run_block_offset * block_size + run_pos) + if self._has_readinto: + n += fh.readinto(buf[:read_count]) + else: + buf[:read_count] = fh.read(read_count) + n += read_count offset += read_count length -= read_count + buf = buf[read_count:] run_idx += 1 - return b"".join(r) + return n class OverlayStream(AlignedStream): @@ -451,6 +540,7 @@ def __init__(self, fh: BinaryIO, size: int | None = None, align: int = STREAM_BU self._fh = fh self.overlays: dict[int, tuple[int, BinaryIO]] = {} self._lookup: list[int] = [] + self._has_readinto = hasattr(self._fh, "readinto") def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> None: """Add an overlay at the given offset. @@ -482,14 +572,12 @@ def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> N self._lookup.sort() # Clear the buffer if we add an overlay at our current position - if self._buf and (self._pos_align <= offset + size and offset <= self._pos_align + len(self._buf)): - self._buf = None + if self._buf_size and (self._pos_align <= offset + size and offset <= self._pos_align + self.align): + self._buf_size = 0 return self - def _read(self, offset: int, length: int) -> bytes: - result = [] - + def _readinto(self, offset: int, buf: memoryview) -> int: fh = self._fh overlays = self.overlays lookup = self._lookup @@ -497,6 +585,9 @@ def _read(self, offset: int, length: int) -> bytes: overlay_len = len(overlays) overlay_idx = bisect_left(lookup, offset) + n = 0 + length = len(buf) + while length > 0: prev_overlay_offset = None if overlay_idx == 0 else lookup[overlay_idx - 1] next_overlay_offset = None if overlay_idx >= overlay_len else lookup[overlay_idx] @@ -512,10 +603,12 @@ def _read(self, offset: int, length: int) -> bytes: prev_overlay_read_size = min(length, prev_overlay_remaining) prev_overlay_data.seek(offset_in_prev_overlay) - result.append(prev_overlay_data.read(prev_overlay_read_size)) + buf[:prev_overlay_read_size] = prev_overlay_data.read(prev_overlay_read_size) + n += prev_overlay_read_size offset += prev_overlay_read_size length -= prev_overlay_read_size + buf = buf[prev_overlay_read_size:] if length == 0: break @@ -527,29 +620,44 @@ def _read(self, offset: int, length: int) -> bytes: if 0 <= gap_to_next_overlay < length: if gap_to_next_overlay: fh.seek(offset) - result.append(fh.read(gap_to_next_overlay)) + if self._has_readinto: + n += fh.readinto(buf[:gap_to_next_overlay]) + else: + buf[:gap_to_next_overlay] = fh.read(gap_to_next_overlay) + n += gap_to_next_overlay + buf = buf[gap_to_next_overlay:] # read remaining from overlay next_overlay_read_size = min(next_overlay_size, length - gap_to_next_overlay) next_overlay_data.seek(0) - result.append(next_overlay_data.read(next_overlay_read_size)) + buf[:next_overlay_read_size] = next_overlay_data.read(next_overlay_read_size) + n += next_overlay_read_size offset += next_overlay_read_size + gap_to_next_overlay length -= next_overlay_read_size + gap_to_next_overlay + buf = buf[next_overlay_read_size + gap_to_next_overlay :] else: # Next overlay is too far away, complete read fh.seek(offset) - result.append(fh.read(length)) + if self._has_readinto: + n += fh.readinto(buf[:length]) + else: + buf[:length] = fh.read(length) + n += length break else: # No next overlay, complete read fh.seek(offset) - result.append(fh.read(length)) + if self._has_readinto: + n += fh.readinto(buf[:length]) + else: + buf[:length] = fh.read(length) + n += length break overlay_idx += 1 - return b"".join(result) + return n class ZlibStream(AlignedStream): diff --git a/tests/test_stream.py b/tests/test_stream.py index 083c85f..198bfc6 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,6 +1,6 @@ import io import zlib -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -70,7 +70,8 @@ def test_buffered_stream() -> None: fh = stream.BufferedStream(buf, size=None) assert fh.read(10) == b"\x01" * 10 - assert fh._buf == buf.getvalue() + assert fh._buf[: len(buf.getvalue())] == buf.getvalue() + assert fh._buf_size == 512 * 3 assert fh.read() == buf.getvalue()[10:] assert fh.read(1) == b"" @@ -124,7 +125,7 @@ def test_aligned_stream_buffer() -> None: # Read aligned blocks so we move past the offset from where the buffer was read fh.read(1024) # Buffer should be reset - assert fh._buf is None + assert fh._buf_size == 0 # Buffer should now be from the 3rd aligned block assert fh.read(256) == b"\x03" * 256 assert fh._buf == b"\x03" * 512 @@ -203,3 +204,17 @@ def test_zlib_stream() -> None: fh.seek(0) assert fh.read() == data + + +def test_layered_readinto() -> None: + size = 1024 * 64 + buf = io.BytesIO(b"\x42" * size) + mock_buffer = Mock(wraps=buf) + + fh = stream.BufferedStream(stream.BufferedStream(mock_buffer, 0, size), 0, size) + tmp = bytearray(1024) + fh.readinto(tmp) + + # Test the bottom layer buffer was called with the cache object of the top layer + mock_buffer.readinto.assert_called_once() + assert mock_buffer.readinto.call_args[0][0].obj is fh._buf.obj