diff --git a/nemo/lightning/io/deserialize.py b/nemo/lightning/io/deserialize.py new file mode 100644 index 000000000000..0017b3c7266f --- /dev/null +++ b/nemo/lightning/io/deserialize.py @@ -0,0 +1,558 @@ +import builtins +import logging +import re +import sys +import time +from pathlib import Path, PosixPath, WindowsPath +from typing import Any, Dict, Final, List, Optional, Union + +from fiddle._src import daglish +from fiddle._src.experimental.serialization import _VALUE_KEY, Deserialization, PyrefPolicy, register_node_traverser + +# Configure logging +logger = logging.getLogger(__name__) + +# Make critical sets immutable using frozenset +BLOCKED_MODULES: Final[frozenset] = frozenset( + { + # System and OS operations + "os", + "sys", + "subprocess", + "shutil", + # Serialization and code execution + "pickle", + "marshal", + "shelve", + "code", + "codeop", + # File and I/O operations + "io", + "tempfile", + "pathlib", + "zipfile", + "tarfile", + # Network and IPC + "socket", + "asyncio", + "multiprocessing", + "threading", + "http", + "http.server", + "urllib", + "urllib.request", + "wsgiref", + # System information and configuration + "platform", + "pwd", + "grp", + "resource", + # Package management and imports + "importlib", + "pkg_resources", + "setuptools", + "distutils", + # Low-level system access + "ctypes", + "mmap", + "fcntl", + "signal", + # Debug and development + "pdb", + "trace", + "gc", + "inspect", + "dis", + "ast", + # XML processing + "xml", + "xml.etree", + "xml.sax", + "xml.dom", + # Encoding and crypto + "base64", + "codecs", + "crypt", + # Terminal and process control + "pty", + "tty", + "termios", + "pipes", + # System logging + "syslog", + "logging.handlers", + # Additional dangerous modules + "commands", + "_thread", + "select", + "readline", + "spwd", + "grp", + "nis", + "site", + "winreg", + "msvcrt", + "winsound", + "venv", + "uuid", + } +) + +DANGEROUS_BUILTINS: Final[frozenset] = frozenset( + { + "eval", + "exec", + "compile", + "__import__", + "open", + "input", + "globals", + "locals", + "vars", + "getattr", + "setattr", + "delattr", + "breakpoint", + "memoryview", + "classmethod", + "staticmethod", + "property", + "dir", + "type", + "object", + "super", + "format", + "frozenset", + "help", + "copyright", + "credits", + "license", + "print", + "repr", + "ascii", + "hash", + "hex", + "oct", + "bin", + "id", + } +) + +TRUSTED_MODULES: Final[frozenset] = frozenset( + { + "nemo", + "nemo_run", + "nemo_alligner", + "nemo_curator", + "torch", + "pytorch_lightning", + "lightning", + "numpy", + "collections", + "typing", + "enum", + "dataclasses", + "pathlib", + } +) + +DANGEROUS_SPECIAL_METHODS: Final[frozenset] = frozenset( + { + "__call__", + "__new__", + "__init__", + "__del__", + "__getattr__", + "__setattr__", + "__delattr__", + "__class_getitem__", + "__get__", + "__set__", + "__delete__", + "__getattribute__", + "__slots__", + "__subclasses__", + "__bases__", + "__class__", + "__mro__", + "__reduce__", + "__reduce_ex__", + "__subclasshook__", + "__init_subclass__", + "__prepare__", + "__instancecheck__", + "__subclasscheck__", + "__descr_get__", + "__descr_set__", + "__descr_delete__", + "__delete__", + "__set_name__", + "__objclass__", + "__annotations__", + } +) + + +class DeserializationError(Exception): + """Custom exception for deserialization errors.""" + + pass + + +class SecurityViolationError(DeserializationError): + """Raised when a security violation is detected.""" + + pass + + +def path_flatten(value: Path) -> tuple[tuple, str]: + """Flatten a Path object into its string representation.""" + return ((), str(value)) + + +def path_unflatten(values: tuple, metadata: str) -> Path: + """Reconstruct a Path object from its string representation.""" + return Path(metadata) + + +def path_elements(value: Path) -> tuple: + """Return an empty tuple since Path has no traversable elements.""" + return () + + +# Register the traverser for Path objects +daglish.register_node_traverser( + Path, + flatten_fn=path_flatten, + unflatten_fn=path_unflatten, + path_elements_fn=path_elements, +) +register_node_traverser(Path, path_flatten, path_unflatten, path_elements) + + +class SafePyrefPolicy(PyrefPolicy): + """A security-enhanced version of PyrefPolicy that restricts module imports.""" + + # Class constants + SAFE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$") + SUSPICIOUS_PATTERNS: Final = tuple( + [ + re.compile(r"__[^_\W]+__"), # Magic methods + re.compile(r"\\x[0-9a-fA-F]{2}"), # Hex escapes + re.compile(r"\\[0-7]{1,3}"), # Octal escapes + re.compile(r"\\u[0-9a-fA-F]{4}"), # Unicode escape + ] + ) + DANGEROUS_SPECIAL_METHODS: Final[frozenset] = frozenset( + { + "__call__", + "__new__", + "__init__", + "__del__", + "__getattr__", + "__setattr__", + "__delattr__", + "__class_getitem__", + "__get__", + "__set__", + "__delete__", + "__getattribute__", + "__slots__", + "__subclasses__", + "__bases__", + "__class__", + "__mro__", + "__reduce__", + "__reduce_ex__", + "__subclasshook__", + "__init_subclass__", + "__prepare__", + "__instancecheck__", + "__subclasscheck__", + "__descr_get__", + "__descr_set__", + "__descr_delete__", + "__delete__", + "__set_name__", + "__objclass__", + "__annotations__", + } + ) + + def __init__( + self, + safe_remote_code: bool = False, + max_depth: int = 100, + max_string_length: int = 100_000, + max_collection_size: int = 10_000, + ): + """Initialize the policy with security settings.""" + super().__init__() + # Initialize all attributes at once using __dict__ + self.__dict__.update( + { + "safe_remote_code": safe_remote_code, + "max_depth": max_depth, + "max_string_length": max_string_length, + "max_collection_size": max_collection_size, + "_current_depth": 0, + "blocked_modules": BLOCKED_MODULES, + "trusted_modules": TRUSTED_MODULES, + # Define safe types as a list instead of a set + "safe_primitives": (int, float, str, bool, type(None)), + "safe_collections": (list, tuple, dict, set, frozenset), + "safe_path_types": (Path, PosixPath, WindowsPath), + } + ) + + def allows_value(self, value: Any) -> bool: + """Check if a value is allowed to be deserialized.""" + # Special case for Path class and instances + if value is Path or isinstance(value, (Path, PosixPath, WindowsPath)): + return True + + # Check for None + if value is None: + return True + + # Check collections first to avoid unhashable type error + if isinstance(value, (list, tuple, set)): + if len(value) > self.max_collection_size: + return False + return all(self.allows_value(item) for item in value) + + if isinstance(value, dict): + if len(value) > self.max_collection_size: + return False + return all(self.allows_value(k) and self.allows_value(v) for k, v in value.items()) + + if isinstance(value, str): + if len(value) > self.max_string_length: + return False + return not any(pattern.search(value) for pattern in self.SUSPICIOUS_PATTERNS) + + # For all other types, just return True if it's a primitive type + return isinstance(value, (int, float, bool, type(None))) + + def _check_depth(self) -> None: + """Check and increment the recursion depth.""" + current = self._current_depth + if current > self.max_depth: + raise ValueError(f"Maximum recursion depth {self.max_depth} exceeded") + # Use __dict__ to bypass __setattr__ + self.__dict__["_current_depth"] = current + 1 + + def _check_timeout(self) -> None: + """Check if execution time limit has been exceeded.""" + if time.time() - self._start_time > self.execution_timeout: + raise ResourceLimitError(f"Execution timeout ({self.execution_timeout}s) exceeded") + + def _validate_int(self, value: int) -> bool: + """Validate integer values.""" + return -sys.maxsize <= value <= sys.maxsize + + def _validate_float(self, value: float) -> bool: + """Validate float values.""" + return not (value in (float("inf"), float("-inf"), float("nan"))) + + def _validate_str(self, value: str) -> bool: + """Validate string content.""" + if len(value) > self.max_string_length: + return False + + # Check for suspicious patterns + for pattern in self.SUSPICIOUS_PATTERNS: + if pattern.search(value): + return False + + return True + + def _validate_bytes(self, value: bytes) -> bool: + """Validate bytes content.""" + return len(value) <= self.max_string_length + + def _validate_sequence(self, value: Union[List, tuple, set]) -> bool: + """Validate sequence types.""" + return len(value) <= self.max_collection_size + + def _validate_mapping(self, value: Dict) -> bool: + """Validate mapping types.""" + if len(value) > self.max_collection_size: + return False + + # Check key types (should only be strings or numbers) + return all(isinstance(k, (str, int, float)) for k in value.keys()) + + def _validate_name(self, name: str) -> bool: + """Validate that a name follows safe naming conventions.""" + if not name or len(name) > 64: # Reasonable maximum length + return False + return bool(self.SAFE_NAME_PATTERN.match(name)) + + def allows_import(self, module: str, symbol: str) -> bool: + """Check if importing the given symbol from the module is allowed. + + Args: + module: The module to import from + symbol: The symbol to import + + Returns: + bool: Whether the import is allowed + """ + # Special case for pathlib.Path + if module == "pathlib" and symbol == "Path": + return True + + # Get root module (e.g., 'os.path' -> 'os') + module_root = module.split(".")[0] + + # Block access to dangerous modules + if module_root in BLOCKED_MODULES: + return False + + # Check if it's a dangerous builtin + if module == "builtins" and symbol in DANGEROUS_BUILTINS: + logger.warning(f"Blocked import of dangerous builtin: {symbol}") + return False + + # In restricted mode, only allow trusted modules + if not self.safe_remote_code: + # Check if the module or any of its parents are in trusted_modules + module_parts = module.split(".") + for i in range(len(module_parts)): + current_module = ".".join(module_parts[: i + 1]) + if current_module in self.trusted_modules: + return True + + logger.warning(f"Blocked import from untrusted module in restricted mode: {module}") + return False + + # Validate names + if not self._validate_name(module_root) or not self._validate_name(symbol): + logger.warning(f"Blocked import due to invalid name: {module}.{symbol}") + return False + + return True + + def __enter__(self): + """Context manager to track recursion depth.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Reset recursion depth on exit.""" + self._current_depth -= 1 + + def _check_collection_size(self, value: Any) -> None: + """Check if a collection exceeds size limits. + + Args: + value: The collection to check + + Raises: + ValueError: If the collection size exceeds limits + """ + if isinstance(value, (str, bytes)): + if len(value) > self.max_string_length: + raise ValueError(f"String length {len(value)} exceeds limit {self.max_string_length}") + elif isinstance(value, (list, tuple, set, dict)): + if len(value) > self.max_collection_size: + raise ValueError(f"Collection size {len(value)} exceeds limit {self.max_collection_size}") + + +class SafeDeserialization(Deserialization): + """A security-enhanced version of Deserialization that restricts module imports.""" + + def __init__( + self, + serialized_value: Dict[str, Any], + safe_remote_code: bool = False, + pyref_policy: Optional[PyrefPolicy] = None, + max_depth: int = 100, + max_string_length: int = 100_000, + max_collection_size: int = 10_000, + ): + if not isinstance(serialized_value, dict): + raise TypeError("serialized_value must be a dictionary") + + # Store limits + self.max_string_length = max_string_length + self.max_collection_size = max_collection_size + + # Verify the integrity of our security measures + self._verify_security_integrity() + + try: + if pyref_policy is None: + pyref_policy = SafePyrefPolicy( + safe_remote_code=safe_remote_code, + max_depth=max_depth, + max_string_length=max_string_length, + max_collection_size=max_collection_size, + ) + super().__init__(serialized_value, pyref_policy=pyref_policy) + except Exception as e: + logger.error(f"Deserialization failed: {str(e)}") + raise DeserializationError(f"Failed to deserialize: {str(e)}") from e + + @staticmethod + def _verify_security_integrity(): + """Verify that security measures haven't been compromised.""" + # Check if critical sets are still frozen + if not isinstance(DANGEROUS_BUILTINS, frozenset): + raise SecurityViolationError("Security violation: DANGEROUS_BUILTINS has been modified") + if not isinstance(BLOCKED_MODULES, frozenset): + raise SecurityViolationError("Security violation: BLOCKED_MODULES has been modified") + if not isinstance(TRUSTED_MODULES, frozenset): + raise SecurityViolationError("Security violation: TRUSTED_MODULES has been modified") + + # Verify SafePyrefPolicy hasn't been tampered with + if not hasattr(SafePyrefPolicy, "__setattr__"): + raise SecurityViolationError("Security violation: SafePyrefPolicy protection removed") + + # Verify that built-in functions haven't been monkey-patched + for func_name in ("getattr", "setattr", "delattr", "eval", "exec"): + builtin_func = getattr(builtins, func_name) + builtins_func = ( + __builtins__[func_name] if isinstance(__builtins__, dict) else getattr(__builtins__, func_name) + ) + if builtin_func is not builtins_func: + raise SecurityViolationError(f"Security violation: built-in {func_name} has been modified") + + def _validate_value(self, value: Any) -> None: + """Recursively validate a value for security concerns.""" + # Check for dangerous types + if callable(value): + raise DeserializationError(f"Callable objects are not allowed: {value}") + + # Check collections recursively + if isinstance(value, dict): + for k, v in value.items(): + if not isinstance(k, str): + raise DeserializationError(f"Dictionary keys must be strings, got: {type(k)}") + self._validate_value(v) + elif isinstance(value, (list, tuple, set)): + for item in value: + self._validate_value(item) + + def _deserialize_leaf(self, leaf): + """Override leaf deserialization to add security checks.""" + value = leaf[_VALUE_KEY] + + # Check string length limits + if isinstance(value, str) and len(value) > self.max_string_length: + raise DeserializationError(f"String length {len(value)} exceeds limit {self.max_string_length}") + + # Validate nested structures + self._validate_value(value) + + return value + + def __setattr__(self, name, value): + if hasattr(self, "_initialized"): + raise AttributeError("Cannot modify SafeDeserialization attributes") + super().__setattr__(name, value) + if name == "_result": # Last attribute set in parent's __init__ + self._initialized = True + + def __delattr__(self, name): + raise AttributeError("Cannot delete SafeDeserialization attributes") diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 08768f54448c..b5a4ba2871a0 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -37,6 +37,7 @@ from nemo.lightning.io.artifact.base import Artifact from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.connector import ModelConnector +from nemo.lightning.io.deserialize import SafeDeserialization from nemo.lightning.io.fdl_torch import enable as _enable_ext from nemo.lightning.io.to_config import to_config from nemo.utils import logging @@ -741,7 +742,16 @@ def analyze(config: fdl.Config, prefix: str): return updated -def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None, build: bool = True) -> CkptType: +def load( + path: Path, + output_type: Type[CkptType] = Any, + subpath: Optional[str] = None, + build: bool = True, + trust_remote_code: bool = False, + max_depth: int = 100, + max_string_length: int = 100_000, + max_collection_size: int = 10_000, +) -> CkptType: """ Loads a configuration from a pickle file and constructs an object of the specified type. @@ -750,6 +760,13 @@ def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = output_type (Type[CkptType]): The type of the object to be constructed from the loaded data. subpath (Optional[str]): Subpath to selectively load only specific objects inside the output_type. Defaults to None. + build (bool): Whether to build the config into an object. Defaults to True. + trust_remote_code (bool): Whether to allow custom code from untrusted sources to be loaded. Can be used to + load models containing custom code. SECURITY WARNING: This could execute arbitrary code. Only enable this if you + trust the source of the model. Defaults to False. + max_depth (int): Maximum recursion depth for nested structures. Defaults to 100. + max_string_length (int): Maximum allowed string length. Defaults to 100,000. + max_collection_size (int): Maximum allowed collection size. Defaults to 10,000. Returns ------- @@ -758,9 +775,16 @@ def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = Raises ------ FileNotFoundError: If the specified file does not exist. + DeserializationError: If deserialization fails due to security violations or resource limits. + SecurityViolationError: If a security violation is detected during deserialization. + ResourceLimitError: If resource limits are exceeded during deserialization. Example: + # Safe mode (default) - only allows trusted modules loaded_model = load("/path/to/model", output_type=MyModel) + + # Unrestricted mode - allows all modules (use with caution) + loaded_model = load("/path/to/model", output_type=MyModel, trust_remote_code=True) """ _path = Path(path) _thread_local.output_dir = _path @@ -803,7 +827,14 @@ def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = if root_key: json_config["root"]["key"] = root_key - config = serialization.Deserialization(json_config).result + config = SafeDeserialization( + json_config, + trust_remote_code=trust_remote_code, + max_depth=max_depth, + max_string_length=max_string_length, + max_collection_size=max_collection_size, + ).result + _artifact_transform_load(config, path) drop_unexpected_params(config) diff --git a/tests/lightning/_io/test_safe_serialization.py b/tests/lightning/_io/test_safe_serialization.py new file mode 100644 index 000000000000..965fe7b3376f --- /dev/null +++ b/tests/lightning/_io/test_safe_serialization.py @@ -0,0 +1,388 @@ +import sys +import threading +from pathlib import Path, PosixPath, WindowsPath + +import pytest + +from nemo.lightning.io.deserialize import DeserializationError, SafeDeserialization, SafePyrefPolicy + + +class TestSafePyrefPolicy: + """Tests for the SafePyrefPolicy class.""" + + def test_trusted_modules(self): + """Test that only trusted modules are allowed when trust_remote_code=False.""" + policy = SafePyrefPolicy(trust_remote_code=False) + + # Should allow trusted modules and their submodules + assert policy.allows_import("nemo", "Model") # Base module + assert policy.allows_import("nemo.core", "Model") # Submodule + assert policy.allows_import("torch", "nn") # Another trusted module + assert policy.allows_import("torch.nn", "Module") # Submodule + + # Should block untrusted modules + assert not policy.allows_import("os", "system") + assert not policy.allows_import("subprocess", "run") + assert not policy.allows_import("untrusted_module", "function") + + def test_blocked_modules(self): + """Test that blocked modules are always blocked, even in unsafe mode.""" + policy = SafePyrefPolicy(trust_remote_code=True) + + # System and OS operations + assert not policy.allows_import("os", "system") + assert not policy.allows_import("os.path", "exists") + assert not policy.allows_import("sys", "exit") + assert not policy.allows_import("subprocess", "run") + assert not policy.allows_import("subprocess", "Popen") + assert not policy.allows_import("shutil", "rmtree") + + # Serialization and code execution + assert not policy.allows_import("pickle", "loads") + assert not policy.allows_import("marshal", "loads") + assert not policy.allows_import("shelve", "open") + assert not policy.allows_import("code", "InteractiveInterpreter") + assert not policy.allows_import("codeop", "compile_command") + + # File and I/O operations + assert not policy.allows_import("io", "open") + assert not policy.allows_import("tempfile", "NamedTemporaryFile") + assert not policy.allows_import("pathlib", "Path") + assert not policy.allows_import("zipfile", "ZipFile") + assert not policy.allows_import("tarfile", "TarFile") + + # Network and IPC + assert not policy.allows_import("socket", "socket") + assert not policy.allows_import("asyncio", "create_task") + assert not policy.allows_import("multiprocessing", "Process") + assert not policy.allows_import("threading", "Thread") + + # System information and configuration + assert not policy.allows_import("platform", "system") + assert not policy.allows_import("pwd", "getpwnam") + assert not policy.allows_import("grp", "getgrnam") + assert not policy.allows_import("resource", "getrlimit") + + # Package management and imports + assert not policy.allows_import("importlib", "import_module") + assert not policy.allows_import("pkg_resources", "require") + assert not policy.allows_import("setuptools", "setup") + assert not policy.allows_import("distutils", "core") + + # Low-level system access + assert not policy.allows_import("ctypes", "CDLL") + assert not policy.allows_import("mmap", "mmap") + assert not policy.allows_import("fcntl", "fcntl") + assert not policy.allows_import("signal", "signal") + + # Debug and development + assert not policy.allows_import("pdb", "set_trace") + assert not policy.allows_import("trace", "Trace") + assert not policy.allows_import("gc", "collect") + + # XML processing (potential security risks) + assert not policy.allows_import("xml.etree.ElementTree", "parse") + assert not policy.allows_import("xml.sax", "parse") + assert not policy.allows_import("xml.dom", "parseString") + + # Encoding and crypto (potential security risks) + assert not policy.allows_import("base64", "b64decode") + assert not policy.allows_import("codecs", "encode") + assert not policy.allows_import("crypt", "crypt") + + # Terminal and process control + assert not policy.allows_import("pty", "spawn") + assert not policy.allows_import("tty", "setraw") + assert not policy.allows_import("termios", "tcgetattr") + assert not policy.allows_import("pipes", "Template") + + # System logging + assert not policy.allows_import("syslog", "syslog") + + # Web-related (potential security risks) + assert not policy.allows_import("wsgiref", "simple_server") + assert not policy.allows_import("http.server", "HTTPServer") + assert not policy.allows_import("urllib.request", "urlopen") + + # Module and class inspection + assert not policy.allows_import("inspect", "getsource") + assert not policy.allows_import("dis", "dis") + assert not policy.allows_import("ast", "parse") + + def test_dangerous_builtins(self): + """Test that dangerous builtins are always blocked.""" + policy = SafePyrefPolicy(trust_remote_code=True) + + # Should block dangerous builtins even in unsafe mode + assert not policy.allows_import("builtins", "eval") + assert not policy.allows_import("builtins", "exec") + assert not policy.allows_import("builtins", "__import__") + + def test_name_validation(self): + """Test validation of module and symbol names.""" + policy = SafePyrefPolicy() + + # Valid names + assert policy._validate_name("valid_name") + assert policy._validate_name("ValidName123") + + # Invalid names + assert not policy._validate_name("_hidden") + assert not policy._validate_name("1invalid") + assert not policy._validate_name("invalid-name") + assert not policy._validate_name("../path") + + +class TestSafeDeserialization: + """Tests for the SafeDeserialization class.""" + + def test_basic_deserialization(self): + """Test basic deserialization of safe types.""" + safe_data = { + "root": {"type": "leaf", "value": {"key": "value", "number": 42, "boolean": True}}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + result = SafeDeserialization(safe_data).result + assert result == {"key": "value", "number": 42, "boolean": True} + + def test_blocked_imports(self): + """Test that dangerous imports are blocked.""" + dangerous_data = { + "root": {"type": {"type": "pyref", "module": "os", "name": "system"}, "items": [], "metadata": None}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError) as exc_info: + SafeDeserialization(dangerous_data).result + + # Verify that the error message indicates a security violation + assert "not permitted by the active Python reference policy" in str(exc_info.value) + + def test_resource_limits(self): + """Test resource limits during deserialization.""" + # Create large nested structure + large_data = { + "root": {"type": "leaf", "value": "x" * 1_000_000}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError) as exc_info: + SafeDeserialization(large_data, max_string_length=100_000).result + + # Verify that the error message indicates a string length violation + assert "String length" in str(exc_info.value) + assert "exceeds limit" in str(exc_info.value) + + +class TestSecurityFeatures: + """Tests for various security features.""" + + def test_dangerous_special_methods(self): + """Test blocking of objects with dangerous special methods.""" + + class DangerousClass: + def __init__(self): + pass + + def __call__(self): + return "Dangerous" + + dangerous_obj = DangerousClass() + policy = SafePyrefPolicy(trust_remote_code=False) + + assert not policy.allows_value(dangerous_obj) + + def test_callable_objects(self): + """Test blocking of callable objects.""" + + def dangerous_function(): + return "Dangerous" + + policy = SafePyrefPolicy(trust_remote_code=False) + assert not policy.allows_value(dangerous_function) + + def test_module_references(self): + """Test blocking of objects with references to untrusted modules.""" + + class ModuleReferenceClass: + def __init__(self): + self.__module__ = "os" + + obj = ModuleReferenceClass() + policy = SafePyrefPolicy(trust_remote_code=False) + + assert not policy.allows_value(obj) + + def test_suspicious_strings(self): + """Test detection of suspicious strings.""" + policy = SafePyrefPolicy() + + # Should block strings with suspicious patterns + suspicious_strings = [ + "__dangerous__", + "\\x41\\x42\\x43", # Hex escape + "\\141\\142\\143", # Octal escape + "\\u0041\\u0042", # Unicode escape + ] + + for s in suspicious_strings: + assert not policy.allows_value(s) + + def test_nested_dangerous_objects(self): + """Test detection of dangerous objects nested in safe containers.""" + dangerous_cases = [ + # Lambda in list + { + "root": {"type": "leaf", "value": {"safe_key": [1, 2, {"unsafe": lambda x: x}]}}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + }, + # Function in dict + { + "root": {"type": "leaf", "value": {"unsafe": (lambda: None)}}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + }, + # Method in nested structure + { + "root": {"type": "leaf", "value": [{"nested": {"deep": {"unsafe": str.strip}}}]}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + }, + ] + + for case in dangerous_cases: + with pytest.raises(DeserializationError) as exc_info: + SafeDeserialization(case).result + assert "Callable objects are not allowed" in str(exc_info.value) + + def test_type_confusion(self): + """Test prevention of type confusion attacks.""" + confusing_data = { + "root": {"type": "leaf", "value": type("DynamicType", (), {"__call__": lambda self: None})}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError): + SafeDeserialization(confusing_data).result + + def test_non_string_dict_keys(self): + """Test that dictionary keys must be strings.""" + invalid_data = { + "root": {"type": "leaf", "value": {1: "value"}}, # numeric key + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError) as exc_info: + SafeDeserialization(invalid_data).result + assert "Dictionary keys must be strings" in str(exc_info.value) + + +class TestErrorHandling: + """Tests for error handling and reporting.""" + + def test_deserialization_error(self): + """Test proper error handling during deserialization.""" + invalid_data = { + "root": {"type": "invalid_type", "items": [], "metadata": None}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError): + SafeDeserialization(invalid_data).result + + def test_security_violation_error(self): + """Test security violation error handling.""" + policy = SafePyrefPolicy(trust_remote_code=False) + + # Should return False for dangerous imports + assert not policy.allows_import("os", "system") + + # The SecurityViolationError should be raised during actual deserialization + dangerous_data = { + "root": {"type": {"type": "pyref", "module": "os", "name": "system"}, "items": [], "metadata": None}, + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + with pytest.raises(DeserializationError) as exc_info: + SafeDeserialization(dangerous_data).result + + # Verify that the error message indicates a security violation + assert "not permitted" in str(exc_info.value).lower() + + def test_resource_limit_error(self): + """Test resource limit error handling.""" + policy = SafePyrefPolicy(max_collection_size=5) + + with pytest.raises(ValueError): + policy._check_collection_size([1, 2, 3, 4, 5, 6]) + + +class TestThreadSafety: + """Tests for thread safety.""" + + def test_concurrent_deserialization(self): + """Test concurrent deserialization operations.""" + safe_data = {"root": {"type": "leaf", "value": 42}, "objects": {}, "refcounts": {}, "version": "0.0.1"} + + results = [] + errors = [] + + def deserialize(): + try: + result = SafeDeserialization(safe_data).result + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=deserialize) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results) == 10 + assert all(r == 42 for r in results) + assert len(errors) == 0 + + +class TestPathSerialization: + """Tests for Path object serialization and deserialization.""" + + def test_path_platform_specific(self): + """Test that Path objects work correctly on different platforms.""" + path_data = { + "root": {"type": "leaf", "value": Path("/some/path")}, # Changed to leaf type # Direct Path object + "objects": {}, + "refcounts": {}, + "version": "0.0.1", + } + + result = SafeDeserialization(path_data).result + assert isinstance(result, Path) + + # Check that we get the right platform-specific path type + if sys.platform == 'win32': + assert isinstance(result, WindowsPath) + assert str(result) == "\\some\\path" + else: + assert isinstance(result, PosixPath) + assert str(result) == "/some/path"