Skip to content

Commit d6d2c3c

Browse files
committed
Set types as attributes on the cstruct object
1 parent a315c88 commit d6d2c3c

File tree

4 files changed

+93
-36
lines changed

4 files changed

+93
-36
lines changed

dissect/cstruct/cstruct.py

+72-22
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ def __init__(self, endian: str = "<", pointer: str | None = None):
4848

4949
self.consts = {}
5050
self.lookups = {}
51+
self.types = {}
52+
self.typedefs = {}
5153
# fmt: off
52-
self.typedefs = {
54+
55+
initial_types = {
5356
# Internal types
5457
"int8": self._make_packed_type("int8", "b", int),
5558
"uint8": self._make_packed_type("uint8", "B", int),
@@ -93,6 +96,21 @@ def __init__(self, endian: str = "<", pointer: str | None = None):
9396
"signed long long": "long long",
9497
"unsigned long long": "uint64",
9598

99+
# Other convenience types
100+
"u1": "uint8",
101+
"u2": "uint16",
102+
"u4": "uint32",
103+
"u8": "uint64",
104+
"u16": "uint128",
105+
"__u8": "uint8",
106+
"__u16": "uint16",
107+
"__u32": "uint32",
108+
"__u64": "uint64",
109+
"uchar": "uint8",
110+
"ushort": "unsigned short",
111+
"uint": "unsigned int",
112+
"ulong": "unsigned long",
113+
96114
# Windows types
97115
"BYTE": "uint8",
98116
"CHAR": "char",
@@ -160,24 +178,12 @@ def __init__(self, endian: str = "<", pointer: str | None = None):
160178
"_DWORD": "uint32",
161179
"_QWORD": "uint64",
162180
"_OWORD": "uint128",
163-
164-
# Other convenience types
165-
"u1": "uint8",
166-
"u2": "uint16",
167-
"u4": "uint32",
168-
"u8": "uint64",
169-
"u16": "uint128",
170-
"__u8": "uint8",
171-
"__u16": "uint16",
172-
"__u32": "uint32",
173-
"__u64": "uint64",
174-
"uchar": "uint8",
175-
"ushort": "unsigned short",
176-
"uint": "unsigned int",
177-
"ulong": "unsigned long",
178181
}
179182
# fmt: on
180183

184+
for name, type_ in initial_types.items():
185+
self.add_type(name, type_)
186+
181187
pointer = pointer or ("uint64" if sys.maxsize > 2**32 else "uint32")
182188
self.pointer = self.resolve(pointer)
183189
self._anonymous_count = 0
@@ -188,37 +194,71 @@ def __getattr__(self, attr: str) -> Any:
188194
except KeyError:
189195
pass
190196

197+
try:
198+
return self.types[attr]
199+
except KeyError:
200+
pass
201+
191202
try:
192203
return self.resolve(self.typedefs[attr])
193204
except KeyError:
194205
pass
195206

196-
raise AttributeError(f"Invalid attribute: {attr}")
207+
return super().__getattribute__(attr)
197208

198209
def _next_anonymous(self) -> str:
199210
name = f"__anonymous_{self._anonymous_count}__"
200211
self._anonymous_count += 1
201212
return name
202213

214+
def _add_attr(self, name: str, value: Any, replace: bool = False) -> None:
215+
if not replace and (name in self.__dict__ and self.__dict__[name] != value):
216+
raise ValueError(f"Attribute already exists: {name}")
217+
setattr(self, name, value)
218+
203219
def add_type(self, name: str, type_: MetaType | str, replace: bool = False) -> None:
204220
"""Add a type or type reference.
205221
206222
Only use this method when creating type aliases or adding already bound types.
223+
All types will be resolved to their actual type objects prior to being added.
224+
Use :func:`add_typedef` to add type references.
207225
208226
Args:
209227
name: Name of the type to be added.
210228
type_: The type to be added. Can be a str reference to another type or a compatible type class.
229+
If a str is given, it will be resolved to the actual type object.
211230
212231
Raises:
213232
ValueError: If the type already exists.
214233
"""
215-
if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)):
234+
typeobj = self.resolve(type_)
235+
if not replace and (name in self.types and self.types[name] != typeobj):
216236
raise ValueError(f"Duplicate type: {name}")
217237

