Skip to content

Commit

Permalink
refactor: introduce state-object
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Nov 17, 2024
1 parent d6b3bb1 commit cd59dbc
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 154 deletions.
5 changes: 2 additions & 3 deletions mocket/entry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections.abc

import mocket.state
from mocket.compat import encode_to_bytes


Expand Down Expand Up @@ -41,10 +42,8 @@ def can_handle(data):
return True

def collect(self, data):
from mocket import Mocket

req = self.request_cls(data)
Mocket.collect(req)
mocket.state.state.collect(req)

def get_response(self):
response = self.responses[self.response_index]
Expand Down
11 changes: 5 additions & 6 deletions mocket/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

import mocket.state

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
Expand Down Expand Up @@ -42,12 +44,11 @@ def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
mocket.state.state._namespace = namespace
mocket.state.state._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
Expand Down Expand Up @@ -91,8 +92,6 @@ def enable(


def disable() -> None:
from mocket.mocket import Mocket

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
Expand Down Expand Up @@ -122,7 +121,7 @@ def disable() -> None:
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
mocket.state.state.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
6 changes: 3 additions & 3 deletions mocket/io.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import io
import os

import mocket.state


class MocketSocketCore(io.BytesIO):
def __init__(self, address) -> None:
self._address = address
super().__init__()

def write(self, content):
from mocket import Mocket

super().write(content)

_, w_fd = Mocket.get_pair(self._address)
_, w_fd = mocket.state.state.get_pair(self._address)
if w_fd:
os.write(w_fd, content)
101 changes: 9 additions & 92 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,18 @@
import collections
import itertools
import os
from typing import Optional, Tuple
from typing import cast

import mocket.inject
import mocket.state

# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.socket import MocketSocket as MocketSocket


class Mocket:
_socket_pairs = {}
_address = (None, None)
_entries = collections.defaultdict(list)
_requests = []
_namespace = str(id(_entries))
_truesocket_recording_dir = None
class _Mocket(mocket.state.MocketState):
def __init__(self) -> None:
self.enable = mocket.inject.enable
self.disable = mocket.inject.disable

@classmethod
def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
"""
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
return cls._socket_pairs.get(address, (None, None))

@classmethod
def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
"""
Store a pair of file descriptors under the key `id_`
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries):
for entry in entries:
cls._entries[entry.location].append(entry)

@classmethod
def get_entry(cls, host, port, data):
host = host or Mocket._address[0]
port = port or Mocket._address[1]
entries = cls._entries.get((host, port), [])
for entry in entries:
if entry.can_handle(data):
return entry

@classmethod
def collect(cls, data):
cls.request_list().append(data)

@classmethod
def reset(cls):
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []

@classmethod
def last_request(cls):
if cls.has_requests():
return cls.request_list()[-1]

@classmethod
def request_list(cls):
return cls._requests

@classmethod
def remove_last_request(cls):
if cls.has_requests():
del cls._requests[-1]

@classmethod
def has_requests(cls):
return bool(cls.request_list())

@classmethod
def get_namespace(cls):
return cls._namespace

@staticmethod
def enable(namespace=None, truesocket_recording_dir=None):
mocket.inject.enable(namespace, truesocket_recording_dir)

@staticmethod
def disable():
mocket.inject.disable()

@classmethod
def get_truesocket_recording_dir(cls):
return cls._truesocket_recording_dir

@classmethod
def assert_fail_if_entries_not_served(cls):
"""Mocket checks that all entries have been served at least once."""
if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
raise AssertionError("Some Mocket entries have not been served")
Mocket = cast(_Mocket, mocket.state.state)
Mocket.enable = mocket.inject.enable
Mocket.disable = mocket.inject.disable
8 changes: 3 additions & 5 deletions mocket/mocketizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mocket.inject
from mocket.mode import MocketMode
from mocket.utils import get_mocketize

Expand All @@ -23,9 +24,7 @@ def __init__(
)

def enter(self):
from mocket import Mocket

