@@ -19,17 +19,17 @@ def _read_16_bit(fh: BinaryIO) -> int:
19
19
class Node :
20
20
__slots__ = ("children" , "is_leaf" , "symbol" )
21
21
22
- def __init__ (self , symbol : Symbol | None = None , is_leaf : bool = False ):
22
+ def __init__ (self , symbol : int = 0 , is_leaf : bool = False ):
23
23
self .symbol = symbol
24
24
self .is_leaf = is_leaf
25
- self .children = [None , None ]
25
+ self .children : list [ Node | None ] = [None , None ]
26
26
27
27
28
28
def _add_leaf (nodes : list [Node ], idx : int , mask : int , bits : int ) -> int :
29
29
node = nodes [0 ]
30
30
i = idx + 1
31
31
32
- while bits > 1 :
32
+ while node and bits > 1 :
33
33
bits -= 1
34
34
childidx = (mask >> bits ) & 1
35
35
if node .children [childidx ] is None :
@@ -38,6 +38,7 @@ def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int:
38
38
i += 1
39
39
node = node .children [childidx ]
40
40
41
+ assert node
41
42
node .children [mask & 1 ] = nodes [idx ]
42
43
return i
43
44
@@ -84,8 +85,9 @@ def _build_tree(buf: bytes) -> Node:
84
85
85
86
86
87
class BitString :
88
+ source : BinaryIO
89
+
87
90
def __init__ (self ):
88
- self .source = None
89
91
self .mask = 0
90
92
self .bits = 0
91
93
@@ -114,16 +116,18 @@ def skip(self, n: int) -> None:
114
116
self .mask += _read_16_bit (self .source ) << (16 - self .bits )
115
117
self .bits += 16
116
118
117
- def decode (self , root : Node ) -> Symbol :
119
+ def decode (self , root : Node ) -> int :
118
120
node = root
119
- while not node .is_leaf :
121
+ while node and not node .is_leaf :
120
122
bit = self .lookup (1 )
121
123
self .skip (1 )
122
124
node = node .children [bit ]
125
+
126
+ assert node
123
127
return node .symbol
124
128
125
129
126
- def decompress (src : bytes | BinaryIO ) -> bytes :
130
+ def decompress (src : bytes | bytearray | memoryview | BinaryIO ) -> bytes :
127
131
"""LZXPRESS decompress from a file-like object or bytes.
128
132
129
133
Decompresses until EOF of the input data.
@@ -134,7 +138,7 @@ def decompress(src: bytes | BinaryIO) -> bytes:
134
138
Returns:
135
139
The decompressed data.
136
140
"""
137
- if not hasattr (src , "read" ):
141
+ if isinstance (src , ( bytes , bytearray , memoryview ) ):
138
142
src = io .BytesIO (src )
139
143
140
144
dst = bytearray ()
0 commit comments