Skip to content

Commit

Permalink
refactor: avoid convulted code
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Mar 28, 2024
1 parent f5edabf commit e10a0f5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 27 deletions.
54 changes: 31 additions & 23 deletions src/picklescan/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def merge(self, sr: "ScanResult"):


class GenOpsError(Exception):
def __init__(self, msg: str):
def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]):
self.msg = msg
self.globals = globals
super().__init__()

def __str__(self) -> str:
Expand Down Expand Up @@ -177,16 +178,10 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
try:
ops = list(pickletools.genops(data))
except Exception as e:
# XXX: pickle will happily load files that contain arbitrarily placed new lines whereas pickletools errors in such cases.
# below is code to circumvent or skip these newlines while succeeding at parsing the opcodes.
err = str(e)
if "opcode b'\\n' unknown" not in err:
raise GenOpsError(err)
else:
pos = int(err.split(",")[0].replace("at position ", ""))
data.seek(-(pos + 1), 1)
ops = list(pickletools.genops(data.read(pos)))
data.seek(1, 1)
# XXX: given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle.
# Thus return the already found globals in the error & let the caller decide what to do.
globals_opt = globals if len(globals) > 0 else None
raise GenOpsError(str(e), globals_opt)

last_byte = data.read(1)
data.seek(-1, 1)
Expand Down Expand Up @@ -241,18 +236,12 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
return globals


def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult:
"""Disassemble a Pickle stream and report issues"""

def _build_scan_result_from_raw_globals(
raw_globals: Set[Tuple[str, str]],
file_id,
scan_err=False,
) -> ScanResult:
globals = []
try:
raw_globals = _list_globals(data, multiple_pickles)
except GenOpsError as e:
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
return ScanResult(globals, scan_err=True)

_log.debug("Global imports in %s: %s", file_id, raw_globals)

issues_count = 0
for rg in raw_globals:
g = Global(rg[0], rg[1], SafetyLevel.Dangerous)
Expand All @@ -278,7 +267,26 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
g.safety = SafetyLevel.Suspicious
globals.append(g)

return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, False)
return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, scan_err)


def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult:
"""Disassemble a Pickle stream and report issues"""

try:
raw_globals = _list_globals(data, multiple_pickles)
except GenOpsError as e:
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
if e.globals is not None:
return _build_scan_result_from_raw_globals(
e.globals, file_id, scan_err=True
)
else:
return ScanResult([], scan_err=True)

_log.debug("Global imports in %s: %s", file_id, raw_globals)

return _build_scan_result_from_raw_globals(raw_globals, file_id)


def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
Expand Down
14 changes: 14 additions & 0 deletions tests/data/malicious-invalid-bytes.pkl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Vos
p2
0Vsystem
p3
0Vtorch
p0
0VLongStorage
p1
0g2
g3
�(Vcat flag.txt
tR.


49 changes: 45 additions & 4 deletions tests/test_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ def initialize_pickle_files():
),
)

initialize_data_file(
f"{_root_path}/data/malicious-invalid-bytes.pkl",
b"".join(
[
pickle.UNICODE + b"os\n",
pickle.PUT + b"2\n",
pickle.POP,
pickle.UNICODE + b"system\n",
pickle.PUT + b"3\n",
pickle.POP,
pickle.UNICODE + b"torch\n",
pickle.PUT + b"0\n",
pickle.POP,
pickle.UNICODE + b"LongStorage\n",
pickle.PUT + b"1\n",
pickle.POP,
pickle.GET + b"2\n",
pickle.GET + b"3\n",
pickle.STACK_GLOBAL,
pickle.MARK,
pickle.UNICODE + b"cat flag.txt\n",
pickle.TUPLE,
pickle.REDUCE,
pickle.STOP,
b"\n\n\t\t",
]
),
)

# Code which created malicious12.pkl using pickleassem (see https://github.com/gousaiyang/pickleassem)
#
# p = PickleAssembler(proto=4)
Expand Down Expand Up @@ -351,7 +380,6 @@ def test_scan_pickle_bytes():


def test_scan_zip_bytes():

buffer = io.BytesIO()
with zipfile.ZipFile(buffer, "w") as zip:
zip.writestr("data.pkl", pickle.dumps(Malicious1()))
Expand Down Expand Up @@ -559,15 +587,17 @@ def test_scan_directory_path():
Global("torch", "_utils", SafetyLevel.Suspicious),
Global("__builtin__", "exec", SafetyLevel.Dangerous),
Global("os", "system", SafetyLevel.Dangerous),
Global("os", "system", SafetyLevel.Dangerous),
Global("operator", "attrgetter", SafetyLevel.Dangerous),
Global("builtins", "__import__", SafetyLevel.Suspicious),
Global("pickle", "loads", SafetyLevel.Dangerous),
Global("_pickle", "loads", SafetyLevel.Dangerous),
Global("_codecs", "encode", SafetyLevel.Suspicious),
],
scanned_files=26,
issues_count=24,
infected_files=21,
scanned_files=27,
issues_count=25,
infected_files=22,
scan_err=True,
)
compare_scan_results(scan_directory_path(f"{_root_path}/data/"), sr)

Expand Down Expand Up @@ -610,3 +640,14 @@ def test_pickle_files():
assert pickle.load(file) == 12345
with open(f"{_root_path}/data/malicious13b.pkl", "rb") as file:
assert pickle.load(file) == 12345


def test_invalid_bytes_err():
malicious_invalid_bytes = ScanResult(
[Global("os", "system", SafetyLevel.Dangerous)], 1, 1, 1, True
)
with open(f"{_root_path}/data/malicious-invalid-bytes.pkl", "rb") as file:
compare_scan_results(
scan_pickle_bytes(file, f"{_root_path}/data/malicious-invalid-bytes.pkl"),
malicious_invalid_bytes,
)

0 comments on commit e10a0f5

Please sign in to comment.