Mocket.enable(
mocket.inject.enable(
namespace=self.namespace,
truesocket_recording_dir=self.truesocket_recording_dir,
)
Expand All @@ -39,9 +38,8 @@ def __enter__(self):
def exit(self):
if self.instance:
self.check_and_call("mocketize_teardown")
from mocket import Mocket

Mocket.disable()
mocket.inject.disable()

def __exit__(self, type, value, tb):
self.exit()
Expand Down
10 changes: 5 additions & 5 deletions mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from h11 import SERVER, Connection, Data
from h11 import Request as H11Request

import mocket.state
from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes
from mocket.entry import MocketEntry
from mocket.mocket import Mocket

STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
CRLF = "\r\n"
Expand Down Expand Up @@ -165,7 +165,7 @@ def collect(self, data):

decoded_data = decode_from_bytes(data)
if not decoded_data.startswith(Entry.METHODS):
Mocket.remove_last_request()
mocket.state.state.remove_last_request()
self._sent_data += data
consume_response = False
else:
Expand All @@ -188,7 +188,7 @@ def can_handle(self, data):
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
method, path, _ = self._parse_requestline(requestline)
except ValueError:
return self is getattr(Mocket, "_last_entry", None)
return self is getattr(mocket.state.state, "_last_entry", None)

uri = urlsplit(path)
can_handle = uri.path == self.path and method == self.method
Expand All @@ -198,7 +198,7 @@ def can_handle(self, data):
self.query, **kw
)
if can_handle:
Mocket._last_entry = self
mocket.state.state._last_entry = self
return can_handle

@staticmethod
Expand Down Expand Up @@ -234,7 +234,7 @@ def register(cls, method, uri, *responses, **config):
if config["add_trailing_slash"] and not urlsplit(uri).path:
uri += "/"

Mocket.register(
mocket.state.state.register(
cls(uri, method, responses, match_querystring=config["match_querystring"])
)

Expand Down
4 changes: 2 additions & 2 deletions mocket/mockredis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from itertools import chain

import mocket.state
from mocket.compat import (
decode_from_bytes,
encode_to_bytes,
shsplit,
)
from mocket.entry import MocketEntry
from mocket.mocket import Mocket


class Request:
Expand Down Expand Up @@ -80,7 +80,7 @@ def register(cls, addr, command, *responses):
r if isinstance(r, BaseException) else cls.response_cls(r)
for r in responses
]
Mocket.register(cls(addr, command, responses))
mocket.state.state.register(cls(addr, command, responses))

@classmethod
def register_response(cls, command, response, addr=None):
Expand Down
5 changes: 2 additions & 3 deletions mocket/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Any, ClassVar

import mocket.state
from mocket.exceptions import StrictMocketException

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -31,11 +32,9 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool:

@staticmethod
def raise_not_allowed() -> NoReturn:
from mocket.mocket import Mocket

current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
for location, entries in mocket.state.state._entries.items()
]
formatted_entries = "\n".join(
[f" {location}:\n {entries}" for location, entries in current_entries]
Expand Down
14 changes: 8 additions & 6 deletions mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from mocket import Mocket, mocketize
import mocket.inject
import mocket.state
from mocket import mocketize
from mocket.async_mocket import async_mocketize
from mocket.compat import ENCODING
from mocket.mockhttp import Entry as MocketHttpEntry
Expand Down Expand Up @@ -45,9 +47,9 @@ class Entry(MocketHttpEntry):
httprettified = mocketize
async_httprettified = async_mocketize

enable = Mocket.enable
disable = Mocket.disable
reset = Mocket.reset
enable = mocket.inject.enable
disable = mocket.inject.disable
reset = mocket.state.state.reset

GET = Entry.GET
PUT = Entry.PUT
Expand Down Expand Up @@ -102,9 +104,9 @@ class MocketHTTPretty:

def __getattr__(self, name):
if name == "last_request":
return Mocket.last_request()
return mocket.state.state.last_request()
if name == "latest_requests":
return Mocket.request_list()
return mocket.state.state.request_list()
return getattr(Entry, name)


Expand Down
11 changes: 6 additions & 5 deletions mocket/plugins/pook_mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
except ModuleNotFoundError:
MockEngine = object

from mocket.mocket import Mocket
import mocket.inject
import mocket.state
from mocket.mockhttp import Entry, Response


Expand Down Expand Up @@ -35,7 +36,7 @@ def single_register(
[Response(body=body, status=status, headers=headers)],
match_querystring=match_querystring,
)
Mocket.register(entry)
mocket.state.state.register(entry)
return entry


Expand Down Expand Up @@ -64,12 +65,12 @@ def mocket_mock_fun(*args, **kwargs):
class MocketInterceptor(BaseInterceptor):
@staticmethod
def activate():
Mocket.disable()
Mocket.enable()
mocket.inject.disable()
mocket.inject.enable()

@staticmethod
def disable():
Mocket.disable()
mocket.inject.disable()

# Store plugins engine
self.engine = engine
Expand Down
Loading

0 comments on commit cd59dbc

Please sign in to comment.