Skip to content

Commit 3b20fbd

Browse files
committed
Add type checking
1 parent 27b44fe commit 3b20fbd

21 files changed

+327
-156
lines changed

dissect/util/compression/__init__.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TYPE_CHECKING
2+
13
from dissect.util.compression import lz4 as lz4_python
24
from dissect.util.compression import lzo as lzo_python
35

@@ -16,8 +18,8 @@
1618
# Note that the pure Python implementation is not a full replacement of the
1719
# native lz4 Python package: only the decompress() function is implemented.
1820
try:
19-
import lz4.block as lz4
20-
import lz4.block as lz4_native
21+
import lz4.block as lz4 # type: ignore
22+
import lz4.block as lz4_native # type: ignore
2123
except ImportError:
2224
lz4 = lz4_python
2325
lz4_native = None
@@ -37,12 +39,19 @@
3739
# Note that the pure Python implementation is not a full replacement of the
3840
# native lzo Python package: only the decompress() function is implemented.
3941
try:
40-
import lzo
41-
import lzo as lzo_native
42+
import lzo # type: ignore
43+
import lzo as lzo_native # type: ignore
4244
except ImportError:
4345
lzo = lzo_python
4446
lzo_native = None
4547

48+
49+
from dissect.util.compression import lznt1, lzxpress, lzxpress_huffman, sevenbit
50+
51+
if TYPE_CHECKING:
52+
lzo = lzo_python
53+
lz4 = lz4_python
54+
4655
__all__ = [
4756
"lz4",
4857
"lz4_native",

dissect/util/compression/lz4.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
import struct
5-
from typing import BinaryIO
5+
from typing import BinaryIO, cast
66

77
from dissect.util.exceptions import CorruptDataError
88

@@ -25,12 +25,12 @@ def _get_length(src: BinaryIO, length: int) -> int:
2525

2626

2727
def decompress(
28-
src: bytes | BinaryIO,
28+
src: bytes | bytearray | memoryview | BinaryIO,
2929
uncompressed_size: int = -1,
3030
max_length: int = -1,
3131
return_bytearray: bool = False,
3232
return_bytes_read: bool = False,
33-
) -> bytes | tuple[bytes, int]:
33+
) -> bytes | bytearray | tuple[bytes | bytearray, int]:
3434
"""LZ4 decompress from a file-like object up to a certain length. Assumes no header.
3535
3636
Args:
@@ -44,7 +44,7 @@ def decompress(
4444
Returns:
4545
The decompressed data or a tuple of the decompressed data and the amount of bytes read.
4646
"""
47-
if not hasattr(src, "read"):
47+
if isinstance(src, (bytes, bytearray, memoryview)):
4848
src = io.BytesIO(src)
4949

5050
dst = bytearray()
@@ -78,7 +78,7 @@ def decompress(
7878
if len(read_buf) != 2:
7979
raise CorruptDataError("Premature EOF")
8080

81-
(offset,) = struct.unpack("<H", read_buf)
81+
(offset,) = cast(tuple[int], struct.unpack("<H", read_buf))
8282

8383
if offset == 0:
8484
raise CorruptDataError("Offset can't be 0")

dissect/util/compression/lznt1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _get_displacement(offset: int) -> int:
2525
TAG_MASKS = [(1 << i) for i in range(8)]
2626

2727

28-
def decompress(src: bytes | BinaryIO) -> bytes:
28+
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
2929
"""LZNT1 decompress from a file-like object or bytes.
3030
3131
Args:
@@ -34,7 +34,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
3434
Returns:
3535
The decompressed data.
3636
"""
37-
if not hasattr(src, "read"):
37+
if isinstance(src, (bytes, bytearray, memoryview)):
3838
src = io.BytesIO(src)
3939

4040
offset = src.tell()

dissect/util/compression/lzo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _read_length(src: BinaryIO, val: int, mask: int) -> int:
2323
return length + mask + val
2424

2525

26-
def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) -> bytes:
26+
def decompress(src: bytes | bytearray | memoryview | BinaryIO, header: bool = True, buflen: int = -1) -> bytes:
2727
"""LZO decompress from a file-like object or bytes. Assumes no header.
2828
2929
Arguments are largely compatible with python-lzo API.
@@ -36,7 +36,7 @@ def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) ->
3636
Returns:
3737
The decompressed data.
3838
"""
39-
if not hasattr(src, "read"):
39+
if isinstance(src, (bytes, bytearray, memoryview)):
4040
src = io.BytesIO(src)
4141

4242
dst = bytearray()

dissect/util/compression/lzxpress.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import BinaryIO
77

88

9-
def decompress(src: bytes | BinaryIO) -> bytes:
9+
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
1010
"""LZXPRESS decompress from a file-like object or bytes.
1111
1212
Args:
@@ -15,7 +15,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
1515
Returns:
1616
The decompressed data.
1717
"""
18-
if not hasattr(src, "read"):
18+
if isinstance(src, (bytes, bytearray, memoryview)):
1919
src = io.BytesIO(src)
2020

