diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..d216af26 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,81 @@ +version: 2.1 + +workflows: + version: 2 + workflow: + jobs: + - lint-flake8 + - test-python37 + - test-python38 + +commands: + tox: + description: "Execute tox env" + parameters: + env: + type: "string" + default: "py37" + steps: + - restore_cache: + keys: + - venv-{{ .Environment.CIRCLE_STAGE }}-{{ .Environment.cacheVer }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + - run: + name: "setup up test environment" + command: | + mkdir ./test-reports + mkdir ./test-reports/coverage + test ! -d venv && pip install virtualenv && virtualenv venv + source venv/bin/activate + pip install -U setuptools tox codecov + - save_cache: + key: venv-{{ .Environment.CIRCLE_STAGE }}-{{ .Environment.cacheVer }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + paths: + - venv + - run: + name: "tox: << parameters.env >>" + command: | + source venv/bin/activate + tox -e << parameters.env >> + - store_artifacts: + path: test-reports + - store_test_results: + path: test-reports + +executors: + python: + parameters: + version: + type: "string" + default: "3.8" + docker: + - image: circleci/python:<< parameters.version >> + environment: + PYTHON_VERSION: "<< parameters.version >>" + +jobs: + test-python37: + executor: + name: "python" + version: "3.7" + steps: + - checkout + - tox: + env: py37 + + test-python38: + executor: + name: "python" + version: "3.8" + steps: + - checkout + - tox: + env: py38 + + lint-flake8: + executor: + name: "python" + version: "3.8" + steps: + - checkout + - tox: + env: flake8 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..fccf3a12 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +default_language_version: + python: python3.7 + +fail_fast: true + +repos: +- repo: https://github.com/ambv/black + rev: 19.3b0 + hooks: + - id: black + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.2.3 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + args: + - --unsafe + - id: debug-statements + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: flake8 + args: + - --exclude=docs/*,tests/* + - --max-line-length=131 + - --ignore=E203 + - id: forbid-new-submodules + - id: mixed-line-ending + args: + - --fix=lf + - id: no-commit-to-branch + args: + - -b master + - id: trailing-whitespace + +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.1.6 + hooks: + - id: remove-tabs + +- repo: https://github.com/Lucas-C/pre-commit-hooks-safety + rev: v1.1.0 + hooks: + - id: python-safety-dependencies-check diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 4359f334..00000000 --- a/.travis.yml +++ /dev/null @@ -1,23 +0,0 @@ -sudo: false -language: python -matrix: - include: - - python: 3.6 - env: TOX_ENV=py36 - - python: 3.7 - dist: xenial - sudo: true - env: TOX_ENV=py37 - - python: 3.7 - dist: xenial - sudo: true - env: TOX_ENV=flake8 -cache: pip -install: - - "travis_retry pip install setuptools --upgrade" - - "travis_retry pip install tox" - - "travis_retry pip install codecov" -script: - - tox -e $TOX_ENV -after_success: - - test $TOX_ENV == "py37" && codecov \ No newline at end of file diff --git a/machine/__about__.py b/machine/__about__.py index 79cb8b4f..bbaeab30 100644 --- a/machine/__about__.py +++ b/machine/__about__.py @@ -1,14 +1,21 @@ +# -*- coding: utf-8 -*- __all__ = [ - '__title__', '__description__', '__uri__', '__version__', '__author__', - '__email__', '__license__', '__copyright__' + "__title__", + "__description__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] -__title__ = "slack-machine" +__title__ = "aio-slack-machine" __description__ = "A sexy, simple, yet powerful and extendable Slack bot" __uri__ = "https://github.com/DandyDev/slack-machine" -__version_info__ = (0, 18, 0) -__version__ = '.'.join(map(str, __version_info__)) +__version_info__ = (0, 19, 0) +__version__ = ".".join(map(str, __version_info__)) __author__ = "Daan Debie" __email__ = "debie.daan@gmail.com" diff --git a/machine/bin/run.py b/machine/bin/run.py index 52877e80..43e0bf2d 100644 --- a/machine/bin/run.py +++ b/machine/bin/run.py @@ -1,17 +1,56 @@ +# -*- coding: utf-8 -*- + +import asyncio +import signal import sys import os +from functools import partial + +from loguru import logger from machine import Machine -from machine.utils.text import announce def main(): # When running this function as console entry point, the current working dir is not in the # Python path, so we have to add it sys.path.insert(0, os.getcwd()) - bot = Machine() - try: - bot.run() - except KeyboardInterrupt: - announce("Thanks for playing!") - sys.exit(0) + + loop = asyncio.new_event_loop() + + # Handle INT and TERM by gracefully halting the event loop + for s in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(s, partial(prepare_to_stop, loop)) + + bot = Machine(loop=loop) + loop.run_until_complete(bot.run()) + + # Once `prepare_to_stop` returns, `loop` is guaranteed + # to be stopped. + prepare_to_stop(loop) + + loop.close() + sys.exit(0) + + +def prepare_to_stop(loop: asyncio.AbstractEventLoop): + # Calling `cancel` on each task in the active loop causes + # a `CancelledError` to be thrown into each wrapped coroutine during + # the next iteration, which _should_ halt execution of the coroutine + # if the coroutine does not explicitly suppress the `CancelledError`. + for task in asyncio.Task.all_tasks(loop=loop): + task.cancel() + + # Using `ensure_future` on the `stop` coroutine here ensures that + # the `stop` coroutine is the last thing to run after each cancelled + # task has its `CancelledError` emitted. + asyncio.ensure_future(stop(), loop=loop) + loop.run_forever() + + +# @asyncio.coroutine +async def stop(): + loop = asyncio.get_event_loop() + logger.info("Thanks for playing!") + + loop.stop() diff --git a/machine/core.py b/machine/core.py index c56565e0..ddcc07f3 100644 --- a/machine/core.py +++ b/machine/core.py @@ -1,12 +1,13 @@ +# -*- coding: utf-8 -*- + +import asyncio import inspect -import logging import sys -import time -from threading import Thread +from typing import Mapping, Optional import dill -from clint.textui import puts, indent, colored -import bottle +from aiohttp.web import Application, AppRunner, TCPSite +from loguru import logger from machine.dispatch import EventDispatcher from machine.plugins.base import MachineBasePlugin @@ -14,81 +15,152 @@ from machine.singletons import Slack, Scheduler, Storage from machine.slack import MessagingClient from machine.storage import PluginStorage +from machine.utils import aio, collections, log_propagate from machine.utils.module_loading import import_string -from machine.utils.text import show_valid, show_invalid, warn, error, announce - -logger = logging.getLogger(__name__) class Machine: - def __init__(self, settings=None): - announce("Initializing Slack Machine:") - - with indent(4): - puts("Loading settings...") - if settings: - self._settings = settings - found_local_settings = True - else: - self._settings, found_local_settings = import_settings() - fmt = '[%(asctime)s][%(levelname)s] %(name)s %(filename)s:%(funcName)s:%(lineno)d |' \ - ' %(message)s' - date_fmt = '%Y-%m-%d %H:%M:%S' - log_level = self._settings.get('LOGLEVEL', logging.ERROR) - logging.basicConfig( - level=log_level, - format=fmt, - datefmt=date_fmt, + _client: Slack + _dispatcher: EventDispatcher + _loop: asyncio.AbstractEventLoop + _settings: collections.CaseInsensitiveDict + _storage: Storage + + _help: Mapping[str, dict] = {"human": {}, "robot": {}} + _http_app: Optional[Application] = None + _http_runner: Optional[AppRunner] = None + _plugin_actions: Mapping[str, dict] = { + "process": {}, + "listen_to": {}, + "respond_to": {}, + } + + def __init__(self, loop=None, settings=None): + log_propagate.install() + logger.info("Initializing Slack Machine") + + self._loop = loop or asyncio.get_event_loop() + + logger.debug("Loading settings...") + if settings: + self._settings = settings + found_local_settings = True + else: + self._settings, found_local_settings = import_settings() + + if not found_local_settings: + logger.warning( + "No local_settings found! Are you sure this is what you want?" ) - if not found_local_settings: - warn("No local_settings found! Are you sure this is what you want?") - if 'SLACK_API_TOKEN' not in self._settings: - error("No SLACK_API_TOKEN found in settings! I need that to work...") - sys.exit(1) - self._client = Slack() - puts("Initializing storage using backend: {}".format(self._settings['STORAGE_BACKEND'])) - self._storage = Storage.get_instance() - logger.debug("Storage initialized!") - - self._plugin_actions = { - 'process': {}, - 'listen_to': {}, - 'respond_to': {}, - 'catch_all': {} - } - self._help = { - 'human': {}, - 'robot': {} - } - puts("Loading plugins...") - self.load_plugins() - logger.debug("The following plugin actions were registered: %s", self._plugin_actions) - self._dispatcher = EventDispatcher(self._plugin_actions, self._settings) + + if "SLACK_API_TOKEN" not in self._settings: + logger.error("No SLACK_API_TOKEN found in settings! I need that to work...") + sys.exit(1) + + self._client = Slack(loop=self._loop) + + logger.info( + "Initializing storage using backend: {}".format( + self._settings["STORAGE_BACKEND"] + ) + ) + self._storage = Storage() + logger.debug("Storage initialized!") + + self._loop.run_until_complete(self._storage.connect()) + + if not self._settings.get("DISABLE_HTTP", False): + self._http_app = Application() + else: + self._http_app = None + + logger.debug("Loading plugins...") + self.load_plugins() + logger.debug( + f"The following plugin actions were registered: {self._plugin_actions}" + ) + + self._dispatcher = EventDispatcher(self._plugin_actions, self._settings) def load_plugins(self): - with indent(4): - logger.debug("PLUGINS: %s", self._settings['PLUGINS']) - for plugin in self._settings['PLUGINS']: - for class_name, cls in import_string(plugin): - if issubclass(cls, MachineBasePlugin) and cls is not MachineBasePlugin: - logger.debug("Found a Machine plugin: {}".format(plugin)) - storage = PluginStorage(class_name) - instance = cls(self._settings, MessagingClient(), - storage) - missing_settings = self._register_plugin(class_name, instance) - if missing_settings: - show_invalid(class_name) - with indent(4): - error_msg = "The following settings are missing: {}".format( - ", ".join(missing_settings) - ) - puts(colored.red(error_msg)) - puts(colored.red("This plugin will not be loaded!")) - del instance + for plugin in self._settings["PLUGINS"]: + for class_name, cls in import_string(plugin): + if issubclass(cls, MachineBasePlugin) and cls is not MachineBasePlugin: + logger.debug("Found a Machine plugin: {}".format(plugin)) + storage = PluginStorage(class_name) + instance = cls(self._settings, MessagingClient(), storage) + + missing_settings = self._register_plugin(class_name, instance) + if missing_settings: + error_msg = "The following settings are missing: {}".format( + ", ".join(missing_settings) + ) + logger.error( + f"{class_name}: {error_msg}. This plugin will not be loaded!" + ) + del instance + else: + if inspect.iscoroutinefunction(instance.init): + self._loop.run_until_complete(instance.init(self._http_app)) else: - instance.init() - show_valid(class_name) - self._storage.set('manual', dill.dumps(self._help)) + instance.init(self._http_app) + + logger.info(f"Loaded plugin: {class_name}") + + self._loop.run_until_complete( + self._storage.set("manual", dill.dumps(self._help)) + ) + + async def run(self): + logger.info("Starting Slack Machine") + self._dispatcher.start() + + keepaliver: Optional[asyncio.Task] = None + try: + await aio.join([self._start_scheduler(), self._start_http_server()]) + # Launch the keepaliver task, keeping a handle to it + # in the current context so it can be cancelled later. + keepaliver = await self._start_keepaliver() + # `rtm.start()` will be continuously waited on and will not + # return unless the connection is closed. + logger.info("Connecting to Slack...") + await self._client.rtm.start() + except (KeyboardInterrupt, SystemExit): + logger.info("Slack Machine shutting down...") + + # Halt the keepaliver task + if keepaliver and not keepaliver.cancelled(): + keepaliver.cancel() + + # Clean up/shut down the aiohttp AppRunner + if self._http_runner is not None: + await self._http_runner.cleanup() + + async def _start_scheduler(self): + logger.info("Starting scheduler...") + Scheduler(loop=self._loop).start() + + async def _start_http_server(self): + if self._http_app is not None: + http_host = self._settings.get("HTTP_SERVER_HOST", "127.0.0.1") + http_port = int(self._settings.get("HTTP_SERVER_PORT", 3000)) + logger.info(f"Starting web server on {http_host}:{http_port}...") + + self._http_runner = AppRunner(self._http_app) + await self._http_runner.setup() + + site = TCPSite(self._http_runner, http_host, http_port) + await site.start() + + logger.debug("Started web server!") + + async def _start_keepaliver(self): + interval = self._settings["KEEP_ALIVE"] + if interval: + logger.info(f"Starting keepaliver... [Interval: {interval}s]") + return asyncio.create_task(self._keepaliver(interval)) + + return None def _register_plugin(self, plugin_class, cls_instance): missing_settings = [] @@ -99,121 +171,94 @@ def _register_plugin(self, plugin_class, cls_instance): if missing_settings: return missing_settings - if hasattr(cls_instance, 'catch_all'): - self._plugin_actions['catch_all'][plugin_class] = { - 'class': cls_instance, - 'class_name': plugin_class, - 'function': getattr(cls_instance, 'catch_all') + if hasattr(cls_instance, "catch_all"): + self._plugin_actions["catch_all"][plugin_class] = { + "class": cls_instance, + "class_name": plugin_class, + "function": getattr(cls_instance, "catch_all"), } if cls_instance.__doc__: class_help = cls_instance.__doc__.splitlines()[0] else: class_help = plugin_class - self._help['human'][class_help] = self._help['human'].get(class_help, {}) - self._help['robot'][class_help] = self._help['robot'].get(class_help, []) + self._help["human"][class_help] = self._help["human"].get(class_help, {}) + self._help["robot"][class_help] = self._help["robot"].get(class_help, []) for name, fn in methods: - if hasattr(fn, 'metadata'): - self._register_plugin_actions(plugin_class, fn.metadata, cls_instance, name, fn, - class_help) + if hasattr(fn, "metadata"): + self._register_plugin_actions( + plugin_class, fn.metadata, cls_instance, name, fn, class_help + ) - def _check_missing_settings(self, fn_or_class): + def _check_missing_settings(self, item): missing_settings = [] - if hasattr(fn_or_class, 'metadata') and 'required_settings' in fn_or_class.metadata: - for setting in fn_or_class.metadata['required_settings']: + if hasattr(item, "metadata") and "required_settings" in item.metadata: + for setting in item.metadata["required_settings"]: if setting not in self._settings: missing_settings.append(setting.upper()) + return missing_settings - def _register_plugin_actions(self, plugin_class, metadata, cls_instance, fn_name, fn, - class_help): + def _register_plugin_actions( + self, plugin_class, metadata, cls_instance, fn_name, fn, class_help + ): fq_fn_name = "{}.{}".format(plugin_class, fn_name) if fn.__doc__: - self._help['human'][class_help][fq_fn_name] = self._parse_human_help(fn.__doc__) - for action, config in metadata['plugin_actions'].items(): - if action == 'process': - event_type = config['event_type'] - event_handlers = self._plugin_actions['process'].get(event_type, {}) + self._help["human"][class_help][fq_fn_name] = self._parse_human_help( + fn.__doc__ + ) + for action, config in metadata["plugin_actions"].items(): + if action == "process": + event_type = config["event_type"] + event_handlers = self._plugin_actions["process"].get(event_type, {}) event_handlers[fq_fn_name] = { - 'class': cls_instance, - 'class_name': plugin_class, - 'function': fn + "class": cls_instance, + "class_name": plugin_class, + "function": fn, } - self._plugin_actions['process'][event_type] = event_handlers - if action == 'respond_to' or action == 'listen_to': - for regex in config['regex']: + self._plugin_actions["process"][event_type] = event_handlers + elif action == "respond_to" or action == "listen_to": + for regex in config["regex"]: event_handler = { - 'class': cls_instance, - 'class_name': plugin_class, - 'function': fn, - 'regex': regex + "class": cls_instance, + "class_name": plugin_class, + "function": fn, + "regex": regex, } key = "{}-{}".format(fq_fn_name, regex.pattern) self._plugin_actions[action][key] = event_handler - self._help['robot'][class_help].append(self._parse_robot_help(regex, action)) - if action == 'schedule': - Scheduler.get_instance().add_job(fq_fn_name, trigger='cron', args=[cls_instance], - id=fq_fn_name, replace_existing=True, **config) - if action == 'route': - for route_config in config: - bottle.route(**route_config)(fn) + self._help["robot"][class_help].append( + self._parse_robot_help(regex, action) + ) + elif action == "schedule": + Scheduler.get_instance().add_job( + fq_fn_name, + trigger="cron", + args=[cls_instance], + id=fq_fn_name, + replace_existing=True, + **config, + ) @staticmethod def _parse_human_help(doc): - summary = doc.splitlines()[0].split(':') + summary = doc.splitlines()[0].split(":") if len(summary) > 1: command = summary[0].strip() cmd_help = summary[1].strip() else: command = "??" cmd_help = summary[0].strip() - return { - 'command': command, - 'help': cmd_help - } + return {"command": command, "help": cmd_help} @staticmethod def _parse_robot_help(regex, action): - if action == 'respond_to': + if action == "respond_to": return "@botname {}".format(regex.pattern) else: return regex.pattern - def _keepalive(self): + async def _keepaliver(self, interval): while True: - time.sleep(self._settings['KEEP_ALIVE']) - self._client.server.send_to_websocket({'type': 'ping'}) + await asyncio.sleep(interval) + await self._client.rtm.ping() logger.debug("Client Ping!") - - def run(self): - announce("\nStarting Slack Machine:") - with indent(4): - connected = self._client.rtm_connect() - if not connected: - logger.error("Could not connect to Slack! Aborting...") - sys.exit(1) - show_valid("Connected to Slack") - Scheduler.get_instance().start() - show_valid("Scheduler started") - if not self._settings['DISABLE_HTTP']: - self._bottle_thread = Thread( - target=bottle.run, - kwargs=dict( - host=self._settings['HTTP_SERVER_HOST'], - port=self._settings['HTTP_SERVER_PORT'], - server=self._settings['HTTP_SERVER_BACKEND'], - ) - ) - self._bottle_thread.daemon = True - self._bottle_thread.start() - show_valid("Web server started") - - if self._settings['KEEP_ALIVE']: - self._keep_alive_thread = Thread(target=self._keepalive) - self._keep_alive_thread.daemon = True - self._keep_alive_thread.start() - show_valid( - "Keepalive thread started [Interval: %ss]" % self._settings['KEEP_ALIVE'] - ) - - show_valid("Dispatcher started") - self._dispatcher.start() diff --git a/machine/dispatch.py b/machine/dispatch.py index eeb9489f..b37f7811 100644 --- a/machine/dispatch.py +++ b/machine/dispatch.py @@ -1,26 +1,25 @@ -import time +# -*- coding: utf-8 -*- + +import asyncio import re -import logging + +from loguru import logger from machine.singletons import Slack -from machine.utils.pool import ThreadPool -from machine.plugins.base import Message +from machine.message import Message from machine.slack import MessagingClient -logger = logging.getLogger(__name__) - class EventDispatcher: - def __init__(self, plugin_actions, settings=None): - self._client = Slack() + self._client = Slack.get_instance() self._plugin_actions = plugin_actions - self._pool = ThreadPool() - alias_regex = '' + alias_regex = "" if settings and "ALIASES" in settings: - logger.info("Setting aliases to {}".format(settings['ALIASES'])) - alias_regex = '|(?P{})'.format( - '|'.join([re.escape(s) for s in settings['ALIASES'].split(',')])) + logger.info("Setting aliases to {}".format(settings["ALIASES"])) + alias_regex = "|(?P{})".format( + "|".join([re.escape(s) for s in settings["ALIASES"].split(",")]) + ) self.RESPOND_MATCHER = re.compile( r"^(?:<@(?P\w+)>:?|(?P\w+):{}) ?(?P.*)$".format( alias_regex @@ -28,31 +27,52 @@ def __init__(self, plugin_actions, settings=None): re.DOTALL, ) + def _event_callback(self, event_type: str): + """ Returns a closured coroutine that dispatches the + given event type to the event handler function + """ + + async def dispatch(*, data: dict, **kwargs): + return await self.handle_event(event_type, data=data, **kwargs) + + return dispatch + def start(self): - while True: - for event in self._client.rtm_read(): - self._pool.add_task(self.handle_event, event) - time.sleep(.1) - - def handle_event(self, event): - # Gotta catch 'em all! - for action in self._plugin_actions['catch_all'].values(): - action['function'](event) + default_events = {"message", "pong"} + # `python-slackclient` no longer allows us to inject the firehose of events - + # we have to register a "type" of event we want to process to receive it. + registered_events = set(self._plugin_actions.get("process", []).keys()) + events = default_events | registered_events + + logger.debug(f"Registering for events: {events}") + + for event in events: + self._client.rtm.on(event=event, callback=self._event_callback(event)) + + async def handle_event(self, event_type: str, *, data: dict, **kwargs): + # The bot should never react to an event generated by itself + if "user" in data and data["user"] == self._get_bot_id(): + return + # Basic dispatch based on event type - if 'type' in event: - if event['type'] in self._plugin_actions['process']: - for action in self._plugin_actions['process'][event['type']].values(): - action['function'](event) + if event_type in self._plugin_actions["process"]: + handlers = [] + for action in self._plugin_actions["process"][event_type].values(): + handlers.append(action["function"](data)) + + await asyncio.gather(*handlers) + # Handle message listeners - if 'type' in event and event['type'] == 'message': - respond_to_msg = self._check_bot_mention(event) + if event_type == "message": + respond_to_msg = self._check_bot_mention(data) if respond_to_msg: - listeners = self._find_listeners('respond_to') - self._dispatch_listeners(listeners, respond_to_msg) + listeners = self._find_listeners("respond_to") + await self._dispatch_listeners(listeners, respond_to_msg) else: - listeners = self._find_listeners('listen_to') - self._dispatch_listeners(listeners, event) - if 'type' in event and event['type'] == 'pong': + listeners = self._find_listeners("listen_to") + await self._dispatch_listeners(listeners, data) + + elif event_type == "pong": logger.debug("Server Pong!") def _find_listeners(self, type): @@ -63,28 +83,29 @@ def _gen_message(event, plugin_class_name): return Message(MessagingClient(), event, plugin_class_name) def _get_bot_id(self): - return self._client.server.login_data['self']['id'] + return self._client.login_data["self"]["id"] def _get_bot_name(self): - return self._client.server.login_data['self']['name'] + return self._client.login_data["self"]["name"] def _check_bot_mention(self, event): - full_text = event.get('text', '') - channel = event['channel'] + full_text = event.get("text", "") + channel = event["channel"] bot_name = self._get_bot_name() bot_id = self._get_bot_id() - m = self.RESPOND_MATCHER.match(full_text) - if channel[0] == 'C' or channel[0] == 'G': - if not m: + at_response = self.RESPOND_MATCHER.match(full_text) + + if channel[0] == "C" or channel[0] == "G": + if not at_response: return None - matches = m.groupdict() + matches = at_response.groupdict() - atuser = matches.get('atuser') - username = matches.get('username') - text = matches.get('text') - alias = matches.get('alias') + atuser = matches.get("atuser") + username = matches.get("username") + text = matches.get("text") + alias = matches.get("alias") if alias: atuser = bot_id @@ -93,16 +114,21 @@ def _check_bot_mention(self, event): # a channel message at other user return None - event['text'] = text + event["text"] = text else: - if m: - event['text'] = m.groupdict().get('text', None) + if at_response: + event["text"] = at_response.groupdict().get("text", None) + return event - def _dispatch_listeners(self, listeners, event): + async def _dispatch_listeners(self, listeners, event): + handlers = [] for l in listeners: - matcher = l['regex'] - match = matcher.search(event.get('text', '')) + matcher = l["regex"] + match = matcher.search(event.get("text", "")) if match: - message = self._gen_message(event, l['class_name']) - l['function'](message, **match.groupdict()) + message = self._gen_message(event, l["class_name"]) + handlers.append(l["function"](message, **match.groupdict())) + + if handlers: + await asyncio.gather(*handlers) diff --git a/machine/message.py b/machine/message.py new file mode 100644 index 00000000..1b22b2d8 --- /dev/null +++ b/machine/message.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +from async_lru import alru_cache + + +class Message: + """A message that was received by the bot + + This class represents a message that was received by the bot and passed to one or more + plugins. It contains the message (text) itself, and metadata about the message, such as the + sender of the message, the channel the message was sent to. + + The ``Message`` class also contains convenience methods for replying to the message in the + right channel, replying to the sender, etc. + """ + + def __init__(self, client, msg_event, plugin_class_name): + self._client = client + self._msg_event = msg_event + self._fq_plugin_name = plugin_class_name + + @property + def user_id(self) -> str: + return self._msg_event["user"] + + @property + def channel_id(self) -> str: + return self._msg_event["channel"] + + @property + def thread_ts(self): + """The timestamp of the original message + + :return: the timestamp of the original message + """ + try: + thread_ts = self._msg_event["thread_ts"] + except KeyError: + thread_ts = self._msg_event["ts"] + + return thread_ts + + @property + def is_dm(self) -> bool: + chan = self.channel_id + return not (chan.startswith("C") or chan.startswith("G")) + + @property + def at_sender(self): + """The sender of the message formatted as mention + + :return: a string representation of the sender of the message, formatted as `mention`_, + to be used in messages + + .. _mention: https://api.slack.com/docs/message-formatting#linking_to_channels_and_users + """ + return self._client.fmt_mention({"id": self.user_id}) + + @property + def text(self): + """The body of the actual message + + :return: the body (text) of the actual message + """ + return self._msg_event["text"] + + @alru_cache(maxsize=8) + async def get_sender(self) -> dict: + """The sender of the message + + :return: dictionary describing the user the message was sent by + """ + return await self._client.find_user_by_id(self.user_id) + + @alru_cache(maxsize=8) + async def get_channel(self) -> dict: + """The channel the message was sent to + + :return: dictionary describing the channel the message was sent to + """ + return await self._client.find_channel_by_id(self.channel_id) + + async def say(self, text, **kwargs): + """Send a new message using the WebAPI to the channel the original message was received in + + Send a new message to the channel the original message was received in, using the WebAPI. + Allows for rich formatting using `attachments`_. Can also reply to a thread and send an + ephemeral message only visible to the sender of the original message. + Ephemeral messages and threaded messages are mutually exclusive, and ``ephemeral`` + takes precedence over ``thread_ts`` + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param text: message text + :param attachments: optional attachments (see `attachments`_) + :param thread_ts: optional timestamp of thread, to send a message in that thread + :param ephemeral: ``True/False`` whether to send the message as an ephemeral message, only + visible to the sender of the original message + :return: Dictionary deserialized from `chat.postMessage`_ request, or `chat.postEphemeral`_ + if `ephemeral` is True. + + .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage + .. _chat.postEphemeral: https://api.slack.com/methods/chat.postEphemeral + """ + + return await self._client.send( + self.channel_id, text, **self._handle_context_args(**kwargs) + ) + + def say_scheduled(self, when, text, **kwargs): + """Schedule a message and send it using the WebAPI + + This is the scheduled version of :py:meth:`~machine.plugins.base.Message.say_webapi`. + It behaves the same, but will send the DM at the scheduled time. + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance + :param text: message text + :param attachments: optional attachments (see `attachments`_) + :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only + visible to the sender of the original message + :return: None + """ + + self._client.send_scheduled( + when, self.channel_id, text, **self._handle_context_args(**kwargs) + ) + + async def reply(self, text, **kwargs): + """Reply to the sender of the original message using the WebAPI + + Reply to the sender of the original message with a new message, mentioning that user. Uses + the WebAPI, so rich formatting using `attachments`_ is possible. Can also reply to a thread + and send an ephemeral message only visible to the sender of the original message. + Ephemeral messages and threaded messages are mutually exclusive, and ``ephemeral`` + takes precedence over ``thread_ts`` + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param text: message text + :param attachments: optional attachments (see `attachments`_) + :param in_thread: ``True/False`` wether to reply to the original message in-thread + :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only + visible to the sender of the original message + :return: Dictionary deserialized from `chat.postMessage`_ request, or `chat.postEphemeral`_ + if `ephemeral` is True. + + .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage + .. _chat.postEphemeral: https://api.slack.com/methods/chat.postEphemeral + """ + in_thread = kwargs.get("in_thread", False) + ephemeral = kwargs.get("ephemeral", False) + if in_thread and not ephemeral: + text = self._create_reply(text) + + return await self.say(text, **self._handle_context_args(**kwargs)) + + def reply_scheduled(self, when, text, **kwargs): + """Schedule a reply and send it using the WebAPI + + This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply_webapi`. + It behaves the same, but will send the reply at the scheduled time. + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance + :param attachments: optional attachments (see `attachments`_) + :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only + visible to the sender of the original message + :return: None + """ + in_thread = kwargs.get("in_thread", False) + ephemeral = kwargs.get("ephemeral", False) + if in_thread and not ephemeral: + text = self._create_reply(text) + + self.say_scheduled(when, text, **self._handle_context_args(**kwargs)) + + async def reply_dm(self, text, **kwargs): + """Reply to the sender of the original message with a DM using the WebAPI + + Reply in a Direct Message to the sender of the original message by opening a DM channel and + sending a message to it via the WebAPI. Allows for rich formatting using + `attachments`_. + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param text: message text + :param attachments: optional attachments (see `attachments`_) + :return: Dictionary deserialized from `chat.postMessage`_ request. + + .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage + """ + return await self._client.send_dm( + self.user_id, text, **self._handle_context_args(**kwargs) + ) + + def reply_dm_scheduled(self, when, text, **kwargs): + """Schedule a DM reply and send it using the WebAPI + + This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply_dm_webapi`. + It behaves the same, but will send the DM at the scheduled time. + + .. _attachments: https://api.slack.com/docs/message-attachments + + :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance + :param text: message text + :param attachments: optional attachments (see `attachments`_) + :return: None + """ + self._client.send_dm_scheduled( + when, self.user_id, text, **self._handle_context_args(**kwargs) + ) + + async def react(self, emoji): + """React to the original message + + Add a reaction to the original message + + :param emoji: what emoji to react with (should be a string, like 'angel', 'thumbsup', etc.) + :return: Dictionary deserialized from `reactions.add`_ request. + + .. _reactions.add: https://api.slack.com/methods/reactions.add + """ + return await self._client.react(self.channel_id, self.thread_ts, emoji) + + def _create_reply(self, text): + if not self.is_dm: + return "{}: {}".format(self.at_sender, text) + else: + return text + + def _handle_context_args(self, **kwargs): + """ Given **kwargs from `say` and friends, turn certain contextual + arguments (ie., `ephemeral`, `in_thread`) into context-free args + to send to Slack. + """ + + next_kwargs = {} + + if kwargs.pop("ephemeral", False): + next_kwargs["ephemeral_user"] = self.user_id + + if kwargs.pop("in_thread", False): + next_kwargs["thread_ts"] = self.thread_ts + + if "ephemeral_user" in next_kwargs and "thread_ts" in next_kwargs: + raise ValueError("Messages may be in-thread or ephemeral, not both") + + next_kwargs.update(kwargs) + return next_kwargs + + def __str__(self): + return "Message '{}', sent by user @{} in channel #{}".format( + self.text, self.user_id, self.channel_id + ) + + def __repr__(self): + return "Message(text={}, sender={}, channel={})".format( + repr(self.text), repr(self.user_id), repr(self.channel_id) + ) diff --git a/machine/plugins/base.py b/machine/plugins/base.py index 02ea49a6..768c0d62 100644 --- a/machine/plugins/base.py +++ b/machine/plugins/base.py @@ -1,3 +1,6 @@ +# -*- coding: utf-8 -*- + +from aiohttp.web import Application from blinker import signal @@ -22,7 +25,7 @@ def __init__(self, settings, client, storage): self.settings = settings self._fq_name = "{}.{}".format(self.__module__, self.__class__.__name__) - def init(self): + def init(self, http_app: Application): """Initialize plugin This method can be implemented by concrete plugin classes. It will be called **once** @@ -30,12 +33,14 @@ def init(self): ``self.settings``, and access storage through ``self.storage``, but the Slack client has not been initialized yet, so you cannot send or process messages during initialization. + This method can be specified as either synchronous or asynchronous, depending on the needs + of the plugin. + :return: None """ pass - @property - def users(self): + async def get_users(self): """Dictionary of all users in the Slack workspace :return: a dictionary of all users in the Slack workspace, where the key is the user id and @@ -44,10 +49,9 @@ def users(self): .. _User: https://github.com/slackapi/python-slackclient/blob/master/slackclient/user.py """ - return self._client.users + return await self._client.get_users() - @property - def channels(self): + async def get_channels(self): """List of all channels in the Slack workspace This is a list of all channels in the Slack workspace that the bot is aware of. This @@ -59,7 +63,7 @@ def channels(self): .. _Channel: https://github.com/slackapi/python-slackclient/blob/master/slackclient/channel.py # NOQA """ - return self._client.channels + return await self._client.get_channels() def retrieve_bot_info(self): """Information about the bot user in Slack @@ -71,7 +75,7 @@ def retrieve_bot_info(self): """ return self._client.retrieve_bot_info() - def at(self, user): + def at(self, user: dict): """Create a mention of the provided user Create a mention of the provided user in the form of ``<@[user_id]>``. This method is @@ -82,42 +86,11 @@ def at(self, user): :param user: user your want to mention :return: user mention """ - return self._client.fmt_mention(user) - - def say(self, channel, text, thread_ts=None): - """Send a message to a channel - - Send a message to a channel using the RTM API. Only `basic Slack formatting`_ allowed. - For richer formatting using attachments, use - :py:meth:`~machine.plugins.base.MachineBasePlugin.say_webapi` - - .. _basic Slack formatting: https://api.slack.com/docs/message-formatting - - :param channel: id or name of channel to send message to. Can be public or private (group) - channel, or DM channel. - :param text: message text - :param thread_ts: optional timestamp of thread, to send a message in that thread - :return: None - """ - self._client.send(channel, text, thread_ts) - - def say_scheduled(self, when, channel, text): - """Schedule a message to a channel - - This is the scheduled version of - :py:meth:`~machine.plugins.base.MachineBasePlugin.say`. - It behaves the same, but will send the message at the scheduled time. - - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param channel: id or name of channel to send message to. Can be public or private (group) - channel, or DM channel. - :param text: message text - :return: None - """ - self._client.send_scheduled(when, channel, text) + return self._client.fmt_mention(user["id"]) - def say_webapi(self, channel, text, attachments=None, thread_ts=None, ephemeral_user=None): + async def say( + self, channel, text, attachments=None, thread_ts=None, ephemeral_user=None + ): """Send a message to a channel using the WebAPI Send a message to a channel using the WebAPI. Allows for rich formatting using @@ -141,9 +114,11 @@ def say_webapi(self, channel, text, attachments=None, thread_ts=None, ephemeral_ .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage .. _chat.postEphemeral: https://api.slack.com/methods/chat.postEphemeral """ - return self._client.send_webapi(channel, text, attachments, thread_ts, ephemeral_user) + return await self._client.send_webapi( + channel, text, attachments, thread_ts, ephemeral_user + ) - def say_webapi_scheduled(self, when, channel, text, attachments, ephemeral_user): + def say_scheduled(self, when, channel, text, attachments, ephemeral_user): """Schedule a message to a channel and send it using the WebAPI This is the scheduled version of @@ -157,9 +132,9 @@ def say_webapi_scheduled(self, when, channel, text, attachments, ephemeral_user) to a specific user only :return: None """ - self._client.send_webapi_scheduled(when, channel, text, attachments, ephemeral_user) + self._client.send_scheduled(when, channel, text, attachments, ephemeral_user) - def react(self, channel, ts, emoji): + async def react(self, channel, ts, emoji): """React to a message in a channel Add a reaction to a message in a channel. What message to react to, is determined by the @@ -173,38 +148,9 @@ def react(self, channel, ts, emoji): .. _reactions.add: https://api.slack.com/methods/reactions.add """ - return self._client.react(channel, ts, emoji) - - def send_dm(self, user, text): - """Send a Direct Message - - Send a Direct Message to a user by opening a DM channel and sending a message to it. - Only `basic Slack formatting`_ allowed. For richer formatting using attachments, use - :py:meth:`~machine.plugins.base.MachineBasePlugin.send_dm_webapi` - - .. _basic Slack formatting: https://api.slack.com/docs/message-formatting - - :param user: id or name of the user to send direct message to - :param text: message text - :return: None - """ - self._client.send_dm(user, text) - - def send_dm_scheduled(self, when, user, text): - """Schedule a Direct Message - - This is the scheduled version of - :py:meth:`~machine.plugins.base.MachineBasePlugin.send_dm`. It behaves the same, but will - send the DM at the scheduled time. - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param user: id or name of the user to send direct message to - :param text: message text - :return: None - """ - self._client.send_dm_scheduled(when, user, text) + return await self._client.react(channel, ts, emoji) - def send_dm_webapi(self, user, text, attachments=None): + async def send_dm(self, user, text, attachments=None): """Send a Direct Message through the WebAPI Send a Direct Message to a user by opening a DM channel and sending a message to it via @@ -220,9 +166,9 @@ def send_dm_webapi(self, user, text, attachments=None): .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage """ - return self._client.send_dm_webapi(user, text, attachments) + return self._client.send_dm(user, text, attachments) - def send_dm_webapi_scheduled(self, when, user, text, attachments=None): + def send_dm_scheduled(self, when, user, text, attachments=None): """Schedule a Direct Message and send it using the WebAPI This is the scheduled version of @@ -236,7 +182,7 @@ def send_dm_webapi_scheduled(self, when, user, text, attachments=None): :param attachments: optional attachments (see `attachments`_) :return: None """ - self._client.send_dm_webapi_scheduled(when, user, text, attachments) + self._client.send_dm_scheduled(when, user, text, attachments) def emit(self, event, **kwargs): """Emit an event @@ -250,315 +196,3 @@ def emit(self, event, **kwargs): """ e = signal(event) e.send(self, **kwargs) - - -class Message: - """A message that was received by the bot - - This class represents a message that was received by the bot and passed to one or more - plugins. It contains the message (text) itself, and metadata about the message, such as the - sender of the message, the channel the message was sent to. - - The ``Message`` class also contains convenience methods for replying to the message in the - right channel, replying to the sender, etc. - """ - - def __init__(self, client, msg_event, plugin_class_name): - self._client = client - self._msg_event = msg_event - self._fq_plugin_name = plugin_class_name - - @property - def sender(self): - """The sender of the message - - :return: the User the message was sent by - """ - return self._client.users.find(self._msg_event['user']) - - @property - def channel(self): - """The channel the message was sent to - - :return: the Channel the message was sent to - """ - return self._client.channels.find(self._msg_event['channel']) - - @property - def is_dm(self): - chan = self._msg_event['channel'] - return not (chan.startswith('C') or chan.startswith('G')) - - @property - def text(self): - """The body of the actual message - - :return: the body (text) of the actual message - """ - return self._msg_event['text'] - - @property - def at_sender(self): - """The sender of the message formatted as mention - - :return: a string representation of the sender of the message, formatted as `mention`_, - to be used in messages - - .. _mention: https://api.slack.com/docs/message-formatting#linking_to_channels_and_users - """ - return self._client.fmt_mention(self.sender) - - def say(self, text, thread_ts=None): - """Send a new message to the channel the original message was received in - - Send a new message to the channel the original message was received in, using the RTM API. - Only `basic Slack formatting`_ allowed. For richer formatting using attachments, use - :py:meth:`~machine.plugins.base.Message.say_webapi` - - .. _basic Slack formatting: https://api.slack.com/docs/message-formatting - - :param text: message text - :param thread_ts: optional timestamp of thread, to send a message in that thread - :return: None - """ - self._client.send(self.channel.id, text, thread_ts) - - def say_scheduled(self, when, text): - """Schedule a message - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.say`. - It behaves the same, but will send the message at the scheduled time. - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param text: message text - :return: None - """ - self._client.send_scheduled(when, self.channel.id, text) - - def say_webapi(self, text, attachments=None, thread_ts=None, ephemeral=False): - """Send a new message using the WebAPI to the channel the original message was received in - - Send a new message to the channel the original message was received in, using the WebAPI. - Allows for rich formatting using `attachments`_. Can also reply to a thread and send an - ephemeral message only visible to the sender of the original message. - Ephemeral messages and threaded messages are mutually exclusive, and ``ephemeral`` - takes precedence over ``thread_ts`` - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param text: message text - :param attachments: optional attachments (see `attachments`_) - :param thread_ts: optional timestamp of thread, to send a message in that thread - :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only - visible to the sender of the original message - :return: Dictionary deserialized from `chat.postMessage`_ request, or `chat.postEphemeral`_ - if `ephemeral` is True. - - .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage - .. _chat.postEphemeral: https://api.slack.com/methods/chat.postEphemeral - """ - if ephemeral: - ephemeral_user = self.sender.id - else: - ephemeral_user = None - - return self._client.send_webapi( - self.channel.id, - text, - attachments, - thread_ts, - ephemeral_user, - ) - - def say_webapi_scheduled(self, when, text, attachments=None, ephemeral=False): - """Schedule a message and send it using the WebAPI - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.say_webapi`. - It behaves the same, but will send the DM at the scheduled time. - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param text: message text - :param attachments: optional attachments (see `attachments`_) - :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only - visible to the sender of the original message - :return: None - """ - if ephemeral: - ephemeral_user = self.sender.id - else: - ephemeral_user = None - - self._client.send_webapi_scheduled(when, self.channel.id, text, attachments, ephemeral_user) - - def reply(self, text, in_thread=False): - """Reply to the sender of the original message - - Reply to the sender of the original message with a new message, mentioning that user. - - :param text: message text - :param in_thread: ``True/False`` wether to reply to the original message in-thread - :return: None - """ - if in_thread: - self.say(text, thread_ts=self.thread_ts) - else: - text = self._create_reply(text) - self.say(text) - - def reply_scheduled(self, when, text): - """Schedule a reply - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply`. - It behaves the same, but will send the reply at the scheduled time. - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param text: message text - :return: None - """ - self.say_scheduled(when, self._create_reply(text)) - - def reply_webapi(self, text, attachments=None, in_thread=False, ephemeral=False): - """Reply to the sender of the original message using the WebAPI - - Reply to the sender of the original message with a new message, mentioning that user. Uses - the WebAPI, so rich formatting using `attachments`_ is possible. Can also reply to a thread - and send an ephemeral message only visible to the sender of the original message. - Ephemeral messages and threaded messages are mutually exclusive, and ``ephemeral`` - takes precedence over ``thread_ts`` - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param text: message text - :param attachments: optional attachments (see `attachments`_) - :param in_thread: ``True/False`` wether to reply to the original message in-thread - :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only - visible to the sender of the original message - :return: Dictionary deserialized from `chat.postMessage`_ request, or `chat.postEphemeral`_ - if `ephemeral` is True. - - .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage - .. _chat.postEphemeral: https://api.slack.com/methods/chat.postEphemeral - """ - if in_thread and not ephemeral: - return self.say_webapi(text, attachments=attachments, thread_ts=self.thread_ts) - else: - text = self._create_reply(text) - return self.say_webapi(text, attachments=attachments, ephemeral=ephemeral) - - def reply_webapi_scheduled(self, when, text, attachments=None, ephemeral=False): - """Schedule a reply and send it using the WebAPI - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply_webapi`. - It behaves the same, but will send the reply at the scheduled time. - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param attachments: optional attachments (see `attachments`_) - :param ephemeral: ``True/False`` wether to send the message as an ephemeral message, only - visible to the sender of the original message - :return: None - """ - self.say_webapi_scheduled(when, self._create_reply(text), attachments, ephemeral) - - def reply_dm(self, text): - """Reply to the sender of the original message with a DM - - Reply in a Direct Message to the sender of the original message by opening a DM channel and - sending a message to it. - - :param text: message text - :return: None - """ - self._client.send_dm(self.sender.id, text) - - def reply_dm_scheduled(self, when, text): - """Schedule a DM reply - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply_dm`. It - behaves the same, but will send the DM at the scheduled time. - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param text: message text - :return: None - """ - self._client.send_dm_scheduled(when, self.sender.id, text) - - def reply_dm_webapi(self, text, attachments=None): - """Reply to the sender of the original message with a DM using the WebAPI - - Reply in a Direct Message to the sender of the original message by opening a DM channel and - sending a message to it via the WebAPI. Allows for rich formatting using - `attachments`_. - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param text: message text - :param attachments: optional attachments (see `attachments`_) - :return: Dictionary deserialized from `chat.postMessage`_ request. - - .. _chat.postMessage: https://api.slack.com/methods/chat.postMessage - """ - return self._client.send_dm_webapi(self.sender.id, text, attachments) - - def reply_dm_webapi_scheduled(self, when, text, attachments=None): - """Schedule a DM reply and send it using the WebAPI - - This is the scheduled version of :py:meth:`~machine.plugins.base.Message.reply_dm_webapi`. - It behaves the same, but will send the DM at the scheduled time. - - .. _attachments: https://api.slack.com/docs/message-attachments - - :param when: when you want the message to be sent, as :py:class:`datetime.datetime` instance - :param text: message text - :param attachments: optional attachments (see `attachments`_) - :return: None - """ - self._client.send_dm_webapi_scheduled(when, self.sender.id, text, attachments) - - def react(self, emoji): - """React to the original message - - Add a reaction to the original message - - :param emoji: what emoji to react with (should be a string, like 'angel', 'thumbsup', etc.) - :return: Dictionary deserialized from `reactions.add`_ request. - - .. _reactions.add: https://api.slack.com/methods/reactions.add - """ - return self._client.react(self.channel.id, self._msg_event['ts'], emoji) - - def _create_reply(self, text): - if not self.is_dm: - return "{}: {}".format(self.at_sender, text) - else: - return text - - @property - def thread_ts(self): - """The timestamp of the original message - - :return: the timestamp of the original message - """ - try: - thread_ts = self._msg_event['thread_ts'] - except KeyError: - thread_ts = self._msg_event['ts'] - - return thread_ts - - def __str__(self): - return "Message '{}', sent by user @{} in channel #{}".format( - self.text, - self.sender.name, - self.channel.name - ) - - def __repr__(self): - return "Message(text={}, sender={}, channel={})".format( - repr(self.text), - repr(self.sender.name), - repr(self.channel.name) - ) diff --git a/machine/plugins/builtin/debug.py b/machine/plugins/builtin/debug.py index f8d20e92..09a23846 100644 --- a/machine/plugins/builtin/debug.py +++ b/machine/plugins/builtin/debug.py @@ -1,19 +1,18 @@ -import logging +# -*- coding: utf-8 -*- + +from loguru import logger + from machine.plugins.base import MachineBasePlugin from machine.plugins.decorators import process -logger = logging.getLogger(__name__) - class EventLoggerPlugin(MachineBasePlugin): - - def catch_all(self, event): + async def catch_all(self, event): logger.debug("Event received: %s", event) class EchoPlugin(MachineBasePlugin): - - @process(slack_event_type='message') - def echo_message(self, event): + @process(slack_event_type="message") + async def echo_message(self, event): logger.debug("Message received: %s", event) - self.say(event['channel'], event['text']) + await self.say(event["channel"], event["text"]) diff --git a/machine/plugins/builtin/fun/images.py b/machine/plugins/builtin/fun/images.py deleted file mode 100644 index c1dd8307..00000000 --- a/machine/plugins/builtin/fun/images.py +++ /dev/null @@ -1,59 +0,0 @@ -import random -import requests -import logging -from machine.plugins.base import MachineBasePlugin -from machine.plugins.decorators import respond_to, required_settings - -logger = logging.getLogger(__name__) - - -@required_settings(['GOOGLE_CSE_ID', 'GOOGLE_API_KEY']) -class ImageSearchPlugin(MachineBasePlugin): - """Images""" - - @respond_to(r'(?:image|img)(?: me)? (?P.+)') - def image_me(self, msg, query): - """image/img (me) : find a random image""" - results = self._search(query.strip()) - if results: - url = random.choice(results) - msg.say(url) - else: - msg.say("Couldn't find any results for ''! :cry:".format(query)) - - @respond_to(r'animate(?: me)? (?P.+)') - def animate_me(self, msg, query): - """animate (me) : find a random gif""" - results = self._search(query.strip(), animated=True) - if results: - url = random.choice(results) - msg.say(url) - else: - msg.say("Couldn't find any results for ''! :cry:".format(query)) - - def _search(self, query, animated=False): - query_params = { - 'cx': self.settings['GOOGLE_CSE_ID'], - 'key': self.settings['GOOGLE_API_KEY'], - 'q': query, - 'searchType': 'image', - 'fields': 'items(link)', - 'safe': 'high' - } - - if animated: - query_params.update({ - 'fileType': 'gif', - 'hq': 'animated', - 'tbs': 'itp:animated' - }) - r = requests.get('https://www.googleapis.com/customsearch/v1', params=query_params) - if r.ok: - response = r.json() - results = [result["link"] for result in response["items"] if "items" in response] - else: - logger.warning("An error occurred while searching! Status code: %s, response: %s" % ( - r.status_code, r.text - )) - results = [] - return results diff --git a/machine/plugins/builtin/fun/memes.py b/machine/plugins/builtin/fun/memes.py index 0212fd49..526b3f96 100644 --- a/machine/plugins/builtin/fun/memes.py +++ b/machine/plugins/builtin/fun/memes.py @@ -1,4 +1,6 @@ -import requests +# -*- coding: utf-8 -*- + +from aiohttp.client import ClientSession from machine.plugins.base import MachineBasePlugin from machine.plugins.decorators import respond_to from machine.plugins.builtin.fun.regexes import url_regex @@ -7,65 +9,69 @@ class MemePlugin(MachineBasePlugin): """Images""" - @respond_to(r'meme (?P\S+) (?P.+);(?P.+)') - def meme(self, msg, meme, top, bottom): + @respond_to(r"meme (?P\S+) (?P.+);(?P.+)") + async def meme(self, msg, meme, top, bottom): """meme ;: generate a meme""" - character_replacements = { - '?': '~q', - '&': '~p', - '#': '~h', - '/': '~s', - "''": '"' - } - query_string = '?font={}'.format(self._font) + character_replacements = {"?": "~q", "&": "~p", "#": "~h", "/": "~s", "''": '"'} + query_string = "?font={}".format(self._font) for original, replacement in character_replacements.items(): - top = top.replace(original, replacement) - bottom = bottom.replace(original, replacement) + top = top.replace(original, replacement) if top else "" + bottom = bottom.replace(original, replacement) if bottom else "" + match = url_regex.match(meme) if match: - query_string = query_string + '&alt=' + match.group('url') - meme = 'custom' - path = '{}/{}/{}/{}.jpg{}'.format(self._base_url, meme, top.strip(), - bottom.strip(), query_string).replace(' ', '-') - msg.say(path) + query_string = query_string + "&alt=" + match.group("url") + meme = "custom" + path = "{}/{}/{}/{}.jpg{}".format( + self._base_url, meme, top.strip(), bottom.strip(), query_string + ).replace(" ", "-") + await msg.say(path) else: - path = '/{}/{}/{}'.format(meme, top.strip(), bottom.strip()).replace(' ', '-') - status, meme_info = self._memegen_api_request(path) + path = "/{}/{}/{}".format( + meme.strip(), top.strip(), bottom.strip() + ).replace(" ", "-") + status, meme_info = await self._memegen_api_request(path) if 200 <= status < 400: - msg.say(meme_info['direct']['masked'] + query_string) + await msg.say(meme_info["direct"]["masked"] + query_string) elif status == 404: - msg.say( - "I don't know that meme. Use `list memes` to see what memes I have available") + await msg.say( + "I don't know that meme. Use `list memes` to see what memes I have available" + ) else: - msg.say("Ooooops! Something went wrong :cry:") + await msg.say("Ooooops! Something went wrong :cry:") - @respond_to(r'list memes') - def list_memes(self, msg): + @respond_to(r"list (dank )?(memes|maymays)") + async def list_memes(self, msg): """list memes: list all the available meme templates""" ephemeral = not msg.is_dm - status, templates = self._memegen_api_request('/api/templates/') + status, templates = await self._memegen_api_request("/api/templates/") if 200 <= status < 400: message = "*You can choose from these memes:*\n\n" + "\n".join( - ["\t_{}_: '{}'".format(url.rsplit('/', 1)[1], description) for description, url in - templates.items()] + [ + "\t_{}_: '{}'".format(url.rsplit("/", 1)[1], description) + for description, url in templates.items() + ] ) - msg.say_webapi(message, ephemeral=ephemeral) + await msg.say(message, ephemeral=ephemeral) else: - msg.say_webapi("It seems I cannot find the memes you're looking for :cry:", - ephemeral=ephemeral) + await msg.reply( + "It seems I cannot find the memes you're looking for :cry:", + in_thread=True, + ) - def _memegen_api_request(self, path): + async def _memegen_api_request(self, path): url = self._base_url + path.lower() - r = requests.get(url) - if r.ok: - return r.status_code, r.json() - else: - return r.status_code, None + async with ClientSession() as session: + async with session.get(url) as resp: + if resp.reason == "OK": + return resp.status, await resp.json() + else: + return resp.status, None @property def _base_url(self): - return self.settings.get('MEMEGEN_URL', 'https://memegen.link') + return self.settings.get("MEMEGEN_URL", "https://memegen.link") @property def _font(self): - return self.settings.get('MEMEGEN_FONT', 'impact') + return self.settings.get("MEMEGEN_FONT", "impact") diff --git a/machine/plugins/builtin/general.py b/machine/plugins/builtin/general.py index 06783a5c..9d527b7c 100644 --- a/machine/plugins/builtin/general.py +++ b/machine/plugins/builtin/general.py @@ -1,31 +1,32 @@ -import logging +# -*- coding: utf-8 -*- + +from loguru import logger + from machine.plugins.base import MachineBasePlugin from machine.plugins.decorators import listen_to, respond_to -logger = logging.getLogger(__name__) - class PingPongPlugin(MachineBasePlugin): """Playing Ping Pong""" - @listen_to(r'^ping$') - def listen_to_ping(self, msg): + @listen_to(r"^ping$") + async def listen_to_ping(self, msg): """ping: serving the ball""" - logger.debug("Ping received with msg: %s", msg) - msg.say("pong") + logger.debug("Ping received with msg: {}", msg) + await msg.say("pong") - @listen_to(r'^pong$') - def listen_to_pong(self, msg): + @listen_to(r"^pong$") + async def listen_to_pong(self, msg): """pong: returning the ball""" - logger.debug("Pong received with msg: %s", msg) - msg.say("ping") + logger.debug("Pong received with msg: {}", msg) + await msg.say("ping") class HelloPlugin(MachineBasePlugin): """Greetings""" - @respond_to(r'^(?Phi|hello)') - def greet(self, msg, greeting): + @respond_to(r"^(?Phi|hello)") + async def greet(self, msg, greeting): """hi/hello: say hello to the little guy""" - logger.debug("Greeting '%s' received", greeting) - msg.say("{}, {}!".format(greeting.title(), msg.at_sender)) + logger.debug("Greeting '{}' received", greeting) + await msg.say("{}, {}!".format(greeting.title(), msg.at_sender)) diff --git a/machine/plugins/builtin/help.py b/machine/plugins/builtin/help.py index 9e085749..388778a1 100644 --- a/machine/plugins/builtin/help.py +++ b/machine/plugins/builtin/help.py @@ -1,39 +1,45 @@ -import logging +# -*- coding: utf-8 -*- + from machine.plugins.base import MachineBasePlugin from machine.plugins.decorators import respond_to -logger = logging.getLogger(__name__) - class HelpPlugin(MachineBasePlugin): """Getting Help""" - @respond_to(r'^help$') - def help(self, msg): + @respond_to(r"^help$") + async def help(self, msg): """help: display this help text""" - manual = self.storage.get('manual', shared=True)['human'] + manual = (await self.storage.get("manual", shared=True))["human"] help_text = "This is what I can respond to:\n\n" - help_text += "\n\n".join([self._gen_class_help_text(cls, fn) - for cls, fn in manual.items() if fn]) - msg.say(help_text) + help_text += "\n\n".join( + [self._gen_class_help_text(cls, fn) for cls, fn in manual.items() if fn] + ) + await msg.say(help_text) - @respond_to(r'^robot help$') - def robot_help(self, msg): + @respond_to(r"^robot help$") + async def robot_help(self, msg): "robot help: display regular expressions that the bot responds to" - robot_manual = self.storage.get('manual', shared=True)['robot'] + robot_manual = (await self.storage.get("manual", shared=True))["robot"] help_text = "This is what triggers me:\n\n" - help_text += "\n\n".join([self._gen_class_robot_help(cls, regexes) - for cls, regexes in robot_manual.items()]) - msg.say(help_text) + help_text += "\n\n".join( + [ + self._gen_class_robot_help(cls, regexes) + for cls, regexes in robot_manual.items() + ] + ) + await msg.say(help_text) def _gen_class_help_text(self, class_help, fn_helps): help_text = "*{}:*\n".format(class_help) - fn_help_texts = "\n".join([self._gen_help_text(fn_help) for fn_help in fn_helps.values()]) + fn_help_texts = "\n".join( + [self._gen_help_text(fn_help) for fn_help in fn_helps.values()] + ) help_text += fn_help_texts return help_text def _gen_help_text(self, fn_help): - return "\t*{}:* {}".format(fn_help['command'], fn_help['help']) + return "\t*{}:* {}".format(fn_help["command"], fn_help["help"]) def _gen_class_robot_help(self, class_help, regexes): help_text = "*{}:*\n".format(class_help) @@ -42,5 +48,5 @@ def _gen_class_robot_help(self, class_help, regexes): return help_text def _gen_bot_regex(self, regex): - bot_name = self.retrieve_bot_info()['name'] + bot_name = self.retrieve_bot_info()["name"] return "\t`{}`".format(regex.replace("@botname", "@" + bot_name)) diff --git a/machine/settings.py b/machine/settings.py index 345c0e93..4b2022c4 100644 --- a/machine/settings.py +++ b/machine/settings.py @@ -1,25 +1,27 @@ +# -*- coding: utf-8 -*- + import os -import logging from importlib import import_module -from machine.utils.collections import CaseInsensitiveDict -logger = logging.getLogger(__name__) +from machine.utils.collections import CaseInsensitiveDict -def import_settings(settings_module='local_settings'): +def import_settings(settings_module="local_settings"): default_settings = { - 'PLUGINS': ['machine.plugins.builtin.general.PingPongPlugin', - 'machine.plugins.builtin.general.HelloPlugin', - 'machine.plugins.builtin.help.HelpPlugin', - 'machine.plugins.builtin.fun.memes.MemePlugin'], - 'STORAGE_BACKEND': 'machine.storage.backends.memory.MemoryStorage', - 'DISABLE_HTTP': False, - 'HTTP_SERVER_HOST': '0.0.0.0', - 'HTTP_SERVER_PORT': 8080, - 'HTTP_SERVER_BACKEND': 'wsgiref', - 'HTTP_PROXY': '', - 'HTTPS_PROXY': '', - 'KEEP_ALIVE': None + "PLUGINS": [ + "machine.plugins.builtin.general.PingPongPlugin", + "machine.plugins.builtin.general.HelloPlugin", + "machine.plugins.builtin.help.HelpPlugin", + "machine.plugins.builtin.fun.memes.MemePlugin", + ], + "STORAGE_BACKEND": "machine.storage.backends.memory.MemoryStorage", + "DISABLE_HTTP": False, + "HTTP_SERVER_HOST": "0.0.0.0", + "HTTP_SERVER_PORT": 8080, + "HTTP_SERVER_BACKEND": "wsgiref", + "HTTP_PROXY": "", + "HTTPS_PROXY": "", + "KEEP_ALIVE": None, } settings = CaseInsensitiveDict(default_settings) try: @@ -29,11 +31,11 @@ def import_settings(settings_module='local_settings'): found_local_settings = False else: for k in dir(local_settings): - if not k.startswith('_'): + if not k.startswith("_"): settings[k] = getattr(local_settings, k) for k, v in os.environ.items(): - if k[:3] == 'SM_': + if k[:3] == "SM_": k = k[3:] settings[k] = v diff --git a/machine/singletons.py b/machine/singletons.py index 4dedecc9..f6b61d5e 100644 --- a/machine/singletons.py +++ b/machine/singletons.py @@ -1,49 +1,88 @@ -from apscheduler.schedulers.background import BackgroundScheduler -from slackclient import SlackClient +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import asyncio + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from slack import RTMClient, WebClient + from machine.settings import import_settings from machine.utils import Singleton from machine.utils.module_loading import import_string +from machine.utils.readonly_proxy import ReadonlyProxy from machine.utils.redis import gen_config_dict class Slack(metaclass=Singleton): - def __init__(self): + __slots__ = "_login_data", "_rtm_client", "_web_client" + + def __init__(self, loop=None): + if not loop: + loop = asyncio.get_event_loop() + _settings, _ = import_settings() - slack_api_token = _settings.get('SLACK_API_TOKEN', None) - http_proxy = _settings.get('HTTP_PROXY', None) - https_proxy = _settings.get('HTTPS_PROXY', None) - proxies = {'http': http_proxy, 'https': https_proxy} + slack_api_token = _settings.get("SLACK_API_TOKEN", None) + http_proxy = _settings.get("HTTP_PROXY", None) + https_proxy = _settings.get("HTTPS_PROXY", None) - self._client = SlackClient(slack_api_token, proxies=proxies) if slack_api_token else None + self._login_data = None + self._rtm_client = RTMClient( + token=slack_api_token, + run_async=True, + auto_reconnect=True, + proxy=https_proxy or http_proxy or None, + loop=loop, + ) + self._web_client = WebClient(slack_api_token, run_async=True, loop=loop) - def __getattr__(self, item): - return getattr(self._client, item) + @RTMClient.run_on(event="open") + def _store_login_data(**payload): + self._login_data = payload["data"] + + @property + def login_data(self) -> ReadonlyProxy[dict]: + return ReadonlyProxy(self._login_data or {}) + + @property + def rtm(self) -> ReadonlyProxy[RTMClient]: + return ReadonlyProxy(self._rtm_client) + + @property + def web(self) -> ReadonlyProxy[WebClient]: + return ReadonlyProxy(self._web_client) @staticmethod - def get_instance(): + def get_instance() -> Slack: return Slack() class Scheduler(metaclass=Singleton): - def __init__(self): + """ Configures an `asyncio`-compatible scheduler instance. + """ + + def __init__(self, loop=None): + if not loop: + loop = asyncio.get_event_loop() + _settings, _ = import_settings() - self._scheduler = BackgroundScheduler() - if 'REDIS_URL' in _settings: + self._scheduler = AsyncIOScheduler(event_loop=loop) + if "REDIS_URL" in _settings: redis_config = gen_config_dict(_settings) - self._scheduler.add_jobstore('redis', **redis_config) + self._scheduler.add_jobstore("redis", **redis_config) def __getattr__(self, item): return getattr(self._scheduler, item) @staticmethod - def get_instance(): + def get_instance() -> Scheduler: return Scheduler() class Storage(metaclass=Singleton): def __init__(self): _settings, _ = import_settings() - _, cls = import_string(_settings['STORAGE_BACKEND'])[0] + _, cls = import_string(_settings["STORAGE_BACKEND"])[0] self._storage = cls(_settings) def __getattr__(self, item): diff --git a/machine/slack.py b/machine/slack.py index 15480640..1e9d23b4 100644 --- a/machine/slack.py +++ b/machine/slack.py @@ -1,113 +1,111 @@ -from machine.singletons import Slack, Scheduler +# -*- coding: utf-8 -*- +from datetime import datetime +from typing import Optional, Sequence -class MessagingClient: - @property - def users(self): - return Slack.get_instance().server.users - - @property - def channels(self): - return Slack.get_instance().server.channels +from async_lru import alru_cache +from slack.web.slack_response import SlackResponse - @staticmethod - def retrieve_bot_info(): - return Slack.get_instance().server.login_data['self'] +from machine.singletons import Scheduler, Slack +from machine.utils.aio import run_coro_until_complete - def fmt_mention(self, user): - u = self.users.find(user) - return "<@{}>".format(u.id) +class MessagingClient: @staticmethod - def send(channel, text, thread_ts=None): - Slack.get_instance().rtm_send_message(channel, text, thread_ts) - - def send_scheduled(self, when, channel, text): - args = [self, channel, text] - kwargs = {'thread_ts': None} - - Scheduler.get_instance().add_job(trigger='date', args=args, - kwargs=kwargs, run_date=when) + def retrieve_bot_info() -> Optional[dict]: + return Slack.get_instance().login_data.get("self") @staticmethod - def send_webapi(channel, text, attachments=None, thread_ts=None, ephemeral_user=None): - method = 'chat.postMessage' - - # This is the only way to conditionally add thread_ts - kwargs = { - 'channel': channel, - 'text': text, - 'attachments': attachments, - 'as_user': True + async def send( + channel_id: str, + text: str, + *, + attachments: Optional[Sequence[dict]] = None, + thread_ts: Optional[str] = None, + ephemeral_user: Optional[str] = None, + ) -> SlackResponse: + method = "chat.postEphemeral" if ephemeral_user else "chat.postMessage" + payload = { + "channel": channel_id, + "text": text, + "blocks": attachments, + "as_user": True, } if ephemeral_user: - method = 'chat.postEphemeral' - kwargs['user'] = ephemeral_user + payload["user"] = ephemeral_user else: if thread_ts: - kwargs['thread_ts'] = thread_ts - - return Slack.get_instance().api_call( - method, - **kwargs - ) - - def send_webapi_scheduled(self, when, channel, text, attachments=None, ephemeral_user=None): - args = [self, channel, text] - kwargs = { - 'attachments': attachments, - 'thread_ts': None, - 'ephemeral_user': ephemeral_user - } + payload["thread_ts"] = thread_ts - Scheduler.get_instance().add_job(trigger='date', args=args, - kwargs=kwargs, run_date=when) + return await Slack.get_instance().web.api_call(method, json=payload) @staticmethod - def react(channel, ts, emoji): - return Slack.get_instance().api_call( - 'reactions.add', - name=emoji, - channel=channel, - timestamp=ts - ) + async def react(channel_id: str, ts: str, emoji: str) -> SlackResponse: + payload = {"name": emoji, "channel": channel_id, "timestamp": ts} + + return await Slack.get_instance().web.api_call("reactions.add", json=payload) @staticmethod - def open_im(user): - response = Slack.get_instance().api_call( - 'im.open', - user=user + async def open_im(user_id: str) -> str: + response = await Slack.get_instance().web.api_call( + "im.open", json={"user": user_id} ) + return response["channel"]["id"] - return response['channel']['id'] - - def send_dm(self, user, text): - u = self.users.find(user) - dm_channel = self.open_im(u.id) - - self.send(dm_channel, text) - - def send_dm_scheduled(self, when, user, text): - args = [self, user, text] - Scheduler.get_instance().add_job(MessagingClient.send_dm, trigger='date', args=args, - run_date=when) - - def send_dm_webapi(self, user, text, attachments=None): - u = self.users.find(user) - dm_channel = self.open_im(u.id) + @property + def channels(self) -> SlackResponse: + return run_coro_until_complete(self.get_channels()) - return Slack.get_instance().api_call( - 'chat.postMessage', - channel=dm_channel, - text=text, - attachments=attachments, - as_user=True + @property + def users(self) -> SlackResponse: + return run_coro_until_complete(self.get_users()) + + async def get_channels(self) -> SlackResponse: + return await Slack.get_instance().web.channels_list() + + @alru_cache(maxsize=32) + async def find_channel_by_id(self, channel_id: str) -> Optional[dict]: + for channel in (await self.get_channels())["channels"]: + if channel["id"] == channel_id: + return channel + + return None + + async def get_users(self) -> SlackResponse: + return list(await Slack.get_instance().web.users_list()) + + @alru_cache(maxsize=32) + async def find_user_by_id(self, user_id: str) -> Optional[dict]: + for user in (await self.get_users())["members"]: + if user["id"] == user_id: + return user + + return None + + def fmt_mention(self, user: dict) -> str: + return f"<@{user['id']}>" + + def send_scheduled(self, when: datetime, channel_id: str, text: str, **kwargs): + args = [channel_id, text] + Scheduler.get_instance().add_job( + MessagingClient.send, + trigger="date", + args=args, + kwargs=kwargs, + run_date=when, ) - def send_dm_webapi_scheduled(self, when, user, text, attachments=None): - args = [self, user, text] - kwargs = {'attachments': attachments} - - Scheduler.get_instance().add_job(MessagingClient.send_dm_webapi, trigger='data', args=args, - kwargs=kwargs) + async def send_dm(self, user_id: str, text: str, **kwargs) -> SlackResponse: + dm_channel = await self.open_im(user_id) + return await self.send(dm_channel, text=text, **kwargs) + + def send_dm_scheduled(self, when: datetime, user_id: str, text: str, **kwargs): + args = [self, user_id, text] + Scheduler.get_instance().add_job( + MessagingClient.send_dm, + trigger="date", + args=args, + kwargs=kwargs, + run_date=when, + ) diff --git a/machine/storage/__init__.py b/machine/storage/__init__.py index 75e92453..d3de6d00 100644 --- a/machine/storage/__init__.py +++ b/machine/storage/__init__.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import dill from machine.singletons import Storage @@ -14,6 +15,7 @@ class PluginStorage: .. _Dill: https://pypi.python.org/pypi/dill """ + def __init__(self, fq_plugin_name): self._fq_plugin_name = fq_plugin_name @@ -23,7 +25,7 @@ def _gen_unique_key(self, key): def _namespace_key(self, key, shared): return key if shared else self._gen_unique_key(key) - def set(self, key, value, expires=None, shared=False): + async def set(self, key, value, expires=None, shared=False): """Store or update a value by key :param key: the key under which to store the data @@ -34,9 +36,9 @@ def set(self, key, value, expires=None, shared=False): """ namespaced_key = self._namespace_key(key, shared) pickled_value = dill.dumps(value) - Storage.get_instance().set(namespaced_key, pickled_value, expires) + await Storage.get_instance().set(namespaced_key, pickled_value, expires) - def get(self, key, shared=False): + async def get(self, key, shared=False): """Retrieve data by key :param key: key for the data to retrieve @@ -44,13 +46,13 @@ def get(self, key, shared=False): :return: the data, or ``None`` if the key cannot be found/has expired """ namespaced_key = self._namespace_key(key, shared) - value = Storage.get_instance().get(namespaced_key) + value = await Storage.get_instance().get(namespaced_key) if value: return dill.loads(value) else: return None - def has(self, key, shared=False): + async def has(self, key, shared=False): """Check if the key exists in storage Note: this class implements ``__contains__`` so instead of calling @@ -64,9 +66,9 @@ def has(self, key, shared=False): expired. """ namespaced_key = self._namespace_key(key, shared) - return Storage.get_instance().has(namespaced_key) + return await Storage.get_instance().has(namespaced_key) - def delete(self, key, shared=False): + async def delete(self, key, shared=False): """Remove a key and its data from storage :param key: key to remove @@ -74,22 +76,19 @@ def delete(self, key, shared=False): namespace """ namespaced_key = self._namespace_key(key, shared) - Storage.get_instance().delete(namespaced_key) + await Storage.get_instance().delete(namespaced_key) - def get_storage_size(self): + async def get_storage_size(self): """Calculate the total size of the storage :return: the total size of the storage in bytes (integer) """ - return Storage.get_instance().size() + return await Storage.get_instance().size() - def get_storage_size_human(self): + async def get_storage_size_human(self): """Calculate the total size of the storage in human readable format :return: the total size of the storage in a human readable string, rounded to the nearest applicable division. eg. B for Bytes, KiB for Kilobytes, MiB for Megabytes etc. """ - return sizeof_fmt(self.get_storage_size()) - - def __contains__(self, key): - return self.has(key, False) + return sizeof_fmt(await self.get_storage_size()) diff --git a/machine/storage/backends/base.py b/machine/storage/backends/base.py index ec66d2d2..2534598e 100644 --- a/machine/storage/backends/base.py +++ b/machine/storage/backends/base.py @@ -11,7 +11,14 @@ class MachineBaseStorage: def __init__(self, settings): self.settings = settings - def get(self, key): + async def connect(self): + """ Used when initializing asynchronous libraries (ie., aioredis) + that need to be awaited to connect or create clients (ex: `aioredis.create_pool`) + """ + + raise NotImplementedError + + async def get(self, key): """Retrieve data by key :param key: key for which to retrieve data @@ -20,7 +27,7 @@ def get(self, key): """ raise NotImplementedError - def set(self, key, value, expires=None): + async def set(self, key, value, expires=None): """Store data by key :param key: the key under which to store the data @@ -30,14 +37,14 @@ def set(self, key, value, expires=None): """ raise NotImplementedError - def delete(self, key): + async def delete(self, key): """Delete data by key :param key: key for which to delete the data """ raise NotImplementedError - def has(self, key): + async def has(self, key): """Check if the key exists :param key: key to check @@ -45,7 +52,7 @@ def has(self, key): """ raise NotImplementedError - def size(self): + async def size(self): """Calculate the total size of the storage :return: total size of storage in bytes (integer) diff --git a/machine/storage/backends/hbase.py b/machine/storage/backends/hbase.py deleted file mode 100644 index a22f44eb..00000000 --- a/machine/storage/backends/hbase.py +++ /dev/null @@ -1,61 +0,0 @@ -from datetime import datetime, timedelta - -from happybase import Connection -from machine.storage.backends.base import MachineBaseStorage - - -def bytes_to_float(byte_arr): - s = byte_arr.decode('utf-8') - return float(s) - - -def float_to_bytes(i): - return bytes(str(i), 'utf-8') - - -class HBaseStorage(MachineBaseStorage): - - _VAL = b'values:value' - _EXP = b'values:expires_at' - _COLS = [_VAL, _EXP] - - def __init__(self, settings): - super().__init__(settings) - hbase_host = settings['HBASE_HOST'] - hbase_table = settings['HBASE_TABLE'] - self._connection = Connection(hbase_host) - self._table = self._connection.table(hbase_table) - - def _get_value(self, key): - row = self._table.row(key, self._COLS) - val = row.get(self._VAL) - if val: - exp = row.get(self._EXP) - if not exp: - return val - elif datetime.fromtimestamp(bytes_to_float(exp)) > datetime.utcnow(): - return val - else: - self.delete(key) - return None - return None - - def has(self, key): - val = self._get_value(key) - return bool(val) - - def get(self, key): - return self._get_value(key) - - def set(self, key, value, expires=None): - data = {self._VAL: value} - if expires: - expires_at = datetime.utcnow() + timedelta(seconds=expires) - data[self._EXP] = float_to_bytes(expires_at.timestamp()) - self._table.put(key, data) - - def delete(self, key): - self._table.delete(key) - - def size(self): - return 0 diff --git a/machine/storage/backends/memory.py b/machine/storage/backends/memory.py index 1dc4eec9..b94dd122 100644 --- a/machine/storage/backends/memory.py +++ b/machine/storage/backends/memory.py @@ -9,7 +9,10 @@ def __init__(self, settings): super().__init__(settings) self._storage = {} - def get(self, key): + async def connect(self): + pass + + async def get(self, key): stored = self._storage.get(key, None) if not stored: return None @@ -20,14 +23,14 @@ def get(self, key): else: return stored[0] - def set(self, key, value, expires=None): + async def set(self, key, value, expires=None): if expires: expires_at = datetime.utcnow() + timedelta(seconds=expires) else: expires_at = None self._storage[key] = (value, expires_at) - def has(self, key): + async def has(self, key): stored = self._storage.get(key, None) if not stored: return False @@ -38,8 +41,8 @@ def has(self, key): else: return True - def delete(self, key): + async def delete(self, key): del self._storage[key] - def size(self): + async def size(self): return sys.getsizeof(self._storage) # pragma: no cover diff --git a/machine/storage/backends/redis.py b/machine/storage/backends/redis.py index 4e15bb98..7b338b36 100644 --- a/machine/storage/backends/redis.py +++ b/machine/storage/backends/redis.py @@ -1,31 +1,58 @@ -from redis import StrictRedis +# -*- coding: utf-8 -*- + +import aioredis from machine.storage.backends.base import MachineBaseStorage -from machine.utils.redis import gen_config_dict class RedisStorage(MachineBaseStorage): def __init__(self, settings): super().__init__(settings) - self._key_prefix = settings.get('REDIS_KEY_PREFIX', 'SM') - redis_config = gen_config_dict(settings) - self._redis = StrictRedis(**redis_config) + self._redis_url = settings.get("REDIS_URL", "redis://localhost:6379") + self._max_connections = settings.get("REDIS_MAX_CONNECTIONS", 10) + self._key_prefix = settings.get("REDIS_KEY_PREFIX", "SM") + self._redis = None + + async def connect(self): + self._redis = await aioredis.create_redis_pool( + self._redis_url, maxsize=self._max_connections + ) + + def _ensure_connected(self): + if self._redis is None: + raise NotConnectedError() def _prefix(self, key): return "{}:{}".format(self._key_prefix, key) - def has(self, key): - return self._redis.exists(self._prefix(key)) + async def has(self, key): + self._ensure_connected() + return await self._redis.exists(self._prefix(key)) + + async def get(self, key): + self._ensure_connected() + return await self._redis.get(self._prefix(key)) + + async def set(self, key, value, expires=None): + self._ensure_connected() + await self._redis.set(self._prefix(key), value, expires) + + async def delete(self, key): + self._ensure_connected() + await self._redis.delete(self._prefix(key)) + + async def size(self): + self._ensure_connected() + info = await self._redis.info("memory") + return info["used_memory"] - def get(self, key): - return self._redis.get(self._prefix(key)) - def set(self, key, value, expires=None): - self._redis.set(self._prefix(key), value, expires) +class NotConnectedError(Exception): + def __init__(self): + super().__init__() + self.message = "RedisStorage backend must be `connect()`ed before using" - def delete(self, key): - self._redis.delete(self._prefix(key)) + def __repr__(self): + return self.message - def size(self): - info = self._redis.info('memory') - return info['used_memory'] + __str__ = __repr__ diff --git a/machine/utils/aio.py b/machine/utils/aio.py new file mode 100644 index 00000000..b3f650ce --- /dev/null +++ b/machine/utils/aio.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +import asyncio +import concurrent +from functools import partial +from typing import Any, Callable, Coroutine, List, Optional, Sequence, Tuple, Type + + +async def join(tasks: Sequence[Coroutine]) -> List[Any]: + """ Execute all of the coroutines, returning a list of responses + or exceptions. + """ + + async def wrapper(future, idx): + try: + output = await future + return idx, output + except Exception as err: + return idx, err + + output = [None] * len(tasks) + tasks = [wrapper(future, idx) for idx, future in enumerate(tasks, start=0)] + + for future in asyncio.as_completed(tasks): + idx, result = await future + output[idx] = result + + return output + + +async def split(tasks: Sequence[Coroutine]) -> Tuple[List[Any], List[Type[Exception]]]: + """ Execute all of the coroutines, returning a list of responses + and a list of exceptions. + """ + + results = await join(tasks) + + return_values = [] + exceptions = [] + + for res in results: + if isinstance(res, Exception): + exceptions.append(res) + else: + return_values.append(res) + + return return_values, exceptions + + +def run_coro_until_complete( + coro: Coroutine, loop: Optional[asyncio.AbstractEventLoop] = None +) -> Any: + if loop is None: + loop = asyncio.get_event_loop() + + task = loop.create_task(coro) + loop.run_until_complete(task) + + return task.result() + + +def run_in_threadpool(func: Callable[..., Any]) -> Callable[..., Any]: + """ Makes any calls to a synchronous function happen in a + `concurrent.futures.ThreadPoolExecutor` so that synchronous calls + are not blocking the event loop. + + The wrapper function will await the `asyncio.Future` provided by + `submit_to_threadpool`, which returns the wrapped return value directly + to the caller. + """ + + async def inner(*args, **kw) -> Any: + fn = partial(func, *args, **kw) + fut = await submit_to_threadpool(fn) + return await fut + + return inner + + +async def submit_to_threadpool( + fn: Callable[..., Any], + loop: Optional[asyncio.AbstractEventLoop] = None, + executor: Optional[concurrent.futures.Executor] = None, +) -> asyncio.Future: + """ Given a curried function, run it inside the given `loop` with `executor`, or, + get the current event loop, create a new `concurrent.futures.ThreadPoolExecutor` + with `max_workers=2`, and run the function there. + + Will not accept any function arguments or keyword arguments. Use `functools.partial` + to curry the input function. + + Returns an `asyncio.Future` object which can be awaited. + """ + + if loop is None: + loop = asyncio.get_event_loop() + + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + + return loop.run_in_executor(executor, fn) + + +def build_executor(*args, **kw) -> concurrent.futures.ThreadPoolExecutor: + return concurrent.futures.ThreadPoolExecutor(*args, **kw) diff --git a/machine/utils/collections.py b/machine/utils/collections.py index f8b17ec9..af7ccb82 100644 --- a/machine/utils/collections.py +++ b/machine/utils/collections.py @@ -1,7 +1,9 @@ -import collections +# -*- coding: utf-8 -*- +from collections.abc import Mapping, MutableMapping -class CaseInsensitiveDict(collections.MutableMapping): + +class CaseInsensitiveDict(MutableMapping): """ A case-insensitive ``dict``-like object. Implements all methods and operations of @@ -53,14 +55,10 @@ def __len__(self): def lower_items(self): """Like iteritems(), but with all lowercase keys.""" - return ( - (lowerkey, keyval[1]) - for (lowerkey, keyval) - in self._store.items() - ) + return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) def __eq__(self, other): - if isinstance(other, collections.Mapping): + if isinstance(other, Mapping): other = CaseInsensitiveDict(other) else: return NotImplemented @@ -72,4 +70,4 @@ def copy(self): return CaseInsensitiveDict(self._store.values()) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, dict(self.items())) + return "%s(%r)" % (self.__class__.__name__, dict(self.items())) diff --git a/machine/utils/log_propagate.py b/machine/utils/log_propagate.py new file mode 100644 index 00000000..fc35dd9e --- /dev/null +++ b/machine/utils/log_propagate.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +import logging + +from loguru import logger + + +class InterceptHandler(logging.Handler): + def emit(self, record): + # Retrieve context where the logging call occurred, this happens to be in the 6th frame upward + logger_opt = logger.opt(depth=6, exception=record.exc_info) + logger_opt.log(record.levelno, record.getMessage()) + + +def install(): + logging.basicConfig(handlers=[InterceptHandler()], level=0) diff --git a/machine/utils/pool.py b/machine/utils/pool.py deleted file mode 100644 index 23f59899..00000000 --- a/machine/utils/pool.py +++ /dev/null @@ -1,45 +0,0 @@ -from queue import Queue -from threading import Thread -import logging - -logger = logging.getLogger(__name__) - - -class Worker(Thread): - """ Thread executing tasks from a given tasks queue """ - - def __init__(self, queue): - Thread.__init__(self) - self.queue = queue - self.daemon = True - self.start() - - def run(self): - while True: - func, args, kargs = self.queue.get() - try: - func(*args, **kargs) - except Exception as ex: - # An exception happened in this thread - logger.exception("An error occurred while performing work", exc_info=ex) - finally: - # Mark this task as done, whether an exception happened or not - self.queue.task_done() - - -class ThreadPool: - """ Pool of threads consuming tasks from a queue """ - - def __init__(self, num_threads=10): - self.queue = Queue(num_threads) - for _ in range(num_threads): - Worker(self.queue) - - def add_task(self, func, *args, **kargs): - """ Add a task to the queue """ - self.queue.put((func, args, kargs)) - - def map(self, func, args_list): - """ Add a list of tasks to the queue """ - for args in args_list: - self.add_task(func, args) diff --git a/machine/utils/readonly_proxy.py b/machine/utils/readonly_proxy.py new file mode 100644 index 00000000..1b4387f7 --- /dev/null +++ b/machine/utils/readonly_proxy.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from typing import Any, Generic, T + + +class ReadonlyProxy(Generic[T]): + __slots__ = "_target" + + def __init__(self, target: T): + self._target = target + + def __getitem__(self, item: Any) -> Any: + _target = getattr(self, "_target") + return _target[item] + + def __getattr__(self, item: str) -> Any: + return getattr(self._target, item) diff --git a/machine/utils/redis.py b/machine/utils/redis.py index 095dec91..5e835b0a 100644 --- a/machine/utils/redis.py +++ b/machine/utils/redis.py @@ -1,17 +1,20 @@ +# -*- coding: utf-8 -*- + from urllib.parse import urlparse def gen_config_dict(settings): - url = urlparse(settings['REDIS_URL']) - if hasattr(url, 'path') and getattr(url, 'path'): + url = urlparse(settings["REDIS_URL"]) + if hasattr(url, "path") and getattr(url, "path"): db = url.path[1:] else: db = 0 - max_connections = settings.get('REDIS_MAX_CONNECTIONS', None) + + max_connections = settings.get("REDIS_MAX_CONNECTIONS", None) return { - 'host': url.hostname, - 'port': url.port, - 'db': db, - 'password': url.password, - 'max_connections': max_connections + "host": url.hostname, + "port": url.port, + "db": db, + "password": url.password, + "max_connections": max_connections, } diff --git a/machine/utils/text.py b/machine/utils/text.py deleted file mode 100644 index 11a35d6f..00000000 --- a/machine/utils/text.py +++ /dev/null @@ -1,21 +0,0 @@ -from clint.textui import puts, colored - - -def show_valid(valid_str): - puts(colored.green(f"✓ {valid_str}")) - - -def show_invalid(valid_str): - puts(colored.red(f"✗ {valid_str}")) - - -def warn(warn_string): - puts(colored.yellow(f"Warning: {warn_string}")) - - -def error(err_string): - puts(colored.red(f"ERROR: {err_string}")) - - -def announce(string): - puts(colored.cyan(string)) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1433b94e..ae912d9c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,15 +1,17 @@ -r requirements.txt + +Cython==0.29.13 +Sphinx==2.2.1 check-manifest==0.39 +coverage==4.5.4 +flake8==3.7.9 +flake8-bugbear==19.8.0 pyroma==2.5 +pytest-cov==2.7.1 +pytest-html==1.21.1 +pytest-metadata==1.8.0 pytest-mock==1.11.2 -pytest==5.2.0 +pytest==5.2.2 +sphinx-autobuild==0.7.1 tox==3.13.1 -flake8==3.7.9 twine==2.0.0 -coverage==4.5.4 -pytest-cov==2.7.1 -Sphinx==2.2.1 -sphinx-autobuild==0.7.1 -redis==3.3.11 -Cython==0.29.13 -happybase==1.2.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 84de577e..9bf495e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -slackclient==1.3.1 -dill==0.3.1.1 +aiohttp==3.6.2 +aioredis==1.3.0 apscheduler==3.6.1 +async_lru==1.0.2 blinker-alt==1.5 -clint==0.5.1 -bottle==0.12.17 \ No newline at end of file +dill==0.3.1.1 +loguru==0.3.2 +slackclient==2.3.1 diff --git a/setup.cfg b/setup.cfg index ac767f46..88ebb966 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,13 +8,12 @@ exclude = dist, build, tests -max-line-length=100 +max-line-length = 100 +select = C,E,F,W,B,B950 +ignore = E501,W503 [aliases] -test=pytest +test = pytest [tool:pytest] addopts = --cov-config .coveragerc --verbose --cov-report term-missing --cov=machine -filterwarnings= - ignore:invalid escape sequence::bottle - ignore:Using or importing the ABCs::bottle \ No newline at end of file diff --git a/setup.py b/setup.py index 0cef97f0..a1d0524f 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from codecs import open import os import sys @@ -7,14 +8,14 @@ here = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(here, 'README.rst'), 'r', 'utf-8') as f: +with open(os.path.join(here, "README.rst"), "r", "utf-8") as f: long_description = f.read() -with open(os.path.join(here, 'machine', '__about__.py'), 'r', 'utf-8') as f: +with open(os.path.join(here, "machine", "__about__.py"), "r", "utf-8") as f: about = {} exec(f.read(), about) -with open(os.path.join(here, 'requirements.txt'), 'r', 'utf-8') as f: +with open(os.path.join(here, "requirements.txt"), "r", "utf-8") as f: dependencies = f.read().splitlines() @@ -24,13 +25,13 @@ class PublishCommand(Command): Graciously taken from https://github.com/kennethreitz/setup.py """ - description = 'Build and publish the package.' + description = "Build and publish the package." user_options = [] @staticmethod def status(s): """Prints things in bold.""" - print('\033[1m{0}\033[0m'.format(s)) + print("\033[1m{0}\033[0m".format(s)) def initialize_options(self): pass @@ -41,10 +42,10 @@ def finalize_options(self): def _remove_builds(self, msg): try: self.status(msg) - rmtree(os.path.join(here, 'dist')) - rmtree(os.path.join(here, 'build')) - rmtree(os.path.join(here, '.egg')) - rmtree(os.path.join(here, 'slack_machine.egg-info')) + rmtree(os.path.join(here, "dist")) + rmtree(os.path.join(here, "build")) + rmtree(os.path.join(here, ".egg")) + rmtree(os.path.join(here, "slack_machine.egg-info")) except FileNotFoundError: pass @@ -55,10 +56,10 @@ def run(self): pass self.status("Building Source and Wheel (universal) distribution…") - os.system('{0} setup.py sdist bdist_wheel'.format(sys.executable)) + os.system("{0} setup.py sdist bdist_wheel".format(sys.executable)) self.status("Uploading the package to PyPi via Twine…") - os.system('twine upload dist/*') + os.system("twine upload dist/*") self._remove_builds("Removing builds…") @@ -74,37 +75,35 @@ def run(self): url=about["__uri__"], author=about["__author__"], author_email=about["__email__"], - setup_requires=['pytest-runner'], - tests_require=['pytest', 'pytest-cov', 'coverage'], + setup_requires=["pytest-runner"], + tests_require=[ + "pytest", + "pytest-cov", + "pytest-html", + "pytest-metadata", + "pytest-mock", + "coverage", + ], install_requires=dependencies, - python_requires='~=3.3', - extras_require={ - 'redis': ['redis', 'hiredis'], - 'hbase': ['Cython==0.29.6', 'happybase'] - }, + python_requires="~=3.3", + extras_require={"redis": ["aioredis"]}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Natural Language :: English", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3 :: Only", "Topic :: Communications :: Chat", "Topic :: Internet", - "Topic :: Office/Business" + "Topic :: Office/Business", ], - keywords='slack bot framework ai', - entry_points={ - 'console_scripts': [ - 'slack-machine = machine.bin.run:main', - ], - }, + keywords="slack bot framework ai", + entry_points={"console_scripts": ["slack-machine = machine.bin.run:main"]}, packages=find_packages(), include_package_data=True, zip_safe=False, - cmdclass={ - 'publish': PublishCommand, - } + cmdclass={"publish": PublishCommand}, ) diff --git a/tests/fake_plugins.py b/tests/fake_plugins.py index f03daa3c..f48da27a 100644 --- a/tests/fake_plugins.py +++ b/tests/fake_plugins.py @@ -1,26 +1,20 @@ +# -*- coding: utf-8 -*- from machine.plugins.base import MachineBasePlugin from machine.plugins.decorators import respond_to, listen_to, process class FakePlugin(MachineBasePlugin): + async def init(self, http_app): + self.x = 42 - @respond_to(r'hello') - def respond_function(self, msg): - pass - - @listen_to(r'hi') - def listen_function(self, msg): + @respond_to(r"hello") + async def respond_function(self, msg): pass - @process('some_event') - def process_function(self, event): + @listen_to(r"hi") + async def listen_function(self, msg): pass - -class FakePlugin2(MachineBasePlugin): - - def init(self): - self.x = 42 - - def catch_all(self, event): + @process("some_event") + async def process_function(self, event): pass diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 00000000..e4afa3ad --- /dev/null +++ b/tests/helpers/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from tests.helpers import expect +from tests.helpers.aio import async_test, coroutine_mock, make_coroutine_mock diff --git a/tests/helpers/aio.py b/tests/helpers/aio.py new file mode 100644 index 00000000..fc60ca2d --- /dev/null +++ b/tests/helpers/aio.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +import asyncio +from functools import wraps +from typing import Any, Awaitable, Callable, NewType +from unittest.mock import Mock + +CoroutineFunction = NewType("CoroutineFunction", Callable[..., Awaitable]) + + +def async_test(fn: CoroutineFunction): + """ Wrapper around test methods, to allow them to be run async """ + + @wraps(fn) + def wrapper(*args, **kwargs): + coro = asyncio.coroutine(fn) + future = coro(*args, **kwargs) + loop = asyncio.get_event_loop() + loop.run_until_complete(future) + + return wrapper + + +def coroutine_mock(): + """ Usable as a mock callable for patching async functions. + From https://stackoverflow.com/a/32505333 + """ + + coro = Mock(name="CoroutineResult") + corofunc = Mock(name="CoroutineFunction", side_effect=asyncio.coroutine(coro)) + corofunc.coro = coro + return corofunc + + +def make_coroutine_mock(return_value: Any): + """ Returns an coroutine mock with the given return_value. + The returned item is ready to be `await`ed + """ + + mock = coroutine_mock() + mock.coro.return_value = return_value + return mock() diff --git a/tests/helpers/expect.py b/tests/helpers/expect.py new file mode 100644 index 00000000..2ded8509 --- /dev/null +++ b/tests/helpers/expect.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, List +from unittest.mock import ( + DEFAULT, + _Call, + MagicMock, + MagicMixin, + MagicProxy, + Mock, + _magics, + call, + patch as mock_patch, +) + +import pytest + +_NO_VALUE = object() + + +class Expectation(object): + """ An expectation is a mapping of a set of call arguments + to a series of return values and/or raisables. + """ + + def __init__(self, sig: _Call): + self.sig = sig + self._items = [] + self._always = _NO_VALUE + + def _set_always(self, value: Any) -> Expectation: + + self._items.clear() + self._always = value + return self + + def _append_item(self, value: Any) -> Expectation: + + self._items.append(value) + return self + + def _next_item(self) -> Any: + + if self._always is not _NO_VALUE: + return self._always + + try: + return self._items.pop(0) + except IndexError: + return _NO_VALUE + + def has_unused_items(self) -> bool: + """ If this expectation has calls that are unused, returns true. + If this expectation has an always value set, returns false. + """ + + if self._always is not _NO_VALUE: + return False + + if len(self._items) > 0: + return True + + return False + + def matches(self, *args, **kw) -> bool: + """ If a given set of args and keyword args matches + the represented call signature, returns True + """ + + if self.sig == call(*args, **kw): + return True + + return False + + def returns(self, value: Any, always: bool = False) -> Expectation: + """ When the `ExpectMock` is called with the arguments + given by `self.sig`, `value` is returned. + + If `always` is set to `True`, this `Expectation`'s + items will be replaced by the given value and will + never become exhausted. + """ + + if always: + return self._set_always(value) + + return self._append_item(value) + + def raises(self, value: Any, always: bool = False) -> Expectation: + """ When the `ExpectMock` is called with the arguments + given by `self.sig`, `value` is raised. + + If `always` is set to `True`, this `Expectation`'s + items will be replaced by the given value and will + never become exhausted. + """ + + if always: + return self._set_always(value) + + return self._append_item(value) + + +class ExpectMixin(object): + def __init__(self, ignore_unused_calls: bool = False): + self.__expectations = [] + self._ignore_unused_calls = ignore_unused_calls + + def expect(self, *args, **kw) -> Expectation: + this_call = call(*args, **kw) + expectation = Expectation(this_call) + self.__expectations.append(expectation) + return expectation + + def raise_for_unused_calls(self): + if self._ignore_unused_calls: + return + + for expectation in self.__expectations: + if expectation.has_unused_items(): + raise UnusedCallsError(self, expectation) + + def _expect_side_effect(self, *args, **kw) -> Any: + for ex in self.__expectations: + if ex.matches(*args, **kw): + item = ex._next_item() + if isinstance(item, Exception): + raise item + elif item is _NO_VALUE: + raise NoExpectationForCall(*args, **kw) + else: + return item + else: + raise NoExpectationForCall(*args, **kw) + + +class ExpectMock(ExpectMixin, Mock): + """ `ExpectMock` is an expectation-based mock builder, built on + top of the stdlib `unittest.mock` framework. + + `ExpectMock` aims to provide a `mockito`-esque interface to + expecting call arguments paired with returns and/or raises. + + --- + + With vanilla `unittest.mock`, you would have to use something + like this example to map a call to a return value or exception, + and expect that the call happens before teardown: + + def test_function(mocker): + calls = { + call(expected, call, args): expectedReturn, + call(other, call, args): Exception("message"), + } + + def dispatcher(*args, **kw): + sig = call(*args, **kw) + if sig in calls: + item = calls[sig] + if isinstance(sig, Exception): + raise item + return calls[sig] + + function = mocker.patch("some.module.function") + function.side_effect = dispatcher + + # Assume `caller_of_function` calls the mocked function + caller_of_function(my, arguments) + + [function.assert_has_call(sig) for sig in calls.keys()] + + --- + + With `ExpectMock`, the above exchange becomes far simpler: + + def test_function(expecter): + function = expecter.mock("some.module.function") + function.expect(expected, call, args).returns(expectedReturn) + function.expect(other, call, args).raises(Exception("message")) + + # Assume `caller_of_function` calls the mocked function + caller_of_function(my, arguments) + """ + + def __init__(self, *args, ignore_unused_calls: bool = False, **kw): + """ Initialize an `ExpectMock`. + """ + + Mock.__init__(self, *args, **kw) + ExpectMixin.__init__(self, ignore_unused_calls=ignore_unused_calls) + + self.side_effect = self._expect_side_effect + + +class ExpectMagicMock(ExpectMixin, MagicMock): + """ `ExpectMagicMock` is an extension of `ExpectMock` that has the + same properties as both `ExpectMock` and `MagicMock`, but will + also return new `ExpectMagicMock`s for magics. + """ + + def __init__(self, *args, ignore_unused_calls: bool = False, **kw): + """ Initialize an `ExpectMock`. + """ + + MagicMock.__init__(self, *args, **kw) + ExpectMixin.__init__(self, ignore_unused_calls=ignore_unused_calls) + + self.side_effect = self._expect_side_effect + + +class NoExpectationForCall(Exception): + def __init__(self, *args, **kw): + super().__init__(self) + this_call = call(*args, **kw) + self.message = f"ExpectMock has no expectation for call: {this_call}" + + def __repr__(self): + return self.message + + def __str__(self): + return repr(self) + + +class UnusedCallsError(Exception): + def __init__(self, mock: ExpectMixin, expectation: Expectation): + super().__init__(self) + unused_calls = map( + lambda item: f" {expectation.sig} -> {item}", expectation._items + ) + self.message = f"Mock {repr(mock)} has unused calls:\n{''.join(unused_calls)}" + + def __repr__(self): + return self.message + + def __str__(self): + return repr(self) + + +def patch( + target, + new=DEFAULT, + spec=None, + create=False, + spec_set=None, + autospec=None, + new_callable=None, + **kwargs, +): + if new_callable is not None: + if new is not DEFAULT: + raise ValueError("Cannot use 'new' and 'new_callable' together") + else: + if new is DEFAULT: + new = ExpectMagicMock + + return mock_patch( + target, + new=new, + spec=spec, + create=create, + spec_set=spec_set, + autospec=autospec, + new_callable=new_callable, + **kwargs, + ) + + +class ExpectMockFixture(object): + """ ExpectMockFixture provides a patching interface for test functions + and collects all created patches to ensure they are unstubbed after + the test has completed + + Based upon the `MockFixture` from the excellent `pytest-mock` plugin. + """ + + def __init__(self): + self._patches = [] + self._mocks = [] + + def reset_all(self): + for m in self._mocks: + m.reset_mock() + + def stop_all(self): + for p in reversed(self._patches): + p.stop() + + self._patches.clear() + self._mocks.clear() + + def check_for_unused_mock_calls(self): + for m in self._mocks: + m.raise_for_unused_calls() + + def patch(self, target, *args, **kw) -> ExpectMagicMock: + if "new_callable" not in kw: + kw["new_callable"] = ExpectMagicMock + + p = patch(target, *args, **kw) + self._patches.append(p) + m = p.start() + self._mocks.append(m) + return m + + def mock(self, *args, **kw) -> ExpectMagicMock: + m = ExpectMagicMock(*args, **kw) + self._mocks.append(m) + return m + + MagicMock = mock + ExpectMagicMock = mock + + +@pytest.yield_fixture +def expect(): + fixture = ExpectMockFixture() + yield fixture + try: + fixture.check_for_unused_mock_calls() + finally: + fixture.stop_all() diff --git a/tests/helpers/test_expect.py b/tests/helpers/test_expect.py new file mode 100644 index 00000000..bed0cdbc --- /dev/null +++ b/tests/helpers/test_expect.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +import time +from unittest.mock import call + +import pytest + +from tests.helpers.expect import ( + ExpectMagicMock, + ExpectMock, + ExpectMockFixture, + Expectation, + NoExpectationForCall, + UnusedCallsError, + expect, + patch, +) + + +def test_basic_expectation(): + mock = ExpectMock() + mock.expect("a").returns(1).returns(2) + mock.expect("b").returns(3).returns(4) + mock.expect("c").raises(Exception("C")).returns(5) + + assert mock("a") == 1 + assert mock("a") == 2 + with pytest.raises(NoExpectationForCall): + mock("a") + + assert mock("b") == 3 + assert mock("b") == 4 + with pytest.raises(NoExpectationForCall): + mock("b") + + with pytest.raises(Exception): + mock("c") + assert mock("c") == 5 + with pytest.raises(NoExpectationForCall): + mock("c") + + +def test_call_record(): + mock = ExpectMock() + mock.expect("a").returns(1).returns(2) + + assert mock("a") == 1 + mock.assert_has_calls([call("a")]) + + assert mock("a") == 2 + mock.assert_has_calls([call("a"), call("a")]) + + +def test_magic_child_instance(): + mock = ExpectMagicMock() + assert isinstance(mock.some_function, ExpectMagicMock) + assert isinstance(mock.__len__, ExpectMagicMock) + + +def test_patch(): + with patch("time.sleep", new_callable=ExpectMock) as sleep: + sleep.expect(1).returns(True) + assert time.sleep(1) == True + + +def test_always(): + mock = ExpectMagicMock() + mock.expect("a").returns(1, always=True) + mock.expect("b").returns(2, always=False) + assert mock("a") == 1 + assert mock("b") == 2 + assert mock("a") == 1 + with pytest.raises(NoExpectationForCall): + mock("b") + + +def test_expect_fixture(expect: ExpectMockFixture): + assert isinstance(expect, ExpectMockFixture) + sleep = expect.patch("time.sleep") + sleep.expect(1).returns(True) + assert time.sleep(1) == True + + expect.stop_all() + assert time.sleep(0.1) == None + + +@pytest.mark.xfail(raises=UnusedCallsError, strict=True) +def test_raises_for_unused_calls(expect: ExpectMockFixture): + sleep = expect.patch("time.sleep") + sleep.expect(0.5).returns(True) + sleep.expect(1).returns(True) + assert time.sleep(0.5) + expect.check_for_unused_mock_calls() + + +def test_ignore_unused_calls(expect: ExpectMockFixture): + sleep = expect.patch("time.sleep", ignore_unused_calls=True) + sleep.expect(0.5).returns(True) + sleep.expect(1).returns(True) + assert time.sleep(0.5) + expect.check_for_unused_mock_calls() diff --git a/tests/local_test_settings.py b/tests/local_test_settings.py index ff1cdd0d..38b6b7a4 100644 --- a/tests/local_test_settings.py +++ b/tests/local_test_settings.py @@ -1,4 +1,5 @@ -SLACK_API_TOKEN = 'xoxo-abc123' +# -*- coding: utf-8 -*- +SLACK_API_TOKEN = "xoxo-abc123" ALIASES = "!,$" -MY_PLUGIN_SETTING = 'foobar' +MY_PLUGIN_SETTING = "foobar" _THIS_SHOULD_NOT_REGISTER = True diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 159496c8..a11e20bb 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -1,108 +1,98 @@ +# -*- coding: utf-8 -*- + import re +from unittest.mock import MagicMock import pytest from machine.slack import MessagingClient from machine.dispatch import EventDispatcher -from machine.plugins.base import Message +from machine.message import Message from machine.storage.backends.base import MachineBaseStorage -from tests.fake_plugins import FakePlugin, FakePlugin2 + +from tests.fake_plugins import FakePlugin +from tests.helpers import async_test @pytest.fixture def msg_client(mocker): - return mocker.MagicMock(spec=MessagingClient) + return MagicMock(spec=MessagingClient) @pytest.fixture def storage(mocker): - return mocker.MagicMock(spec=MachineBaseStorage) + return MagicMock(spec=MachineBaseStorage) @pytest.fixture def fake_plugin(mocker, msg_client, storage): plugin_instance = FakePlugin({}, msg_client, storage) - mocker.spy(plugin_instance, 'respond_function') - mocker.spy(plugin_instance, 'listen_function') - mocker.spy(plugin_instance, 'process_function') - return plugin_instance - - -@pytest.fixture -def fake_plugin2(mocker, msg_client, storage): - plugin_instance = FakePlugin2({}, msg_client, storage) - mocker.spy(plugin_instance, 'catch_all') + mocker.spy(plugin_instance, "respond_function") + mocker.spy(plugin_instance, "listen_function") + mocker.spy(plugin_instance, "process_function") return plugin_instance @pytest.fixture -def plugin_actions(fake_plugin, fake_plugin2): - respond_fn = getattr(fake_plugin, 'respond_function') - listen_fn = getattr(fake_plugin, 'listen_function') - process_fn = getattr(fake_plugin, 'process_function') - catch_all_fn = getattr(fake_plugin2, 'catch_all') +def plugin_actions(fake_plugin): + respond_fn = getattr(fake_plugin, "respond_function") + listen_fn = getattr(fake_plugin, "listen_function") + process_fn = getattr(fake_plugin, "process_function") plugin_actions = { - 'catch_all': { - 'TestPlugin2': { - 'class': fake_plugin2, - 'class_name': 'tests.fake_plugins.FakePlugin2', - 'function': catch_all_fn - } - }, - 'listen_to': { - 'TestPlugin.listen_function-hi': { - 'class': fake_plugin, - 'class_name': 'tests.fake_plugins.FakePlugin', - 'function': listen_fn, - 'regex': re.compile('hi', re.IGNORECASE) + "listen_to": { + "TestPlugin.listen_function-hi": { + "class": fake_plugin, + "class_name": "tests.fake_plugins.FakePlugin", + "function": listen_fn, + "regex": re.compile("hi", re.IGNORECASE), } }, - 'respond_to': { - 'TestPlugin.respond_function-hello': { - 'class': fake_plugin, - 'class_name': 'tests.fake_plugins.FakePlugin', - 'function': respond_fn, - 'regex': re.compile('hello', re.IGNORECASE) + "respond_to": { + "TestPlugin.respond_function-hello": { + "class": fake_plugin, + "class_name": "tests.fake_plugins.FakePlugin", + "function": respond_fn, + "regex": re.compile("hello", re.IGNORECASE), } }, - 'process': { - 'some_event': { - 'TestPlugin.process_function': { - 'class': fake_plugin, - 'class_name': 'tests.fake_plugins.FakePlugin', - 'function': process_fn + "process": { + "some_event": { + "TestPlugin.process_function": { + "class": fake_plugin, + "class_name": "tests.fake_plugins.FakePlugin", + "function": process_fn, } } - } + }, } return plugin_actions -@pytest.fixture(params=[None, {"ALIASES": "!"}, {"ALIASES": "!,$"}], ids=["No Alias", "Alias", "Aliases"]) +@pytest.fixture( + params=[None, {"ALIASES": "!"}, {"ALIASES": "!,$"}], + ids=["No Alias", "Alias", "Aliases"], +) def dispatcher(mocker, plugin_actions, request): - mocker.patch('machine.dispatch.ThreadPool', autospec=True) - mocker.patch('machine.singletons.SlackClient', autospec=True) - mocker.patch('machine.singletons.BackgroundScheduler', autospec=True) + mocker.patch("machine.singletons.Slack", autospec=True) + mocker.patch("machine.singletons.Scheduler", autospec=True) + dispatch_instance = EventDispatcher(plugin_actions, request.param) - mocker.patch.object(dispatch_instance, '_get_bot_id') - dispatch_instance._get_bot_id.return_value = '123' - mocker.patch.object(dispatch_instance, '_get_bot_name') - dispatch_instance._get_bot_name.return_value = 'superbot' + mocker.patch.object(dispatch_instance, "_get_bot_id") + + dispatch_instance._get_bot_id.return_value = "123" + mocker.patch.object(dispatch_instance, "_get_bot_name") + + dispatch_instance._get_bot_name.return_value = "superbot" dispatch_instance._aliases = request.param + return dispatch_instance -def test_handle_event_process(dispatcher, fake_plugin): - some_event = {'type': 'some_event'} - dispatcher.handle_event(some_event) +@async_test +async def test_handle_event_process(dispatcher, fake_plugin): + await dispatcher.handle_event("some_event", data={}) assert fake_plugin.process_function.call_count == 1 - fake_plugin.process_function.assert_called_once_with(some_event) - - -def test_handle_event_catch_all(dispatcher, fake_plugin2): - any_event = {'type': 'foobar'} - dispatcher.handle_event(any_event) - fake_plugin2.catch_all.assert_called_once_with(any_event) + fake_plugin.process_function.assert_called_once_with({}) def _assert_message(args, text): @@ -115,67 +105,67 @@ def _assert_message(args, text): assert args[0][0].text == text -def test_handle_event_listen_to(dispatcher, fake_plugin, fake_plugin2): - msg_event = {'type': 'message', 'text': 'hi', 'channel': 'C1', 'user': 'user1'} - dispatcher.handle_event(msg_event) - fake_plugin2.catch_all.assert_called_once_with(msg_event) +@async_test +async def test_handle_event_listen_to(dispatcher, fake_plugin): + msg_event = {"text": "hi", "channel": "C1", "user": "user1"} + await dispatcher.handle_event("message", data=msg_event) assert fake_plugin.listen_function.call_count == 1 assert fake_plugin.respond_function.call_count == 0 args = fake_plugin.listen_function.call_args - _assert_message(args, 'hi') + _assert_message(args, "hi") -def test_handle_event_respond_to(dispatcher, fake_plugin, fake_plugin2): - msg_event = {'type': 'message', 'text': '<@123> hello', 'channel': 'C1', 'user': 'user1'} - dispatcher.handle_event(msg_event) - fake_plugin2.catch_all.assert_called_once_with(msg_event) +@async_test +async def test_handle_event_respond_to(dispatcher, fake_plugin): + msg_event = {"text": "<@123> hello", "channel": "C1", "user": "user1"} + await dispatcher.handle_event("message", data=msg_event) assert fake_plugin.respond_function.call_count == 1 assert fake_plugin.listen_function.call_count == 0 args = fake_plugin.respond_function.call_args - _assert_message(args, 'hello') + _assert_message(args, "hello") def test_check_bot_mention(dispatcher): - normal_msg_event = {'text': 'hi', 'channel': 'C1'} + normal_msg_event = {"text": "hi", "channel": "C1"} event = dispatcher._check_bot_mention(normal_msg_event) assert event is None - mention_msg_event = {'text': '<@123> hi', 'channel': 'C1'} + mention_msg_event = {"text": "<@123> hi", "channel": "C1"} event = dispatcher._check_bot_mention(mention_msg_event) - assert event == {'text': 'hi', 'channel': 'C1'} + assert event == {"text": "hi", "channel": "C1"} - mention_msg_event_username = {'text': 'superbot: hi', 'channel': 'C1'} + mention_msg_event_username = {"text": "superbot: hi", "channel": "C1"} event = dispatcher._check_bot_mention(mention_msg_event_username) - assert event == {'text': 'hi', 'channel': 'C1'} + assert event == {"text": "hi", "channel": "C1"} - mention_msg_event_group = {'text': '<@123> hi', 'channel': 'G1'} + mention_msg_event_group = {"text": "<@123> hi", "channel": "G1"} event = dispatcher._check_bot_mention(mention_msg_event_group) - assert event == {'text': 'hi', 'channel': 'G1'} + assert event == {"text": "hi", "channel": "G1"} - mention_msg_event_other_user = {'text': '<@456> hi', 'channel': 'C1'} + mention_msg_event_other_user = {"text": "<@456> hi", "channel": "C1"} event = dispatcher._check_bot_mention(mention_msg_event_other_user) assert event is None - mention_msg_event_dm = {'text': 'hi', 'channel': 'D1'} + mention_msg_event_dm = {"text": "hi", "channel": "D1"} event = dispatcher._check_bot_mention(mention_msg_event_dm) - assert event == {'text': 'hi', 'channel': 'D1'} + assert event == {"text": "hi", "channel": "D1"} def test_check_bot_mention_alias(dispatcher): - mention_msg_event_no_alias_1 = {'text': '!hi', 'channel': 'C1'} + mention_msg_event_no_alias_1 = {"text": "!hi", "channel": "C1"} event = dispatcher._check_bot_mention(mention_msg_event_no_alias_1) - if dispatcher._aliases and '!' in dispatcher._aliases['ALIASES']: - assert event == {'text': 'hi', 'channel': 'C1'} + if dispatcher._aliases and "!" in dispatcher._aliases["ALIASES"]: + assert event == {"text": "hi", "channel": "C1"} else: assert event is None - mention_msg_event_no_alias_2 = {'text': '$hi', 'channel': 'C1'} + mention_msg_event_no_alias_2 = {"text": "$hi", "channel": "C1"} event = dispatcher._check_bot_mention(mention_msg_event_no_alias_2) - if dispatcher._aliases and '$' in dispatcher._aliases['ALIASES']: - assert event == {'text': 'hi', 'channel': 'C1'} - elif dispatcher._aliases and '!' in dispatcher._aliases['ALIASES']: + if dispatcher._aliases and "$" in dispatcher._aliases["ALIASES"]: + assert event == {"text": "hi", "channel": "C1"} + elif dispatcher._aliases and "!" in dispatcher._aliases["ALIASES"]: assert event is None else: assert event is None diff --git a/tests/test_hbase_storage.py b/tests/test_hbase_storage.py deleted file mode 100644 index 7d064276..00000000 --- a/tests/test_hbase_storage.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -from happybase import Table - -from machine.storage.backends.hbase import bytes_to_float, float_to_bytes, HBaseStorage - -_VAL = b'values:value' -_EXP = b'values:expires_at' -_COLS = [_VAL, _EXP] - - -@pytest.fixture -def table(mocker): - table = mocker.MagicMock(spec=Table) - ConnectionCls = mocker.patch('machine.storage.backends.hbase.Connection', autospec=True) - instance = ConnectionCls.return_value - instance.table.return_value = table - return table - - -@pytest.fixture -def hbase_storage(table): - return HBaseStorage({'HBASE_HOST': 'foo', 'HBASE_TABLE': 'bar'}) - - -def test_float_conversion(): - assert bytes_to_float(float_to_bytes(3.14159265359)) == 3.14159265359 - - -def test_get(table, hbase_storage): - hbase_storage.get('key1') - table.row.assert_called_with('key1', _COLS) - - -def test_has(table, hbase_storage): - hbase_storage.has('key1') - table.row.assert_called_with('key1', _COLS) - - -def test_delete(table, hbase_storage): - hbase_storage.delete('key1') - table.delete.assert_called_with('key1') - - -def test_set(table, hbase_storage): - hbase_storage.set('key1', 'val1') - table.put.assert_called_with('key1', {b'values:value': 'val1'}) diff --git a/tests/test_memory_storage.py b/tests/test_memory_storage.py index 3a7676ba..e4b52cbe 100644 --- a/tests/test_memory_storage.py +++ b/tests/test_memory_storage.py @@ -1,45 +1,57 @@ +# -*- coding: utf-8 -*- from datetime import datetime import pytest from machine.storage.backends.memory import MemoryStorage +from tests.helpers import async_test + @pytest.fixture def memory_storage(): return MemoryStorage({}) -def test_store_retrieve_values(memory_storage): +@async_test +async def test_store_retrieve_values(memory_storage): assert memory_storage._storage == {} - memory_storage.set("key1", "value1") + await memory_storage.set("key1", "value1") assert memory_storage._storage == {"key1": ("value1", None)} - assert memory_storage.get("key1") == "value1" + assert (await memory_storage.get("key1")) == "value1" -def test_delete_values(memory_storage): +@async_test +async def test_delete_values(memory_storage): assert memory_storage._storage == {} - memory_storage.set("key1", "value1") - memory_storage.set("key2", "value2") - assert memory_storage._storage == {"key1": ("value1", None), "key2": ("value2", None)} - memory_storage.delete("key2") + await memory_storage.set("key1", "value1") + await memory_storage.set("key2", "value2") + assert memory_storage._storage == { + "key1": ("value1", None), + "key2": ("value2", None), + } + await memory_storage.delete("key2") assert memory_storage._storage == {"key1": ("value1", None)} -def test_expire_values(memory_storage, mocker): +@async_test +async def test_expire_values(memory_storage, mocker): assert memory_storage._storage == {} - mocked_dt = mocker.patch('machine.storage.backends.memory.datetime', autospec=True) + mocked_dt = mocker.patch("machine.storage.backends.memory.datetime", autospec=True) mocked_dt.utcnow.return_value = datetime(2017, 1, 1, 12, 0, 0, 0) - memory_storage.set("key1", "value1", expires=15) - assert memory_storage._storage == {"key1": ("value1", datetime(2017, 1, 1, 12, 0, 15, 0))} - assert memory_storage.get("key1") == "value1" + await memory_storage.set("key1", "value1", expires=15) + assert memory_storage._storage == { + "key1": ("value1", datetime(2017, 1, 1, 12, 0, 15, 0)) + } + assert (await memory_storage.get("key1")) == "value1" mocked_dt.utcnow.return_value = datetime(2017, 1, 1, 12, 0, 20, 0) - assert memory_storage.get("key1") is None + assert (await memory_storage.get("key1")) is None -def test_inclusion(memory_storage): +@async_test +async def test_inclusion(memory_storage): assert memory_storage._storage == {} - memory_storage.set("key1", "value1") - assert memory_storage.has("key1") == True - memory_storage.delete("key1") - assert memory_storage.has("key1") == False + await memory_storage.set("key1", "value1") + assert (await memory_storage.has("key1")) == True + await memory_storage.delete("key1") + assert (await memory_storage.has("key1")) == False diff --git a/tests/test_plugin_registration.py b/tests/test_plugin_registration.py index d442b264..892143ef 100644 --- a/tests/test_plugin_registration.py +++ b/tests/test_plugin_registration.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import re import pytest @@ -6,24 +7,24 @@ from machine.utils.collections import CaseInsensitiveDict -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def settings(): settings = CaseInsensitiveDict() - settings['PLUGINS'] = ['tests.fake_plugins'] - settings['SLACK_API_TOKEN'] = 'xoxo-abc123' - settings['STORAGE_BACKEND'] = 'machine.storage.backends.memory.MemoryStorage' + settings["PLUGINS"] = ["tests.fake_plugins"] + settings["SLACK_API_TOKEN"] = "xoxo-abc123" + settings["STORAGE_BACKEND"] = "machine.storage.backends.memory.MemoryStorage" return settings -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def settings_with_required(settings): - settings['setting_1'] = 'foo' + settings["setting_1"] = "foo" return settings -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def required_settings_class(): - @required_settings(['setting_1', 'setting_2']) + @required_settings(["setting_1", "setting_2"]) class C: pass @@ -35,58 +36,68 @@ def test_load_and_register_plugins(settings): actions = machine._plugin_actions # Test general structure of _plugin_actions - assert set(actions.keys()) == {'process', 'listen_to', 'respond_to', 'catch_all'} + assert set(actions.keys()) == {"process", "listen_to", "respond_to"} # Test registration of process actions - assert 'some_event' in actions['process'] - assert 'tests.fake_plugins:FakePlugin.process_function' in actions['process']['some_event'] - assert 'class' in actions['process']['some_event'][ - 'tests.fake_plugins:FakePlugin.process_function'] - assert 'function' in actions['process']['some_event'][ - 'tests.fake_plugins:FakePlugin.process_function'] + assert "some_event" in actions["process"] + assert ( + "tests.fake_plugins:FakePlugin.process_function" + in actions["process"]["some_event"] + ) + assert ( + "class" + in actions["process"]["some_event"][ + "tests.fake_plugins:FakePlugin.process_function" + ] + ) + assert ( + "function" + in actions["process"]["some_event"][ + "tests.fake_plugins:FakePlugin.process_function" + ] + ) # Test registration of respond_to actions - respond_to_key = 'tests.fake_plugins:FakePlugin.respond_function-hello' - assert respond_to_key in actions['respond_to'] - assert 'class' in actions['respond_to'][respond_to_key] - assert 'function' in actions['respond_to'][respond_to_key] - assert 'regex' in actions['respond_to'][respond_to_key] - assert actions['respond_to'][respond_to_key]['regex'] == re.compile('hello', re.IGNORECASE) + respond_to_key = "tests.fake_plugins:FakePlugin.respond_function-hello" + assert respond_to_key in actions["respond_to"] + assert "class" in actions["respond_to"][respond_to_key] + assert "function" in actions["respond_to"][respond_to_key] + assert "regex" in actions["respond_to"][respond_to_key] + assert actions["respond_to"][respond_to_key]["regex"] == re.compile( + "hello", re.IGNORECASE + ) # Test registration of listen_to actions - listen_to_key = 'tests.fake_plugins:FakePlugin.listen_function-hi' - assert listen_to_key in actions['listen_to'] - assert 'class' in actions['listen_to'][listen_to_key] - assert 'function' in actions['listen_to'][listen_to_key] - assert 'regex' in actions['listen_to'][listen_to_key] - assert actions['listen_to'][listen_to_key]['regex'] == re.compile('hi', re.IGNORECASE) - - # Test registration of catch_all actions - assert 'tests.fake_plugins:FakePlugin2' in actions['catch_all'] - assert 'tests.fake_plugins:FakePlugin' not in actions['catch_all'] - assert 'class' in actions['catch_all']['tests.fake_plugins:FakePlugin2'] - assert 'function' in actions['catch_all']['tests.fake_plugins:FakePlugin2'] + listen_to_key = "tests.fake_plugins:FakePlugin.listen_function-hi" + assert listen_to_key in actions["listen_to"] + assert "class" in actions["listen_to"][listen_to_key] + assert "function" in actions["listen_to"][listen_to_key] + assert "regex" in actions["listen_to"][listen_to_key] + assert actions["listen_to"][listen_to_key]["regex"] == re.compile( + "hi", re.IGNORECASE + ) def test_plugin_storage_fq_plugin_name(settings): machine = Machine(settings=settings) actions = machine._plugin_actions - plugin1_cls = actions['respond_to']['tests.fake_plugins:FakePlugin.respond_function-hello'][ - 'class'] - plugin2_cls = actions['catch_all']['tests.fake_plugins:FakePlugin2']['class'] - assert plugin1_cls.storage._fq_plugin_name == 'tests.fake_plugins:FakePlugin' - assert plugin2_cls.storage._fq_plugin_name == 'tests.fake_plugins:FakePlugin2' + plugin1_cls = actions["respond_to"][ + "tests.fake_plugins:FakePlugin.respond_function-hello" + ]["class"] + assert plugin1_cls.storage._fq_plugin_name == "tests.fake_plugins:FakePlugin" def test_plugin_init(settings): machine = Machine(settings=settings) actions = machine._plugin_actions - plugin_cls = actions['catch_all']['tests.fake_plugins:FakePlugin2']['class'] + plugin_cls = actions["respond_to"][ + "tests.fake_plugins:FakePlugin.respond_function-hello" + ]["class"] assert plugin_cls.x == 42 def test_required_settings(settings_with_required, required_settings_class): machine = Machine(settings=settings_with_required) missing = machine._check_missing_settings(required_settings_class) - assert 'SETTING_1' not in missing - assert 'SETTING_2' in missing + assert "SETTING_1" not in missing + assert "SETTING_2" in missing diff --git a/tests/test_plugin_storage.py b/tests/test_plugin_storage.py index 0739b38d..cf0b23c5 100644 --- a/tests/test_plugin_storage.py +++ b/tests/test_plugin_storage.py @@ -1,54 +1,62 @@ +# -*- coding: utf-8 -*- import pytest from machine.storage import PluginStorage from machine.storage.backends.memory import MemoryStorage +from tests.helpers import async_test + @pytest.fixture def storage_backend(mocker): storage = MemoryStorage({}) - backend_get_instance = mocker.patch('machine.storage.Storage.get_instance') + backend_get_instance = mocker.patch("machine.storage.Storage.get_instance") backend_get_instance.return_value = storage return storage @pytest.fixture def plugin_storage(storage_backend): - storage_instance = PluginStorage('tests.fake_plugin.FakePlugin') + storage_instance = PluginStorage("tests.fake_plugin.FakePlugin") return storage_instance -def test_namespacing(plugin_storage, storage_backend): - plugin_storage.set('key1', 'value1') - expected_key = 'tests.fake_plugin.FakePlugin:key1' +@async_test +async def test_namespacing(plugin_storage, storage_backend): + await plugin_storage.set("key1", "value1") + expected_key = "tests.fake_plugin.FakePlugin:key1" assert expected_key in storage_backend._storage -def test_inclusion(plugin_storage, storage_backend): - plugin_storage.set('key1', 'value1') - assert plugin_storage.has('key1') == True - assert 'key1' in plugin_storage +@async_test +async def test_inclusion(plugin_storage, storage_backend): + await plugin_storage.set("key1", "value1") + assert (await plugin_storage.has("key1")) == True + assert await plugin_storage.has("key1") -def test_retrieve(plugin_storage): - plugin_storage.set('key1', 'value1') - retrieved = plugin_storage.get('key1') - assert retrieved == 'value1' +@async_test +async def test_retrieve(plugin_storage): + await plugin_storage.set("key1", "value1") + retrieved = await plugin_storage.get("key1") + assert retrieved == "value1" -def test_shared(plugin_storage, storage_backend): - plugin_storage.set('key1', 'value1', shared=True) - assert plugin_storage.has('key1') == False - assert plugin_storage.has('key1', shared=True) == True - expected_key = 'key1' +@async_test +async def test_shared(plugin_storage, storage_backend): + await plugin_storage.set("key1", "value1", shared=True) + assert (await plugin_storage.has("key1")) == False + assert (await plugin_storage.has("key1", shared=True)) == True + expected_key = "key1" assert expected_key in storage_backend._storage -def test_delete(plugin_storage, storage_backend): - plugin_storage.set('key1', 'value1') - assert plugin_storage.has('key1') == True - expected_key = 'tests.fake_plugin.FakePlugin:key1' +@async_test +async def test_delete(plugin_storage, storage_backend): + await plugin_storage.set("key1", "value1") + assert (await plugin_storage.has("key1")) == True + expected_key = "tests.fake_plugin.FakePlugin:key1" assert expected_key in storage_backend._storage - plugin_storage.delete('key1') - assert plugin_storage.has('key1') == False + await plugin_storage.delete("key1") + assert (await plugin_storage.has("key1")) == False assert expected_key not in storage_backend._storage diff --git a/tests/test_redis_storage.py b/tests/test_redis_storage.py index 85c74d8f..8530c48b 100644 --- a/tests/test_redis_storage.py +++ b/tests/test_redis_storage.py @@ -1,47 +1,69 @@ -from unittest.mock import MagicMock +# -*- coding: utf-8 -*- +from unittest import mock + +import aioredis import pytest -from redis import StrictRedis from machine.storage.backends.redis import RedisStorage +from tests.helpers import async_test, make_coroutine_mock +from tests.helpers.expect import ExpectMock, expect + -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def redis_client(): - return MagicMock(spec=StrictRedis) + return ExpectMock(spec=aioredis.Redis) @pytest.fixture -def redis_storage(mocker, redis_client): - mocker.patch('machine.storage.backends.redis.StrictRedis', autospec=True) - settings = {'REDIS_URL': 'redis://nohost:1234'} +def redis_storage(expect, redis_client): + create_redis_pool = mock.patch( + "machine.storage.backends.redis.aioredis.create_redis_pool" + ) + create_redis_pool.return_value = redis_client + settings = {"REDIS_URL": "redis://nohost:1234"} storage = RedisStorage(settings) storage._redis = redis_client return storage -def test_set(redis_storage, redis_client): - redis_storage.set('key1', 'value1') - redis_client.set.assert_called_with('SM:key1', 'value1', None) - redis_storage.set('key2', 'value2', 42) - redis_client.set.assert_called_with('SM:key2', 'value2', 42) +@async_test +async def test_set(redis_storage, redis_client): + redis_client.set.expect("SM:key1", "value1", None).returns( + make_coroutine_mock(None) + ) + redis_client.set.expect("SM:key2", "value2", 42).returns(make_coroutine_mock(None)) + + await redis_storage.set("key1", "value1") + await redis_storage.set("key2", "value2", 42) + + +@async_test +async def test_get(redis_storage, redis_client): + redis_client.get.expect("SM:key1").returns(make_coroutine_mock(None)) + + await redis_storage.get("key1") + +@async_test +async def test_has(redis_storage, redis_client): + redis_client.exists.expect("SM:key1").returns(make_coroutine_mock(None)) -def test_get(redis_storage, redis_client): - redis_storage.get('key1') - redis_client.get.assert_called_with('SM:key1') + await redis_storage.has("key1") -def test_has(redis_storage, redis_client): - redis_storage.has('key1') - redis_client.exists.assert_called_with('SM:key1') +@async_test +async def test_delete(redis_storage, redis_client): + redis_client.delete.expect("SM:key1").returns(make_coroutine_mock(None)) + await redis_storage.delete("key1") -def test_delete(redis_storage, redis_client): - redis_storage.delete('key1') - redis_client.delete.assert_called_with('SM:key1') +@async_test +async def test_size(redis_storage, redis_client): + redis_client.info.expect("memory").returns( + make_coroutine_mock({"used_memory": "haha all of it"}) + ) -def test_size(redis_storage, redis_client): - redis_storage.size() - redis_client.info.assert_called_with('memory') + await redis_storage.size() diff --git a/tox.ini b/tox.ini index dc57c9ff..7970975b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,17 @@ [tox] -envlist = py36,py37,flake8 +envlist = full,flake8 [testenv] commands = pytest tests deps =-r{toxinidir}/requirements-dev.txt [testenv:py37] -passenv = CI TRAVIS TRAVIS_* -commands = pytest --cov-config .coveragerc --verbose --cov-report term-missing --cov-report xml --cov=machine tests +passenv = CI CIRCLE_* +commands = pytest --cov-config=.coveragerc --verbose --cov-report=term-missing --cov-report=html:./test-reports/coverage/ --junitxml=./test-reports/junit.xml tests + +[testenv:py38] +passenv = CI CIRCLE_* +commands = pytest --cov-config=.coveragerc --verbose --cov-report=term-missing --cov-report=html:./test-reports/coverage/ --junitxml=./test-reports/junit.xml tests [testenv:flake8] deps = flake8