diff --git a/.env.sample b/.env.sample index 7c1bd17..49b7fa2 100644 --- a/.env.sample +++ b/.env.sample @@ -1,4 +1,11 @@ +DUNE_API_KEY= + +# Slack Credentials SLACK_TOKEN= SLACK_ALERT_CHANNEL= -DUNE_API_KEY= \ No newline at end of file +# Twitter Credentials +CONSUMER_KEY= +CONSUMER_SECRET= +ACCESS_TOKEN= +ACCESS_TOKEN_SECRET= diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5019bd6 --- /dev/null +++ b/Makefile @@ -0,0 +1,37 @@ +VENV = venv +PYTHON = $(VENV)/bin/python3 +PIP = $(VENV)/bin/pip +PROJECT_ROOT = src + + +$(VENV)/bin/activate: requirements/dev.txt + python3 -m venv $(VENV) + $(PIP) install --upgrade pip + $(PIP) install -r requirements/dev.txt + + +install: + make $(VENV)/bin/activate + +clean: + rm -rf __pycache__ + +fmt: + black ./ + +lint: + pylint ${PROJECT_ROOT}/ + +types: + mypy ${PROJECT_ROOT}/ --strict + +check: + make fmt + make lint + make types + +test-unit: + python -m pytest tests/unit + +test-e2e: + python -m pytest tests/e2e \ No newline at end of file diff --git a/requirements/prod.txt b/requirements/prod.txt index 22c8f08..05df39d 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -5,4 +5,5 @@ types-python-dateutil==2.8.19 types-PyYAML==6.0.11 python-dateutil==2.8.2 python-dotenv==0.21.0 -certifi==2022.12.7 \ No newline at end of file +certifi==2022.12.7 +tweepy==4.13.0 diff --git a/src/post/__init__.py b/src/post/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/post/base.py b/src/post/base.py new file mode 100644 index 0000000..bc08c24 --- /dev/null +++ b/src/post/base.py @@ -0,0 +1,12 @@ +"""Abstraction for posting alerts""" +from abc import ABC, abstractmethod + + +class PostClient(ABC): + """ + Basic Post Client with message post functionality + """ + + @abstractmethod + def post(self, message: str) -> None: + """Posts `message` to `self.channel` excluding link previews.""" diff --git a/src/post/twitter.py b/src/post/twitter.py new file mode 100644 index 0000000..974c35b --- /dev/null +++ b/src/post/twitter.py @@ -0,0 +1,24 @@ +""" +Twitter Alert Client +""" +import tweepy # type:ignore + +from src.post.base import PostClient + + +class TwitterClient(PostClient): + """Forwards alerts to Twitter""" + + def __init__(self, credentials: dict[str, str]) -> None: + auth = tweepy.OAuthHandler( + consumer_key=credentials["consumer_key"], + consumer_secret=credentials["consumer_secret"], + ) + auth.set_access_token( + key=credentials["access_token"], + secret=credentials["access_token_secret"], + ) + self.api = tweepy.API(auth) + + def post(self, message: str) -> None: + self.api.update_status(status=message) diff --git a/src/query_monitor/factory.py b/src/query_monitor/factory.py index 23bab8a..1ddd8c6 100644 --- a/src/query_monitor/factory.py +++ b/src/query_monitor/factory.py @@ -2,13 +2,14 @@ Factory method to load QueryMonitor object from yaml configuration files """ from __future__ import annotations -import os -from dataclasses import dataclass import logging.config +from dataclasses import dataclass +from enum import Enum + import yaml -from dune_client.types import QueryParameter from dune_client.query import Query +from dune_client.types import QueryParameter from src.models import TimeWindow, LeftBound from src.query_monitor.base import QueryBase @@ -17,11 +18,22 @@ from src.query_monitor.result_threshold import ResultThresholdQuery from src.query_monitor.windowed import WindowedQueryMonitor - log = logging.getLogger(__name__) logging.config.fileConfig(fname="logging.conf", disable_existing_loggers=False) +class AlertType(Enum): + """Supported Alert Frameworks.""" + + SLACK = "slack" + TWITTER = "twitter" + + @classmethod + def from_str(cls, val: str) -> AlertType: + """From string constructor""" + return cls(val.lower()) + + @dataclass class Config: """ @@ -31,6 +43,7 @@ class Config: query: QueryBase ping_frequency: int alert_channel: str + alert_type: AlertType def load_config(config_yaml: str) -> Config: @@ -66,10 +79,11 @@ def load_config(config_yaml: str) -> Config: config_obj = Config( query=base_query, - # Use specified channel, or default to "global config" - alert_channel=cfg.get("alert_channel", os.environ["SLACK_ALERT_CHANNEL"]), + alert_channel=cfg.get("alert_channel"), # This is 4x the DuneClient default of 5 seconds ping_frequency=cfg.get("ping_frequency", 20), + # Slack is the default alert type. + alert_type=AlertType.from_str(cfg.get("alert_type", "slack")), ) log.debug(f"config parsed as {config_obj}") return config_obj diff --git a/src/runner.py b/src/runner.py index 181d2f9..2e0d778 100644 --- a/src/runner.py +++ b/src/runner.py @@ -10,8 +10,8 @@ from dune_client.client import DuneClient from src.alert import AlertLevel +from src.post.base import PostClient from src.query_monitor.base import QueryBase -from src.slack_client import BasicSlackClient log = logging.getLogger(__name__) logging.config.fileConfig(fname="logging.conf", disable_existing_loggers=False) @@ -26,12 +26,12 @@ def __init__( self, query: QueryBase, dune: DuneClient, - slack_client: BasicSlackClient, + alerter: PostClient, ping_frequency: int, ): self.query = query self.dune = dune - self.slack_client = slack_client + self.alerter = alerter self.ping_frequency = ping_frequency def run_loop(self) -> None: @@ -44,6 +44,6 @@ def run_loop(self) -> None: alert = query.get_alert(results) if alert.level == AlertLevel.SLACK: log.warning(alert.message) - self.slack_client.post(alert.message) + self.alerter.post(alert.message) elif alert.level == AlertLevel.LOG: log.info(alert.message) diff --git a/src/slack_client.py b/src/slack_client.py index 45e5a37..a2c5d76 100644 --- a/src/slack_client.py +++ b/src/slack_client.py @@ -10,11 +10,13 @@ from slack.errors import SlackApiError from slack.web.client import WebClient +from src.post.base import PostClient + log = logging.getLogger(__name__) logging.config.fileConfig(fname="logging.conf", disable_existing_loggers=False) -class BasicSlackClient: +class BasicSlackClient(PostClient): """ Basic Slack Client with message post functionality constructed from an API token and channel diff --git a/src/slackbot.py b/src/slackbot.py index 0dc7105..901da27 100644 --- a/src/slackbot.py +++ b/src/slackbot.py @@ -8,8 +8,10 @@ from dune_client.client import DuneClient +from src.post.base import PostClient +from src.post.twitter import TwitterClient from src.query_monitor.base import QueryBase -from src.query_monitor.factory import load_config +from src.query_monitor.factory import load_config, AlertType from src.runner import QueryRunner from src.slack_client import BasicSlackClient @@ -17,14 +19,14 @@ def run_slackbot( query: QueryBase, dune: DuneClient, - slack_client: BasicSlackClient, + alert_client: PostClient, ping_frequency: int, ) -> None: """ This is the main method of the program. Instantiate a query runner, and execute its run_loop """ - query_runner = QueryRunner(query, dune, slack_client, ping_frequency) + query_runner = QueryRunner(query, dune, alert_client, ping_frequency) query_runner.run_loop() @@ -39,11 +41,29 @@ def run_slackbot( args = parser.parse_args() dotenv.load_dotenv() config = load_config(args.query_config) + + alerter: PostClient + if config.alert_type == AlertType.SLACK: + alerter = BasicSlackClient( + token=os.environ["SLACK_TOKEN"], + # Use specified channel, or default to "global config" + channel=config.alert_channel or os.environ["SLACK_ALERT_CHANNEL"], + ) + elif config.alert_type == AlertType.TWITTER: + alerter = TwitterClient( + credentials={ + "consumer_key": os.environ["CONSUMER_KEY"], + "consumer_secret": os.environ["CONSUMER_SECRET"], + "access_token": os.environ["ACCESS_TOKEN"], + "access_token_secret": os.environ["ACCESS_TOKEN_SECRET"], + } + ) + else: + raise ValueError(f"Invalid or unsupported AlertType {config.alert_type}") + run_slackbot( query=config.query, dune=DuneClient(os.environ["DUNE_API_KEY"]), - slack_client=BasicSlackClient( - token=os.environ["SLACK_TOKEN"], channel=config.alert_channel - ), + alert_client=alerter, ping_frequency=config.ping_frequency, ) diff --git a/tests/e2e/test_twitter_post.py b/tests/e2e/test_twitter_post.py new file mode 100644 index 0000000..f029627 --- /dev/null +++ b/tests/e2e/test_twitter_post.py @@ -0,0 +1,26 @@ +import os +import unittest + +import dotenv +import pytest + +from src.post.twitter import TwitterClient + + +class TestTwitterPost(unittest.TestCase): + @pytest.mark.skip(reason="Don't want to make a post all the time.") + def test_twitter_post(self): + dotenv.load_dotenv() + client = TwitterClient( + credentials={ + "consumer_key": os.environ["CONSUMER_KEY"], + "consumer_secret": os.environ["CONSUMER_SECRET"], + "access_token": os.environ["ACCESS_TOKEN"], + "access_token_secret": os.environ["ACCESS_TOKEN_SECRET"], + } + ) + client.post("Hi Mom!") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_implementations.py b/tests/unit/test_implementations.py index 4d995da..fff1472 100644 --- a/tests/unit/test_implementations.py +++ b/tests/unit/test_implementations.py @@ -99,10 +99,6 @@ def test_load_from_config(self): self.assertTrue(isinstance(left_bounded_monitor, LeftBoundedQueryMonitor)) del os.environ["SLACK_ALERT_CHANNEL"] - def test_load_config_error(self): - with self.assertRaises(KeyError): - load_config(filepath("no-params.yaml")) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_load_config.py b/tests/unit/test_load_config.py index 655d641..0c84bb9 100644 --- a/tests/unit/test_load_config.py +++ b/tests/unit/test_load_config.py @@ -6,14 +6,10 @@ class TestConfigLoading(unittest.TestCase): - def setUp(self) -> None: - self.fallback_alert_channel = "Default" - os.environ["SLACK_ALERT_CHANNEL"] = self.fallback_alert_channel - def test_default_config(self): config = load_config(filepath("counter.yaml")) - self.assertEqual(config.alert_channel, self.fallback_alert_channel) + self.assertEqual(config.alert_channel, None) def test_specified_channel(self): config = load_config(filepath("alert-channel.yaml"))