2121
offset = src.tell()
@@ -41,7 +41,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
4141
if src.tell() - offset == size:
4242
break
4343

44-
match = struct.unpack("<H", src.read(2))[0]
44+
match: int = struct.unpack("<H", src.read(2))[0]
4545
match_offset, match_length = divmod(match, 8)
4646
match_offset += 1
4747

dissect/util/compression/lzxpress_huffman.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ def _read_16_bit(fh: BinaryIO) -> int:
1919
class Node:
2020
__slots__ = ("children", "is_leaf", "symbol")
2121

22-
def __init__(self, symbol: Symbol | None = None, is_leaf: bool = False):
22+
def __init__(self, symbol: int = 0, is_leaf: bool = False):
2323
self.symbol = symbol
2424
self.is_leaf = is_leaf
25-
self.children = [None, None]
25+
self.children: list[Node | None] = [None, None]
2626

2727

2828
def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int:
2929
node = nodes[0]
3030
i = idx + 1
3131

32-
while bits > 1:
32+
while node and bits > 1:
3333
bits -= 1
3434
childidx = (mask >> bits) & 1
3535
if node.children[childidx] is None:
@@ -38,6 +38,7 @@ def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int:
3838
i += 1
3939
node = node.children[childidx]
4040

41+
assert node
4142
node.children[mask & 1] = nodes[idx]
4243
return i
4344

@@ -84,8 +85,9 @@ def _build_tree(buf: bytes) -> Node:
8485

8586

8687
class BitString:
88+
source: BinaryIO
89+
8790
def __init__(self):
88-
self.source = None
8991
self.mask = 0
9092
self.bits = 0
9193

@@ -114,16 +116,18 @@ def skip(self, n: int) -> None:
114116
self.mask += _read_16_bit(self.source) << (16 - self.bits)
115117
self.bits += 16
116118

117-
def decode(self, root: Node) -> Symbol:
119+
def decode(self, root: Node) -> int:
118120
node = root
119-
while not node.is_leaf:
121+
while node and not node.is_leaf:
120122
bit = self.lookup(1)
121123
self.skip(1)
122124
node = node.children[bit]
125+
126+
assert node
123127
return node.symbol
124128

125129

126-
def decompress(src: bytes | BinaryIO) -> bytes:
130+
def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
127131
"""LZXPRESS decompress from a file-like object or bytes.
128132
129133
Decompresses until EOF of the input data.
@@ -134,7 +138,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
134138
Returns:
135139
The decompressed data.
136140
"""
137-
if not hasattr(src, "read"):
141+
if isinstance(src, (bytes, bytearray, memoryview)):
138142
src = io.BytesIO(src)
139143

140144
dst = bytearray()

dissect/util/compression/sevenbit.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
from io import BytesIO
3+
import io
44
from typing import BinaryIO
55

66

7-
def compress(src: bytes | BinaryIO) -> bytes:
7+
def compress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes:
88
"""Sevenbit compress from a file-like object or bytes.
99
1010
Args:
@@ -13,8 +13,8 @@ def compress(src: bytes | BinaryIO) -> bytes:
1313
Returns:
1414
The compressed data.
1515
"""
16-
if not hasattr(src, "read"):
17-
src = BytesIO(src)
16+
if isinstance(src, (bytes, bytearray, memoryview)):
17+
src = io.BytesIO(src)
1818

1919
dst = bytearray()
2020

@@ -39,7 +39,7 @@ def compress(src: bytes | BinaryIO) -> bytes:
3939
return bytes(dst)
4040

4141

42-
def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes:
42+
def decompress(src: bytes | bytearray | memoryview | BinaryIO, wide: bool = False) -> bytes:
4343
"""Sevenbit decompress from a file-like object or bytes.
4444
4545
Args:
@@ -48,8 +48,8 @@ def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes:
4848
Returns:
4949
The decompressed data.
5050
"""
51-
if not hasattr(src, "read"):
52-
src = BytesIO(src)
51+
if isinstance(src, (bytes, bytearray, memoryview)):
52+
src = io.BytesIO(src)
5353

5454
dst = bytearray()
5555

dissect/util/compression/xz.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
CRC_SIZE = 4
99

1010

11-
def repair_checksum(fh: BinaryIO) -> BinaryIO:
11+
def repair_checksum(fh: BinaryIO) -> OverlayStream:
1212
"""Repair CRC32 checksums for all headers in an XZ stream.
1313
1414
FortiOS XZ files have (on purpose) corrupt streams which they read using a modified ``xz`` binary.
@@ -55,7 +55,7 @@ def repair_checksum(fh: BinaryIO) -> BinaryIO:
5555
# Parse the index
5656
isize, num_records = _mbi(index[1:])
5757
index = index[1 + isize : -4]
58-
records = []
58+
records: list[tuple[int, int]] = []
5959
for _ in range(num_records):
6060
if not index:
6161
raise ValueError("Missing index size")

0 commit comments

Comments
 (0)