Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce safer_eval #50

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions partd/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
description of the array's dtype.
"""
from __future__ import absolute_import

from toolz import identity, partial, valmap

import numpy as np
from toolz import valmap, identity, partial

from .compatibility import pickle
from .core import Interface
from .file import File
from .utils import frame, framesplit, suffix, ignoring
from .utils import frame, framesplit, ignoring, safer_eval, suffix


def serialize_dtype(dt):
Expand All @@ -34,7 +37,7 @@ def parse_dtype(s):
dtype([('a', '<i4')])
"""
if s.startswith(b'['):
return np.dtype(eval(s)) # Dangerous!
return np.dtype(safer_eval(s))
else:
return np.dtype(s)

Expand Down
49 changes: 46 additions & 3 deletions partd/tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import absolute_import

import pickle

import pytest

import partd
from partd.numpy import Numpy, parse_dtype
from partd.utils import safer_eval

np = pytest.importorskip('numpy') # noqa

import pickle

import partd
from partd.numpy import Numpy


def test_numpy():
Expand Down Expand Up @@ -70,3 +74,42 @@ def test_non_utf8_bytes():
b'\xf0\x28\x8c\xbc'], dtype='O')
s = partd.numpy.serialize(a)
assert (partd.numpy.deserialize(s, 'O') == a).all()


@pytest.mark.parametrize('text,parsed', [
(b'[("a", "i4")]', [('a', '<i4')]), # Test different quotation mark types.
(b"[('a', 'i4')]", [('a', '<i4')]),
(b"[('b', 'i2')]", [('b', '<i2')]),
(b"[('c', 'f8')]", [('c', '<f8')]),
(b"[('x', 'i4'), ('y', 'i4')]", [('x', '<i4'), ('y', '<i4')]),
(
b"[('a', 'i4'), ('b', 'i2'), ('c', 'f8')]",
[('a', '<i4'), ('b', '<i2'), ('c', '<f8')],
),
])
def test_safer_eval_tuple(text, parsed):
assert np.dtype(safer_eval(text)) == np.dtype(parsed)


@pytest.mark.parametrize('text,parsed', [
(b'a', 'S'),
(b'b', 'int8'),
(b'c', 'S1'),
(b'i2', 'int16'),
(b'i4', 'int32'),
(b'f8', 'float64'),
(b'M8[us]', '<M8[us]'),
(b'M8[s]', '<M8[s]'),
(b'datetime64[D]', '<M8[D]'),
(b'timedelta64[25s]', '<m8[25s]'),
(
b"i4, (2,3)f8",
[('f0', '<i4'), ('f1', '<f8', (2, 3))],
),
(
b"[('a', 'i4'), ('b', 'i2'), ('c', 'f8')]",
[('a', '<i4'), ('b', '<i2'), ('c', '<f8')],
),
])
def test_parse_dtype(text, parsed):
assert parse_dtype(text) == np.dtype(parsed)
11 changes: 7 additions & 4 deletions partd/tests/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import absolute_import

import os

import pytest
pytest.importorskip('pandas') # noqa

import numpy as np
import pandas as pd
import pandas.util.testing as tm
import os
import pandas.testing as tm
from partd.pandas import PandasBlocks, PandasColumns, deserialize, serialize

pytest.importorskip('pandas') # noqa


from partd.pandas import PandasColumns, PandasBlocks, serialize, deserialize


df1 = pd.DataFrame({'a': [1, 2, 3],
Expand Down
20 changes: 19 additions & 1 deletion partd/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from partd.utils import frame, framesplit
import struct

import pytest

from partd.utils import frame, framesplit, safer_eval


def test_frame():
assert frame(b'Hello') == struct.pack('Q', 5) + b'Hello'
Expand All @@ -9,3 +12,18 @@ def test_frame():
def test_framesplit():
L = [b'Hello', b'World!', b'123']
assert list(framesplit(b''.join(map(frame, L)))) == L


def test_safer_eval_safe():
assert safer_eval("[1, 2, 3]") == [1, 2, 3]
assert safer_eval("['a', 'b', 'c']") == ['a', 'b', 'c']


def test_safer_eval_unsafe():
with pytest.raises(ValueError) as excinfo:
safer_eval("\xe1")
assert "non-printable" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
safer_eval("__import__('os').system('ls')")
assert "__" in str(excinfo.value)
44 changes: 41 additions & 3 deletions partd/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from contextlib import contextmanager
import os
import shutil
import tempfile
import struct
import tempfile
from contextlib import contextmanager
from string import printable as _printable

# Exclude newline and tab characters from consideration.
printable = _printable[:-5]


def raises(exc, lamda):
Expand Down Expand Up @@ -48,7 +52,6 @@ def framesplit(bytes):
[b'Hello', b'World']
"""
i = 0; n = len(bytes)
chunks = list()
while i < n:
nbytes = struct.unpack('Q', bytes[i:i+8])[0]
i += 8
Expand Down Expand Up @@ -174,3 +177,38 @@ def extend(key, term):
key = (key,)

return key + term


def safer_eval(source):
""" A safer alternative to the built-in ``eval``

The further safety is achieved via additional checks over the input.

Please, note that this is not 100% bullet-proof, as it still internally
relies on ``eval``.

Examples
--------

>>> safer_eval("1")
1
>>> safer_eval("[1, 2, 3]")
[1, 2, 3]
>>> safer_eval("['a', 'b', 'c']")
['a', 'b', 'c']
"""
# Preserve the original type, if it's not ``str``, but ensure that sanity
# checks are performed over a ``str`` representation of the input.
string = source if type(source) is str else str(source)

# Disallow evaluation of non-printable chracters.
if any(map(lambda c: c not in printable, string)):
raise ValueError("Cannot evaluate strings containing non-printable characters")

# Disallow evaluation of dunder/magic Python methods.
# Access to the latter may recover ``__builtins__``.
if '__' in string:
raise ValueError("Cannot evaluate strings containing '__'")

# Disallow ``__builtins__`` (e.g., ``__import__``, etc.).
return eval(source, {'__builtins__': {}})