Skip to content

Commit

Permalink
revert to singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Nov 17, 2024
1 parent cd59dbc commit 5833f65
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 131 deletions.
4 changes: 2 additions & 2 deletions mocket/entry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections.abc

import mocket.state
from mocket.compat import encode_to_bytes
from mocket.mocket import Mocket


class MocketEntry:
Expand Down Expand Up @@ -43,7 +43,7 @@ def can_handle(data):

def collect(self, data):
req = self.request_cls(data)
mocket.state.state.collect(req)
Mocket.collect(req)

def get_response(self):
response = self.responses[self.response_index]
Expand Down
11 changes: 6 additions & 5 deletions mocket/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

import mocket.state

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
Expand Down Expand Up @@ -44,11 +42,12 @@ def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext

mocket.state.state._namespace = namespace
mocket.state.state._truesocket_recording_dir = truesocket_recording_dir
Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
Expand Down Expand Up @@ -92,6 +91,8 @@ def enable(


def disable() -> None:
from mocket.mocket import Mocket

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
Expand Down Expand Up @@ -121,7 +122,7 @@ def disable() -> None:
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
mocket.state.state.reset()
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
4 changes: 2 additions & 2 deletions mocket/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import os

import mocket.state
from mocket.mocket import Mocket


class MocketSocketCore(io.BytesIO):
Expand All @@ -12,6 +12,6 @@ def __init__(self, address) -> None:
def write(self, content):
super().write(content)

_, w_fd = mocket.state.state.get_pair(self._address)
_, w_fd = Mocket.get_pair(self._address)
if w_fd:
os.write(w_fd, content)
105 changes: 95 additions & 10 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,103 @@
from typing import cast
from __future__ import annotations

import collections
import itertools
import os
from typing import TYPE_CHECKING, ClassVar

import mocket.inject
import mocket.state

# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.socket import MocketSocket as MocketSocket
# from mocket.socket import MocketSocket as MocketSocket

if TYPE_CHECKING:
from mocket.entry import MocketEntry
from mocket.types import Address


class Mocket:
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
_requests: ClassVar[list] = []
_namespace: ClassVar[str] = str(id(_entries))
_truesocket_recording_dir: ClassVar[str | None] = None

enable = mocket.inject.enable
disable = mocket.inject.disable

@classmethod
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
"""
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
return cls._socket_pairs.get(address, (None, None))

@classmethod
def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
"""
Store a pair of file descriptors under the key `id_`
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries: MocketEntry) -> None:
for entry in entries:
cls._entries[entry.location].append(entry)

@classmethod
def get_entry(cls, host: str, port: int, data) -> MocketEntry | None:
host = host or cls._address[0]
port = port or cls._address[1]
entries = cls._entries.get((host, port), [])
for entry in entries:
if entry.can_handle(data):
return entry
return None

@classmethod
def collect(cls, data) -> None:
cls._requests.append(data)

@classmethod
def reset(cls) -> None:
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []

@classmethod
def last_request(cls):
if cls.has_requests():
return cls._requests[-1]

@classmethod
def request_list(cls):
return cls._requests

@classmethod
def remove_last_request(cls) -> None:
if cls.has_requests():
del cls._requests[-1]

@classmethod
def has_requests(cls) -> bool:
return bool(cls.request_list())

class _Mocket(mocket.state.MocketState):
def __init__(self) -> None:
self.enable = mocket.inject.enable
self.disable = mocket.inject.disable
@classmethod
def get_namespace(cls) -> str:
return cls._namespace

@classmethod
def get_truesocket_recording_dir(cls) -> str | None:
return cls._truesocket_recording_dir

Mocket = cast(_Mocket, mocket.state.state)
Mocket.enable = mocket.inject.enable
Mocket.disable = mocket.inject.disable
@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
"""Mocket checks that all entries have been served at least once."""
if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
raise AssertionError("Some Mocket entries have not been served")
10 changes: 5 additions & 5 deletions mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from h11 import SERVER, Connection, Data
from h11 import Request as H11Request

import mocket.state
from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes
from mocket.entry import MocketEntry
from mocket.mocket import Mocket

STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
CRLF = "\r\n"
Expand Down Expand Up @@ -165,7 +165,7 @@ def collect(self, data):

decoded_data = decode_from_bytes(data)
if not decoded_data.startswith(Entry.METHODS):
mocket.state.state.remove_last_request()
Mocket.remove_last_request()
self._sent_data += data
consume_response = False
else:
Expand All @@ -188,7 +188,7 @@ def can_handle(self, data):
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
method, path, _ = self._parse_requestline(requestline)
except ValueError:
return self is getattr(mocket.state.state, "_last_entry", None)
return self is getattr(Mocket, "_last_entry", None)

uri = urlsplit(path)
can_handle = uri.path == self.path and method == self.method
Expand All @@ -198,7 +198,7 @@ def can_handle(self, data):
self.query, **kw
)
if can_handle:
mocket.state.state._last_entry = self
Mocket._last_entry = self
return can_handle

@staticmethod
Expand Down Expand Up @@ -234,7 +234,7 @@ def register(cls, method, uri, *responses, **config):
if config["add_trailing_slash"] and not urlsplit(uri).path:
uri += "/"

mocket.state.state.register(
Mocket.register(
cls(uri, method, responses, match_querystring=config["match_querystring"])
)

Expand Down
4 changes: 2 additions & 2 deletions mocket/mockredis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from itertools import chain

import mocket.state
from mocket.compat import (
decode_from_bytes,
encode_to_bytes,
shsplit,
)
from mocket.entry import MocketEntry
from mocket.mocket import Mocket


class Request:
Expand Down Expand Up @@ -80,7 +80,7 @@ def register(cls, addr, command, *responses):
r if isinstance(r, BaseException) else cls.response_cls(r)
for r in responses
]
mocket.state.state.register(cls(addr, command, responses))
Mocket.register(cls(addr, command, responses))

@classmethod
def register_response(cls, command, response, addr=None):
Expand Down
4 changes: 2 additions & 2 deletions mocket/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import TYPE_CHECKING, Any, ClassVar

import mocket.state
from mocket.exceptions import StrictMocketException
from mocket.mocket import Mocket

if TYPE_CHECKING: # pragma: no cover
from typing import NoReturn
Expand Down Expand Up @@ -34,7 +34,7 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool:
def raise_not_allowed() -> NoReturn:
current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in mocket.state.state._entries.items()
for location, entries in Mocket._entries.items()
]
formatted_entries = "\n".join(
[f" {location}:\n {entries}" for location, entries in current_entries]
Expand Down
8 changes: 4 additions & 4 deletions mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import mocket.inject
import mocket.state
from mocket import mocketize
from mocket.async_mocket import async_mocketize
from mocket.compat import ENCODING
from mocket.mocket import Mocket
from mocket.mockhttp import Entry as MocketHttpEntry
from mocket.mockhttp import Request as MocketHttpRequest
from mocket.mockhttp import Response as MocketHttpResponse
Expand Down Expand Up @@ -49,7 +49,7 @@ class Entry(MocketHttpEntry):

enable = mocket.inject.enable
disable = mocket.inject.disable
reset = mocket.state.state.reset
reset = Mocket.reset

GET = Entry.GET
PUT = Entry.PUT
Expand Down Expand Up @@ -104,9 +104,9 @@ class MocketHTTPretty:

def __getattr__(self, name):
if name == "last_request":
return mocket.state.state.last_request()
return Mocket.last_request()
if name == "latest_requests":
return mocket.state.state.request_list()
return Mocket.request_list()
return getattr(Entry, name)


Expand Down
4 changes: 2 additions & 2 deletions mocket/plugins/pook_mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
MockEngine = object

import mocket.inject
import mocket.state
from mocket.mocket import Mocket
from mocket.mockhttp import Entry, Response


Expand Down Expand Up @@ -36,7 +36,7 @@ def single_register(
[Response(body=body, status=status, headers=headers)],
match_querystring=match_querystring,
)
mocket.state.state.register(entry)
Mocket.register(entry)
return entry


Expand Down
Loading

0 comments on commit 5833f65

Please sign in to comment.