218-
self.typedefs[name] = type_
238+
self.types[name] = typeobj
239+
self._add_attr(name, typeobj, replace=replace)
219240

220241
addtype = add_type
221242

243+
def add_typedef(self, name: str, type_: str, replace: bool = False) -> None:
244+
"""Add a type reference.
245+
246+
Use this method to add type references to this cstruct instance. This is largely a convenience method for the
247+
internal :func:`add_type` method.
248+
249+
Args:
250+
name: Name of the type to be added.
251+
type_: The type reference to be added.
252+
replace: Whether to replace the type if it already exists.
253+
"""
254+
if not isinstance(type_, str):
255+
raise TypeError("Type reference must be a string")
256+
257+
if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)):
258+
raise ValueError(f"Duplicate type: {name}")
259+
260+
self.typedefs[name] = type_
261+
222262
def add_custom_type(
223263
self, name: str, type_: MetaType, size: int | None = None, alignment: int | None = None, **kwargs
224264
) -> None:
@@ -236,6 +276,16 @@ def add_custom_type(
236276
"""
237277
self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs))
238278

279+
def add_const(self, name: str, value: Any) -> None:
280+
"""Add a constant value.
281+
282+
Args:
283+
name: Name of the constant to be added.
284+
value: The value of the constant.
285+
"""
286+
self.consts[name] = value
287+
self._add_attr(name, value, replace=True)
288+
239289
def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct:
240290
"""Parse structures from the given definitions using the given definition type.
241291
@@ -307,14 +357,14 @@ def resolve(self, name: str) -> MetaType:
307357
return type_name
308358

309359
for _ in range(10):
360+
if type_name in self.types:
361+
return self.types[type_name]
362+
310363
if type_name not in self.typedefs:
311364
raise ResolveError(f"Unknown type {name}")
312365

313366
type_name = self.typedefs[type_name]
314367

315-
if not isinstance(type_name, str):
316-
return type_name
317-
318368
raise ResolveError(f"Recursion limit exceeded while resolving type {name}")
319369

320370
def _make_type(

dissect/cstruct/parser.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from dissect.cstruct import cstruct
18+
from dissect.cstruct.types.structure import Structure
1819

1920

2021
class Parser:
@@ -96,7 +97,7 @@ def _constant(self, tokens: TokenConsumer) -> None:
9697
except (ExpressionParserError, ExpressionTokenizerError):
9798
pass
9899

99-
self.cstruct.consts[match["name"]] = value
100+
self.cstruct.add_const(match["name"], value)
100101

101102
def _enum(self, tokens: TokenConsumer) -> None:
102103
# We cheat with enums because the entire enum is in the token
@@ -137,18 +138,21 @@ def _enum(self, tokens: TokenConsumer) -> None:
137138

138139
enum = factory(d["name"] or "", self.cstruct.resolve(d["type"]), values)
139140
if not enum.__name__:
140-
self.cstruct.consts.update(enum.__members__)
141+
for k, v in enum.__members__.items():
142+
self.cstruct.add_const(k, v)
141143
else:
142144
self.cstruct.add_type(enum.__name__, enum)
143145

144146
tokens.eol()
145147

146148
def _typedef(self, tokens: TokenConsumer) -> None:
147149
tokens.consume()
150+
type_name = None
148151
type_ = None
149152

150153
if tokens.next == self.TOK.IDENTIFIER:
151-
type_ = self.cstruct.resolve(self._identifier(tokens))
154+
type_name = self._identifier(tokens)
155+
type_ = self.cstruct.resolve(type_name)
152156
elif tokens.next == self.TOK.STRUCT:
153157
# The register thing is a bit dirty
154158
# Basically consumes all NAME tokens and
@@ -157,12 +161,15 @@ def _typedef(self, tokens: TokenConsumer) -> None:
157161

158162
names = self._names(tokens)
159163
for name in names:
160-
type_, name, bits = self._parse_field_type(type_, name)
164+
new_type, name, bits = self._parse_field_type(type_, name)
161165
if bits is not None:
162166
raise ParserError(f"line {self._lineno(tokens.previous)}: typedefs cannot have bitfields")
163-
self.cstruct.add_type(name, type_)
167+
if type_name is None or new_type is not type_:
168+
self.cstruct.add_type(name, new_type)
169+
else:
170+
self.cstruct.add_typedef(name, type_name)
164171

165-
def _struct(self, tokens: TokenConsumer, register: bool = False) -> None:
172+
def _struct(self, tokens: TokenConsumer, register: bool = False) -> type[Structure]:
166173
stype = tokens.consume()
167174

168175
factory = self.cstruct._make_union if stype.value.startswith("union") else self.cstruct._make_struct
@@ -399,7 +406,7 @@ def _constants(self, data: str) -> None:
399406
except (ValueError, SyntaxError):
400407
pass
401408

402-
self.cstruct.consts[d["name"]] = v
409+
self.cstruct.add_const(d["name"], v)
403410

404411
def _enums(self, data: str) -> None:
405412
r = re.finditer(
@@ -481,7 +488,7 @@ def _structs(self, data: str) -> None:
481488
if d["defs"]:
482489
for td in d["defs"].strip().split(","):
483490
td = td.strip()
484-
self.cstruct.add_type(td, st)
491+
self.cstruct.add_typedef(td, st)
485492

486493
def _parse_fields(self, data: str) -> None:
487494
fields = re.finditer(

dissect/cstruct/types/structure.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Struc
241241
sizes = {}
242242
for field in cls.__fields__:
243243
offset = stream.tell()
244-
field_type = cls.cs.resolve(field.type)
244+
field_type = field.type
245245

246246
if field.offset is not None and offset != struct_start + field.offset:
247247
# Field is at a specific offset, either alligned or added that way
@@ -295,7 +295,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int:
295295
num = 0
296296

297297
for field in cls.__fields__:
298-
field_type = cls.cs.resolve(field.type)
298+
field_type = field.type
299299

300300
bit_field_type = (
301301
(field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None
@@ -460,7 +460,7 @@ def _read_fields(
460460
buf = io.BytesIO(stream.read(cls.size))
461461

462462
for field in cls.__fields__:
463-
field_type = cls.cs.resolve(field.type)
463+
field_type = field.type
464464

465465
start = 0
466466
if field.offset is not None:

tests/test_basic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_load_file(cs: cstruct, compiled: bool) -> None:
3131
path = Path(__file__).parent / "data/testdef.txt"
3232

3333
cs.loadfile(path, compiled=compiled)
34-
assert "test" in cs.typedefs
34+
assert "test" in cs.types
3535

3636

3737
def test_read_type_name(cs: cstruct) -> None:
@@ -46,7 +46,7 @@ def test_type_resolve(cs: cstruct) -> None:
4646

4747
cs.add_type("ref0", "uint32")
4848
for i in range(1, 15): # Recursion limit is currently 10
49-
cs.add_type(f"ref{i}", f"ref{i - 1}")
49+
cs.add_typedef(f"ref{i}", f"ref{i - 1}")
5050

5151
with pytest.raises(ResolveError, match="Recursion limit exceeded"):
5252
cs.resolve("ref14")
@@ -404,7 +404,7 @@ def test_reserved_keyword(cs: cstruct, compiled: bool) -> None:
404404
cs.load(cdef, compiled=compiled)
405405

406406
for name in ["in", "class", "for"]:
407-
assert name in cs.typedefs
407+
assert name in cs.types
408408
assert verify_compiled(cs.resolve(name), compiled)
409409

410410
assert cs.resolve(name)(b"\x01").a == 1

0 commit comments

Comments
 (0)