Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility decorator to log time spent in function calls #317

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/zino/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from zino.config.models import DEFAULT_INTERVAL_MINUTES, PollDevice
from zino.config.polldevs import InvalidConfiguration, read_polldevs
from zino.tasks import run_all_tasks
from zino.utils import log_time_spent

_log = logging.getLogger(__name__)
_scheduler = None
Expand All @@ -38,6 +39,7 @@ def get_scheduler() -> AsyncIOScheduler:
return _scheduler


@log_time_spent()
def load_polldevs(polldevs_conf: str) -> Tuple[Set, Set]:
"""Loads pollfile into process state.

Expand Down
3 changes: 3 additions & 0 deletions src/zino/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from zino.flaps import FlappingStates
from zino.planned_maintenance import PlannedMaintenances
from zino.statemodels import DeviceStates
from zino.utils import log_time_spent

_log = logging.getLogger(__name__)

Expand All @@ -34,13 +35,15 @@ class ZinoState(BaseModel):
planned_maintenances: PlannedMaintenances = Field(default_factory=PlannedMaintenances)
flapping: FlappingStates = Field(default_factory=FlappingStates)

@log_time_spent()
def dump_state_to_file(self, filename: str):
"""Dumps the full state to a file in JSON format"""
_log.debug("dumping state to %s", filename)
with open(filename, "w") as statefile:
statefile.write(self.model_dump_json(exclude_none=True, indent=2))

@classmethod
@log_time_spent()
def load_state_from_file(cls, filename: str) -> Optional["ZinoState"]:
"""Loads and returns a previously persisted ZinoState from a JSON file dump.

Expand Down
26 changes: 25 additions & 1 deletion src/zino/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import logging
from functools import wraps
from ipaddress import ip_address
from typing import Optional
from time import time
from typing import Optional, Union

import aiodns
from pyasn1.type.univ import OctetString
Expand Down Expand Up @@ -39,3 +42,24 @@ async def reverse_dns(ip: str) -> Optional[str]:
return response.name
except aiodns.error.DNSError:
return None


def log_time_spent(logger: Union[logging.Logger, str] = __name__, level: int = logging.DEBUG):
"""Decorator that logs the time taken for a function to execute. Not suitable for use with async functions"""
if isinstance(logger, str):
logger = logging.getLogger(logger)

def actual_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time()
try:
result = func(*args, **kwargs)
finally:
end = time()
logger.log(level, "%s took %s seconds", func.__name__, end - start)
return result

return wrapper

return actual_decorator
21 changes: 20 additions & 1 deletion tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from ipaddress import IPv4Address, IPv6Address
from unittest.mock import AsyncMock, MagicMock

import aiodns
import pytest

from zino.utils import parse_ip, reverse_dns
from zino.utils import log_time_spent, parse_ip, reverse_dns


class TestParseIP:
Expand Down Expand Up @@ -78,6 +79,24 @@ async def test_should_return_none_for_invalid_ip(self, mock_dnsresolver):
assert result is None


class TestLogTimeSpent:
def test_when_logger_is_specified_it_should_log_time_spent_using_that_logger(self, caplog):
@log_time_spent(logger="test_logger", level=logging.DEBUG)
def test_function():
pass

with caplog.at_level(logging.DEBUG):
test_function()

assert any(
record.name == "test_logger"
and "took" in record.msg
and "seconds" in record.msg
and "test_function" in record.args
for record in caplog.records
)


@pytest.fixture
def mock_dnsresolver(monkeypatch) -> AsyncMock:
mock_dnsresolver = AsyncMock()
Expand Down
Loading