Skip to content

Commit

Permalink
Rabbitmq fixes/enhancements for scheduler (#1976)
Browse files Browse the repository at this point in the history
Co-authored-by: ammar92 <[email protected]>
Co-authored-by: Jan Klopper <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent 56c8004 commit e622b3f
Show file tree
Hide file tree
Showing 16 changed files with 2,253 additions and 2,045 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ repos:
rev: v1.2.0
hooks:
- id: mypy
additional_dependencies: ['types-PyYAML', 'types-requests', 'types-cachetools', 'pydantic', 'pynacl']
additional_dependencies: ['types-PyYAML', 'types-requests', 'types-cachetools', 'types-retry', 'pydantic', 'pynacl']
exclude: |
(?x)(
^boefjes/ |
Expand Down
2 changes: 1 addition & 1 deletion mula/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ENV PATH=/home/scheduler/.local/bin:${PATH}
# dependencies
ARG ENVIRONMENT

COPY requirements.txt requirements-dev.txt .
COPY requirements.txt requirements-dev.txt ./
RUN --mount=type=cache,target=/root/.cache pip install --upgrade pip \
&& pip install -r requirements.txt \
&& if [ "$ENVIRONMENT" = "dev" ]; then pip install -r requirements-dev.txt; fi
Expand Down
1,175 changes: 620 additions & 555 deletions mula/poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mula/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pydantic = "^2.1.1"
pydantic-settings= "^2.0.2"
python-dotenv = "^1.0.0"
requests = "^2.31.0"
retry = "^0.9.2"
sqlalchemy = "^2.0.19"
typing-extensions = "^4.5.0"
uvicorn = "^0.17.1"

# OpenTelemetry
opentelemetry-sdk = "^1.19.0"
Expand All @@ -40,6 +40,7 @@ psutil = "^5.9.5"
pytest = "^7.2.2"
pytest-cov = "^4.0.0"
httpx = "^0.23.3"
uvicorn = "^0.17.1"

[build-system]
requires = ["setuptools>=59", "wheel"]
Expand Down
1,092 changes: 564 additions & 528 deletions mula/requirements-dev.txt

Large diffs are not rendered by default.

820 changes: 423 additions & 397 deletions mula/requirements.txt

Large diffs are not rendered by default.

68 changes: 55 additions & 13 deletions mula/scheduler/connectors/listeners/listeners.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import functools
import json
import logging
import urllib.parse
from typing import Dict, Optional
from concurrent import futures
from typing import Callable, Dict, Optional

import pika
from retry import retry

from ..connector import Connector # noqa: TID252

Expand Down Expand Up @@ -37,6 +40,14 @@ class RabbitMQ(Listener):
channel and procedure that needs to be dispatched when receiving messages
from a RabbitMQ queue.
To combat potential dropping of AMQP/stream connections due to AMQP heartbeat
timeouts, due to long running tasks, we will delegate processing of the
incoming message to another thread, while the connection adapter’s thread
continues to service its I/O loop’s message pump, permitting AMQP heartbeats
and other I/O to be serviced in a timely fashion.
Source: https://pika.readthedocs.io/en/stable/modules/adapters/index.html#requesting-message-acknowledgements-from-another-thread
Attributes:
dsn:
A string defining the data source name of the RabbitMQ host to
Expand All @@ -45,9 +56,11 @@ class RabbitMQ(Listener):
A pika.BlockingConnection instance.
channel:
A pika.BlockingConnection.channel instance.
executor:
A concurrent.futures.ThreadPoolExecutor instance.
"""

def __init__(self, dsn: str):
def __init__(self, dsn: str, queue: str, func: Callable, durable: bool = True, prefetch_count: int = 1) -> None:
"""Initialize the RabbitMQ Listener
Args:
Expand All @@ -56,17 +69,35 @@ def __init__(self, dsn: str):
connect to.
"""
super().__init__()
self.dsn = dsn

self.dsn: str = dsn
self.queue: str = queue
self.durable: bool = durable
self.prefetch_count: int = prefetch_count
self.func: Callable = func

self.executor: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor(max_workers=10)
self.connection = pika.BlockingConnection(pika.URLParameters(self.dsn))
self.channel = self.connection.channel()

def dispatch(self, body: bytes) -> None:
"""Dispatch a message without a return value"""
raise NotImplementedError
def listen(self) -> None:
self.basic_consume(self.queue, self.durable, self.prefetch_count)

def dispatch(self, channel, delivery_tag, body: bytes) -> None:
# Call the function
self.func(body)

# Acknowledge the message
self.connection.add_callback_threadsafe(functools.partial(self.ack_message, channel, delivery_tag))

@retry(
(pika.exceptions.AMQPConnectionError, pika.exceptions.ConnectionClosedByBroker), delay=5, jitter=(1, 3), tries=5
)
def basic_consume(self, queue: str, durable: bool, prefetch_count: int) -> None:
try:
if self.connection.is_closed:
self.connection = pika.BlockingConnection(pika.URLParameters(self.dsn))
self.channel = self.connection.channel()
self.channel.queue_declare(queue=queue, durable=durable)
except pika.exceptions.ChannelClosedByBroker as exc:
if "inequivalent arg 'durable'" in exc.reply_text:
Expand All @@ -79,9 +110,15 @@ def basic_consume(self, queue: str, durable: bool, prefetch_count: int) -> None:
self.channel.queue_declare(queue=queue, durable=durable)
else:
raise
self.channel.basic_qos(prefetch_count=prefetch_count)
self.channel.basic_consume(queue, on_message_callback=self.callback)
self.channel.start_consuming()

try:
self.channel.basic_qos(prefetch_count=prefetch_count)
self.channel.basic_consume(queue, on_message_callback=self.callback)
self.channel.start_consuming()
except pika.exceptions.AMQPChannelError as exc:
# Do not recover on channel errors
self.logger.error("AMQPChannelError: %s", exc)
raise exc

def get(self, queue: str) -> Optional[Dict[str, object]]:
method, properties, body = self.channel.basic_get(queue)
Expand All @@ -94,6 +131,13 @@ def get(self, queue: str) -> Optional[Dict[str, object]]:

return response

def ack_message(self, channel, delivery_tag):
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ack this message
self.logger.debug("Channel already closed, cannot ack message!")

def callback(
self,
channel: pika.channel.Channel,
Expand All @@ -105,10 +149,7 @@ def callback(
queue.
"""
self.logger.debug("Received message on queue %s, message: %r", method.routing_key, body)

self.dispatch(body)

channel.basic_ack(method.delivery_tag)
self.executor.submit(self.dispatch, channel, method.delivery_tag, body)

def is_healthy(self) -> bool:
"""Check if the RabbitMQ connection is healthy"""
Expand All @@ -126,6 +167,7 @@ def is_healthy(self) -> bool:
def stop(self) -> None:
self.logger.debug("Stopping RabbitMQ connection")

self.executor.shutdown()
self.connection.add_callback_threadsafe(self._close_callback)

self.logger.debug("RabbitMQ connection closed")
Expand Down
38 changes: 1 addition & 37 deletions mula/scheduler/connectors/listeners/raw_data.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,9 @@
from typing import Callable

from scheduler.models import RawDataReceivedEvent

from .listeners import RabbitMQ


class RawData(RabbitMQ):
"""The RawData listener class that listens to the raw data queue and calls
the function passed to it. This is used within the NormalizerScheduler.
Attributes:
queue: A string representing the messaging queue name.
func: A python callable.
prefetch_count: An integer representing the prefetch count.
"""

def __init__(self, dsn: str, queue: str, func: Callable, prefetch_count: int):
"""Initialize the RawData listener.
Args:
dsn: A string representing the DSN.
queue: A string representing the messaging queue name.
func: A python callable.
prefetch_count: An integer representing the prefetch count.
"""
super().__init__(dsn)
self.queue = queue
self.func = func
self.prefetch_count = prefetch_count

def listen(self) -> None:
"""Listen to the messaging queue."""
self.basic_consume(self.queue, True, self.prefetch_count)

def dispatch(self, body: bytes) -> None:
"""Dispatch the message to the function.
body: A bytes object representing the body of the message.
"""
# Convert body into a RawDataReceivedEvent
model = RawDataReceivedEvent.parse_raw(body)

# Call the function
self.func(model)
pass
38 changes: 1 addition & 37 deletions mula/scheduler/connectors/listeners/scan_profile.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,10 @@
from typing import Callable

from scheduler.models import ScanProfileMutation as ScanProfileMutationModel

from .listeners import RabbitMQ


class ScanProfileMutation(RabbitMQ):
"""The ScanProfileMutation listener class that listens to the scan profile
mutation queue and calls the function passed to it. This is used within the
BoefjeScheduler.
Attributes:
queue: A string representing the messaging queue name.
func: A python callable.
prefetch_count: An integer representing the prefetch count.
"""

def __init__(self, dsn: str, queue: str, func: Callable, prefetch_count: int):
"""Initialize the RawData listener.
Args:
dsn: A string representing the DSN.
queue: A string representing the messaging queue name.
func: A python callable.
prefetch_count: An integer representing the prefetch count.
"""
super().__init__(dsn)
self.queue = queue
self.func = func
self.prefetch_count = prefetch_count

def listen(self) -> None:
"""Listen to the messaging queue."""
self.basic_consume(self.queue, True, self.prefetch_count)

def dispatch(self, body: bytes) -> None:
"""Dispatch the message to the function.
body: A bytes object representing the body of the message.
"""
# Convert body into a ScanProfileMutationModel
model = ScanProfileMutationModel.parse_raw(body)

# Call the function
self.func(model)
pass
7 changes: 5 additions & 2 deletions mula/scheduler/schedulers/boefje.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ def run(self) -> None:
)

@tracer.start_as_current_span("boefje_push_tasks_for_scan_profile_mutations")
def push_tasks_for_scan_profile_mutations(self, mutation: ScanProfileMutation) -> None:
def push_tasks_for_scan_profile_mutations(self, body: bytes) -> None:
"""Create tasks for oois that have a scan level change.
Args:
mutation: The mutation that was received.
"""
self.logger.info(
# Convert body into a ScanProfileMutation
mutation = ScanProfileMutation.parse_raw(body)

self.logger.debug(
"Received scan level mutation %s for: %s [ooi_primary_key=%s, organisation_id=%s, scheduler_id=%s]",
mutation.operation,
mutation.primary_key,
Expand Down
16 changes: 14 additions & 2 deletions mula/scheduler/schedulers/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@

from scheduler import context, queues, rankers
from scheduler.connectors import listeners
from scheduler.models import Normalizer, NormalizerTask, Organisation, Plugin, PrioritizedItem, RawData, TaskStatus
from scheduler.models import (
Normalizer,
NormalizerTask,
Organisation,
Plugin,
PrioritizedItem,
RawData,
RawDataReceivedEvent,
TaskStatus,
)

from .scheduler import Scheduler

Expand Down Expand Up @@ -80,13 +89,16 @@ def run(self) -> None:
)

@tracer.start_as_current_span("normalizer_push_task_for_received_raw_data")
def push_tasks_for_received_raw_data(self, latest_raw_data: RawData) -> None:
def push_tasks_for_received_raw_data(self, body: bytes) -> None:
"""Create tasks for the received raw data.
Args:
latest_raw_data: A `RawData` object that was received from the
message queue.
"""
# Convert body into a RawDataReceivedEvent
latest_raw_data = RawDataReceivedEvent.parse_raw(body)

self.logger.debug(
"Received new raw data from message queue [raw_data_id=%s, organisation_id=%s, scheduler_id=%s]",
latest_raw_data.raw_data.id,
Expand Down
Loading

0 comments on commit e622b3f

Please sign in to comment.