diff --git a/pyproject.toml b/pyproject.toml index 5cd7e24..a70647f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "simplyprint-ws-client" -version = "1.0.0-rc.4" +version = "1.0.0-rc.5" license = "AGPL-3.0-or-later" authors = ["SimplyPrint "] description = "SimplyPrint Websocket Client" diff --git a/simplyprint_ws_client/client/app.py b/simplyprint_ws_client/client/app.py index e7aa3e7..0a46a69 100644 --- a/simplyprint_ws_client/client/app.py +++ b/simplyprint_ws_client/client/app.py @@ -13,6 +13,7 @@ from ..const import APP_DIRS from ..helpers.sentry import Sentry from ..helpers.url_builder import SimplyPrintUrl +from ..utils import traceability from ..utils.event_loop_runner import EventLoopRunner @@ -136,17 +137,18 @@ def add_new_client(self, config: Optional[Config]) -> asyncio.Future: def reload_client(self, client: Client) -> asyncio.Future: return asyncio.run_coroutine_threadsafe(self._reload_client(client), self.instance.event_loop) - def run_blocking(self): + def run_blocking(self, enable_tracing=False): with EventLoopRunner() as runner: - runner.run(self.run()) + with traceability.enable_traceable(enable_tracing): + runner.run(self.run()) - def run_detached(self): + def run_detached(self, *args, **kwargs): """ Run the client in a separate thread. """ if self.instance_thread: self.logger.warning("Client instance already running - stopping old instance") self.stop() - self.instance_thread = threading.Thread(target=self.run_blocking) + self.instance_thread = threading.Thread(target=self.run_blocking, args=args, kwargs=kwargs) self.instance_thread.start() def stop(self): diff --git a/simplyprint_ws_client/client/client.py b/simplyprint_ws_client/client/client.py index d8a480d..95d78fd 100644 --- a/simplyprint_ws_client/client/client.py +++ b/simplyprint_ws_client/client/client.py @@ -14,6 +14,7 @@ from ..helpers.intervals import IntervalTypes, Intervals from ..helpers.physical_machine import PhysicalMachine from ..utils.event_loop_provider import EventLoopProvider +from ..utils.traceability import traceable class ClientConfigurationException(Exception): @@ -75,14 +76,12 @@ def __init__( # Recover handles from the class # TODO: Generalize this under the event system. for name in dir(self): - if not hasattr(self, name): - continue - - attr = getattr(self, name) - - if hasattr(attr, "_event"): - event_cls = attr._event - self.event_bus.on(event_cls, attr, attr._pre) + try: + attr = getattr(self, name) + event_cls = getattr(attr, "_event") + self.event_bus.on(event_cls, attr, getattr(attr, "_pre")) + except (AttributeError, RuntimeError): + pass async def __aenter__(self): """ Acquire a client to perform order sensitive operations.""" @@ -94,13 +93,15 @@ async def __aexit__(self, exc_type, exc, tb): self._client_lock.release() @property + @traceable(with_retval=True, with_stack=True) def connected(self) -> bool: """ Check if the client is connected to the server. """ - return self._connected + return self._connected and self.is_external_connected() @connected.setter + @traceable(with_args=True, with_stack=True) def connected(self, value: bool): self._connected = value @@ -146,6 +147,13 @@ def set_ui_info(self, ui: str, ui_version: str): self.printer.info.ui = ui self.printer.info.ui_version = ui_version + def is_external_connected(self): + """ + Check if the client is connected to an external device. + """ + + return True + @abstractmethod async def init(self): """ diff --git a/simplyprint_ws_client/client/lifetime/lifetime_manager.py b/simplyprint_ws_client/client/lifetime/lifetime_manager.py index ab42b3b..6c65e06 100644 --- a/simplyprint_ws_client/client/lifetime/lifetime_manager.py +++ b/simplyprint_ws_client/client/lifetime/lifetime_manager.py @@ -1,11 +1,15 @@ import logging from enum import Enum -from typing import Dict +from typing import Dict, TYPE_CHECKING from .lifetime import ClientLifetime, ClientAsyncLifetime from ..client import Client +from ...utils import traceability from ...utils.stoppable import AsyncStoppable +if TYPE_CHECKING: + from ..instance import Instance + class LifetimeType(Enum): ASYNC = 0 @@ -22,11 +26,15 @@ def get_cls(self): class LifetimeManager(AsyncStoppable): logger: logging.Logger lifetime_check_interval = 10 + + instance: 'Instance' lifetimes: Dict[Client, ClientLifetime] - def __init__(self, *args, **kwargs): + def __init__(self, instance: 'Instance', *args, **kwargs): super().__init__(*args, **kwargs) - self.logger = logging.getLogger("lifetime_manager") + + self.logger = instance.logger.getChild("lifetime_manager") + self.instance = instance self.lifetimes = {} def contains(self, client: Client) -> bool: @@ -55,13 +63,21 @@ async def loop(self) -> None: if lifetime.is_stopped(): continue - if lifetime.is_healthy(): + if not lifetime.is_healthy(): + client.logger.warning(f"Client lifetime unhealthy - restarting") + await self.restart_lifetime(client) continue - client.logger.warning(f"Client lifetime unhealthy - restarting") + if self.instance.connection.is_connected() and not client._connected: + connected_trace = traceability.from_class(client).get("connected", None) - await self.stop_lifetime(client) - await self.start_lifetime(client) + client.logger.warning( + f"Instance is connected but client has not received connected event yet. Last {len(connected_trace.call_record)} traces:") + + for record in connected_trace.get_call_record(): + client.logger.warning( + f"[{record.called_at}] Called connected with args {record.args} retval {record.retval}", + exc_info=record.stack) await self.wait(self.lifetime_check_interval) @@ -91,6 +107,10 @@ async def stop_lifetime(self, client: Client) -> None: lifetime.stop() + async def restart_lifetime(self, client: Client) -> None: + await self.stop_lifetime(client) + await self.start_lifetime(client) + def remove(self, client: Client) -> None: lifetime = self.lifetimes.pop(client, None) diff --git a/simplyprint_ws_client/connection/connection.py b/simplyprint_ws_client/connection/connection.py index d25b704..79c155d 100644 --- a/simplyprint_ws_client/connection/connection.py +++ b/simplyprint_ws_client/connection/connection.py @@ -13,6 +13,7 @@ from ..events.client_events import ClientEvent, ClientEventMode from ..events.event import Event from ..events.event_bus import EventBus +from ..utils.traceability import traceable class ConnectionPollEvent(Event): @@ -167,6 +168,7 @@ async def send_event(self, client: Client, event: ClientEvent) -> None: self.logger.error(f"Failed to send event {event}", exc_info=e) await self.on_disconnect() + @traceable async def poll_event(self, timeout=None) -> None: if not self.is_connected(): self.logger.debug(f"Did not poll event because not connected") diff --git a/simplyprint_ws_client/helpers/file_backup.py b/simplyprint_ws_client/helpers/file_backup.py index f49ea06..462fdb9 100644 --- a/simplyprint_ws_client/helpers/file_backup.py +++ b/simplyprint_ws_client/helpers/file_backup.py @@ -42,3 +42,24 @@ def backup_file(file: Path, max_count: int = 5, max_age: Optional[datetime.timed # Now create the new backup by copying the original file shutil.copy(file, file.parent / f"{file.name}.bak.0") + + @staticmethod + def strip_log_file(file: Path, max_size: int = 100 * 1024 * 1024): + """Strip a log file to a maximum size""" + + if not file.exists(): + return + + if file.stat().st_size <= max_size: + return + + # Use the size to start seeking from the end of the file + # and then read the file in chunks of 1024 bytes until we have read the last size + # then overwrite the file with the new content + with open(file, "r+") as f: + f.seek(0, 2) + f.seek(f.tell() - max_size, 0) + content = f.read() + f.seek(0) + f.write(content) + f.truncate() diff --git a/simplyprint_ws_client/utils/traceability.py b/simplyprint_ws_client/utils/traceability.py new file mode 100644 index 0000000..07764a6 --- /dev/null +++ b/simplyprint_ws_client/utils/traceability.py @@ -0,0 +1,224 @@ +import contextlib +import contextvars +import dataclasses +import time +import traceback +from collections import deque +from functools import wraps +from typing import Optional, List + +_traceability_enabled = contextvars.ContextVar("_traceability_enabled", default=False) + + +@contextlib.contextmanager +def enable_traceable(enabled=True): + token = _traceability_enabled.set(enabled) + + try: + yield + finally: + _traceability_enabled.reset(token) + + +def exception_as_value(*args, return_none=False, **kwargs): + """ Internal decorator to return an exception as a value + + Only used to minimize runtime overhead. + """ + + def decorator(func): + if not callable(func): + raise ValueError("exception_as_value decorator must be used on a callable") + + @wraps(func) + def wrapper(*fargs, **fkwargs): + try: + return func(*fargs, **fkwargs) + except Exception as e: + return e if not return_none else None + + return wrapper + + if args and callable(args[0]): + return decorator(args[0]) + + return decorator + + +def traceable_location_from_func(func, *args, **kwargs): + # If the function is a method, we store it on the instance + # therefore we suffix the key with the function name. + if hasattr(func, "__self__"): + return func.__self__, f"__traceability__{func.__name__}", True + + # For property functions if the first argument is a class + # and that class has the property, we store it on the class + if args and hasattr(args[0], '__class__') and hasattr(args[0].__class__, func.__name__): + return args[0], f"__traceability__{func.__name__}", True + + # Otherwise, we store it on the function itself + return func, f"__traceability__", False + + +# Collects traceability information for a function +# Into an object that can be used to trace the function call +def traceable(*args, record_calls=False, with_stack=False, with_args=False, with_retval=False, record_count=10, + **kwargs): + """ + :param record_calls: Whether to record the number of calls to the function + :param with_stack: Whether to record the stack of the function + :param with_args: Whether to record the arguments of the function + :param with_retval: Whether to record the return value of the function + :param record_count: The number of records to keep + """ + + should_record_calls = record_calls or with_stack or with_args or with_retval + + def decorator(func): + if not callable(func): + raise ValueError("traceable decorator must be used on a callable") + + # All functions also get a static key + _, key, _ = traceable_location_from_func(func) + + setattr(func, "__traceability__", Traceability( + last_called=None, + call_record=deque(maxlen=10) if should_record_calls else None + )) + + @wraps(func) + def wrapper(*fargs, **fkwargs): + # If traceability is disabled, we just call the function + # Getting a smaller runtime overhead. + if not _traceability_enabled.get(): + return func(*fargs, **fkwargs) + + obj, trace_key, remove_first_arg = traceable_location_from_func(func, *fargs, **fkwargs) + + if hasattr(obj, trace_key): + traceability = getattr(obj, trace_key) + else: + traceability = Traceability( + last_called=None, + call_record=deque(maxlen=10) if should_record_calls else None + ) + + setattr(obj, trace_key, traceability) + + traceability.last_called = time.time() + + retval = None + + try: + retval = func(*fargs, **fkwargs) + return retval + finally: + if should_record_calls: + record = TraceabilityRecord( + called_at=traceability.last_called, + args=(fargs[1:] if remove_first_arg else fargs) if with_args else None, + kwargs=fkwargs if with_args else None, + retval=retval if with_retval else None, + stack=None, + ) + + if with_stack: + record.stack = traceback.format_stack() + + traceability.call_record.append(record) + + return wrapper + + if args and callable(args[0]): + return decorator(args[0]) + + return decorator + + +@exception_as_value(return_none=True) +def from_func(func): + obj, key, _ = traceable_location_from_func(func) + + if not hasattr(obj, key): + raise ValueError("Function does not have traceability information") + + traceability = getattr(obj, key) + + if not isinstance(traceability, Traceability): + raise ValueError("Traceability information is not of the correct type") + + return traceability + + +@exception_as_value(return_none=True) +def from_property(prop: property): + return from_func(prop.fget), from_func(prop.fset) + + +def from_class_instance(cls): + # Find all properties starting with __traceability__ + traces = { + name[len("__traceability__"):]: value for name, value in cls.__dict__.items() + if name.startswith("__traceability__") + } + + # Return name: Traceability + return { + name: value for name, value in traces.items() + if isinstance(value, Traceability) + } + + +def from_class_static(cls): + # Find all callables that have the property __traceability__ + traces = { + name: value for name, value in cls.__dict__.items() + if hasattr(value, "__traceability__") + } + + return { + name: from_func(value) for name, value in traces.items() + } + + +@exception_as_value(return_none=True) +def from_class(cls): + if isinstance(cls, type): + return from_class_static(cls) + + return from_class_instance(cls) + + +@dataclasses.dataclass(slots=True) +class TraceabilityRecord: + called_at: float + args: Optional[tuple] = None + kwargs: Optional[dict] = None + retval: Optional[object] = None + stack: Optional[List[str]] = None + + +@dataclasses.dataclass(slots=True) +class Traceability: + last_called: Optional[float] + call_record: Optional[deque[TraceabilityRecord]] = None + + def stats(self): + return { + "last_called": self.last_called, + "delta_called": time.time() - self.last_called + } + + def get_call_record(self): + return list(self.call_record) if self.call_record else [] + + +__all__ = [ + "traceable", + "enable_traceable", + "from_func", + "from_property", + "from_class", + "Traceability", + "TraceabilityRecord" +] diff --git a/tests/test_traceability.py b/tests/test_traceability.py new file mode 100644 index 0000000..8ac9853 --- /dev/null +++ b/tests/test_traceability.py @@ -0,0 +1,156 @@ +import unittest + +from simplyprint_ws_client.client import Client +from simplyprint_ws_client.client.config import Config +from simplyprint_ws_client.utils import traceability + + +class TestTraceability(unittest.IsolatedAsyncioTestCase): + async def test_client_traceability(self): + class TestClient(Client): + async def init(self): + pass + + async def tick(self): + pass + + async def stop(self): + pass + + with traceability.enable_traceable(): + client = TestClient(config=Config.get_new()) + + class_traces = traceability.from_class(client) + + self.assertTrue("connected" in class_traces) + + connected_traces = class_traces["connected"] + + # Accessed once via the event initializer. + self.assertEqual(len(connected_traces.get_call_record()), 1) + + # Purge it. + connected_traces.call_record.pop() + + async with client: + client.connected = True + + self.assertEqual(len(connected_traces.get_call_record()), 1) + + connected_record = connected_traces.call_record.pop() + + self.assertEqual(connected_record.args, (True,)) + + def test_record_class_isolation(self): + class TestClass: + @traceability.traceable(with_retval=True) + def traceable_function(self, a: int, b: int) -> int: + return a + b + + test_instance_1 = TestClass() + test_instance_2 = TestClass() + + with traceability.enable_traceable(): + test_instance_1.traceable_function(1, 2) + test_instance_1.traceable_function(2, 3) + test_instance_2.traceable_function(3, 4) + + trace_1 = traceability.from_class(test_instance_1)['traceable_function'] + trace_2 = traceability.from_class(test_instance_2)['traceable_function'] + + self.assertTrue(trace_1 == traceability.from_func(test_instance_1.traceable_function)) + self.assertFalse(trace_1 == traceability.from_func(TestClass.traceable_function)) + self.assertFalse(trace_1 == trace_2) + + self.assertEqual(len(trace_1.get_call_record()), 2) + self.assertEqual(len(trace_2.get_call_record()), 1) + + def test_record_traceability_property_getter_and_setter(self): + class TestClass: + def __init__(self): + self._value = 0 + + @property + @traceability.traceable(with_retval=True) + def value(self): + return self._value + + @value.setter + @traceability.traceable(with_args=True) + def value(self, value): + self._value = value + + test_instance = TestClass() + + with traceability.enable_traceable(): + test_instance.value = 1 + self.assertEqual(test_instance.value, 1) + + trace = traceability.from_class(test_instance).get('value') + + self.assertEqual(len(trace.get_call_record()), 2) + + trace_set_record = trace.call_record.popleft() + + self.assertEqual(trace_set_record.args, (1,)) + + trace_get_record = trace.call_record.popleft() + + self.assertEqual(trace_get_record.retval, 1) + + def test_record_traceability(self): + @traceability.traceable(with_retval=True) + def traceable_function_a(a: int, b: int) -> int: + return a + b + + @traceability.traceable(with_args=True, with_stack=True) + def traceable_function_b(a: int, b: int, **kwargs) -> int: + return a * b + + with traceability.enable_traceable(): + traceable_function_a(1, 2) + traceable_function_b(1, 2) + + trace_a = traceability.from_func(traceable_function_a) + trace_b = traceability.from_func(traceable_function_b) + + self.assertEqual(len(trace_a.get_call_record()), 1) + self.assertEqual(len(trace_b.get_call_record()), 1) + + trace_a_record = trace_a.call_record.pop() + trace_b_record = trace_b.call_record.pop() + + self.assertEqual(trace_a_record.args, None) + self.assertEqual(trace_a_record.retval, 3) + + self.assertEqual(trace_b_record.args, (1, 2)) + self.assertEqual(trace_b_record.retval, None) + + with traceability.enable_traceable(): + traceable_function_b(1, 2, extra_custom_arg="test") + + trace_b_record = trace_b.call_record.pop() + self.assertEqual(trace_b_record.args, (1, 2)) + self.assertEqual(trace_b_record.retval, None) + self.assertEqual(trace_b_record.kwargs, {"extra_custom_arg": "test"}) + self.assertEqual(any("traceable_function_b" in frame for frame in trace_b_record.stack), True) + + def test_basic_traceability(self): + @traceability.traceable + def traceable_function(a: int, b: int) -> int: + return a + b + + trace = traceability.from_func(traceable_function) + + self.assertEqual(trace.get_call_record(), []) + self.assertEqual(trace.last_called, None) + + traceable_function(1, 2) + + self.assertEqual(trace.get_call_record(), []) + self.assertEqual(trace.last_called, None) + + with traceability.enable_traceable(): + traceable_function(1, 2) + + self.assertNotEqual(trace.last_called, None)