diff --git a/zulip/integrations/discord/README.md b/zulip/integrations/discord/README.md new file mode 100644 index 000000000..da9e0c70d --- /dev/null +++ b/zulip/integrations/discord/README.md @@ -0,0 +1,99 @@ +# Discord<->Zulip bridge + +It supports basic text mirroring between Discord guilds and Zulip +streams (within a single Zulip realm). On both the Discord and Zulip +sides, it can either send as a single bot with limited special +permissions, or with special permissions (webhooks and can_forge_sender +respectively), it can more naturally mirror them. + +This design more naturally fits a personal Zulip realm supporting many Discord +guilds than a single organization that wants to run on both Zulip and Discord. +For that, you might want each Discord channel to match a Zulip stream, and +Discord threads to match Zulip topics. Supporting this mode of operation as +well would be a good future enhancement. + +There is currently no special support for threads, media, embeds, reactions, +edits, system messages (joins/leaves), etc. -- all of these would be good +future additions. The bot also does not create channels automatically for new +Zulip topics, which might be a good (optional) addition. + +Configuration lives in a single `bridge.ini` file. A template can be created by running: +``` +bridge.py --write-sample-config=bridge.ini --from-zuliprc=zuliprc +``` + +(The `zuliprc` is optional; see the Zulip section for details on getting it.) + +Configuration consists of a `zulip` section, `discord` section, and a list of guilds and streams to associate. + +With the exception of the stream<->guild setup, all configuration is global. Some features that might make sense to make configurable by guild: +- Whether to use webhooks to forge Discord senders +- How to create a topic name from a Discord thread + +This is developed on Python 3.8, and definitely requires at least Python 3.7 (for `asyncio.run`). + +## Zulip + +For Zulip setup, create a bot user (gear -> personal settings -> Bots -> Add a +new bot) and download a `zuliprc`. Then, you can run `bridge.py +--write-sample-config --from-zuliprc=zuliprc` to create a sample config based +on that `zuliprc`. + +Forging senders is optional -- it will make messages forwarded to Zulip look +more natural, but requires special permissions. To use it: +- Set the `forge_sender` key of the `zulip` section to `true` in your + `bridge.ini` file, and run `./manage.py change_user_role -r discord-mirror + discord-bot@discord-mirror.zulip.org can_forge_sender` (adjust realm and bot + name appropriately) on your Zulip server to grant permissions. +- Add a `RealmDomain` of "users.discord.com": under gear -> Manage organization + -> Organization permissions -> Restrict email domains of new users?, choose + to restrict to a list of domains, and add "users.discord.com". You can then + freely switch back to "Don't allow disposable email addresses" or another + value for that setting, if you wish. +- Note that for technical reasons, forging senders involves pretending to be a + Jabber mirror; as a result, sender names will include " (XMPP)" after them. + (If you prefer to show " (irc)", you can change `jabber_mirror` to + `irc_mirror` in `bridge.py`. It should also be a simple server patch to + support another client name with no suffix or a different suffix.) + +## Discord + +Create a Discord integration (https://discord.com/developers/applications/) + +On the "Bot" tab, add a bot. Copy the token into the `token` key of the +`discord` section of the config file. (Don't confuse the token with the +application ID on the "general information" tab, or the client ID or secret on +the OAuth2 tab -- you should be on the "Bot" tab.) Enable "message content +intent" -- this is a privileged intent, so if you want to use the bot with more +than 100 guilds, you'll need to get your bot reviewed. For typical uses with +only a handful of guilds, though, no review is needed. + +On the OAuth2 URL generator page, give it the `bot` scope with the following +permissions: +- Manage webhooks (optional; to set the sender name) +- Read messages / view channels +- Send messages +- Create public threads +- Create private threads +- Send messages in threads +- Read message history + +This will produce a URL with these permissions and some client ID, along the lines of: +https://discord.com/api/oauth2/authorize?client_id=914346072418185256&permissions=378494061568&scope=bot + +Following this link will allow adding the integration to a Discord server you have Manager Server permissions on. + +The manage webhooks permission is optional, but makes messages forwarded to +Discord look more native; if you wish to disable it, set the `use_webhook` key +in the `discord` section to false. + +## Streams + +The `streams` section has no specific keys. Instead, each key is a stream name, +and the corresponding value is a Discord guild ID. To find a guild ID, open the +the Discord webapp, and navigate to the guild (server). You should see a URL +like `https://discord.com/channels//`, where `` is the +guild ID to use. + +Make sure to subscribe your Zulip bot user to each relevant stream, and add +your Discord bot to each relevant guild. diff --git a/zulip/integrations/discord/bridge.py b/zulip/integrations/discord/bridge.py new file mode 100755 index 000000000..df633b994 --- /dev/null +++ b/zulip/integrations/discord/bridge.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python + +"""Discord/Zulip bridge""" + +import argparse +import asyncio +import configparser +import logging +import os +import re +import sys +import traceback +from typing import Any, Dict, Literal, Optional, Tuple, Union, cast + +import discord + +import zulip +import zulip.asynch + +LOG_FORMAT = "%(asctime)s %(levelname)7s - %(name)20s - %(message)s" +logger = logging.getLogger(__name__) + +CHANNEL_THREAD_SEP = " >> " + + +class Bridge_ConfigException(Exception): + pass + + +class Bridge_ZulipFatalException(Exception): + pass + + +# Zulip uses Django to validate emails, which provides an EmailValidator class +# https://github.com/django/django/blob/main/django/core/validators.py#L158 +# This regex is copied from there +EMAIL_USER_RE = re.compile( + r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z" # dot-atom + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)', # quoted-string + re.IGNORECASE, +) + + +class GuildConfig: + def __init__(self, stream: str) -> None: + self.stream = stream + self.webhooks: Dict[int, discord.Webhook] = {} + + +class Bridge: + "Zulip/Discord bridge" "" + + def __init__( + self, + discord_config: Dict[str, Union[str, bool]], + zulip_client: zulip.asynch.AsyncClient, + zulip_config: Dict[str, Union[str, bool]], + streams: Dict[str, int], + ) -> None: + self._discord_config = discord_config + # The Discord client doesn't have any useful configuration (the token + # is supplied to the login() method), but does embed the event loop + # The intent is that Bridge can be created before starting the event + # loop, so while the Zulip client (which includes config) is passed in, + # the Discord client is created when needed. + self._discord_client: Optional[discord.Client] = None + self._zulip_client = zulip_client + self._streams = streams + self._guilds = self._build_guild_config(streams) + self._discord_domain = zulip_config.get("discord_domain", "users.discord.com") + self._default_stream = zulip_config.get("default_stream") + self._default_topic = zulip_config.get("default_topic", "(no topic)") + self._forge_zulip = zulip_config["forge_sender"] + + @staticmethod + def _build_guild_config(streams: Dict[str, int]) -> Dict[int, GuildConfig]: + ret = {} + for stream, guild_id in streams.items(): + ret[guild_id] = GuildConfig(stream) + return ret + + def get_zulip_stream_and_topic_from_discord(self, message: discord.Message) -> Tuple[Optional[str], str]: + if not message.guild: + logging.warning("Message not from a guild: %s", message) + return None, "" + stream = self._guilds[message.guild.id].stream + if isinstance(message.channel, discord.TextChannel): + topic = message.channel.name + elif isinstance(message.channel, discord.Thread): + parent = message.channel.parent + if not parent: + logging.warning("Thread message has no parent: %s", message) + return None, "" + topic = parent.name + CHANNEL_THREAD_SEP + message.channel.name + else: + logging.warning("Message not text or thread: %s", message) + return None, "" + return stream, topic + + def get_zulip_sender_from_discord(self, sender: discord.abc.User) -> str: + if EMAIL_USER_RE.match(sender.display_name): + user_part = sender.display_name + else: + logger.debug( + "Sender %s display_name='%s' invalid user part, using name='%s'", + sender, + sender.display_name, + sender.name, + ) + user_part = sender.name + return user_part + "@" + self._discord_domain + + async def on_discord_message(self, message: discord.Message) -> None: + logger.info(f'Discord message from {message.author}: "{message.content}" {message}') + + # Avoid mirroring our own messages + assert self._discord_client + if message.author == self._discord_client.user: + logger.info("Ignoring message from self: %s", message) + return + if message.webhook_id: + guild_config = self._guilds[message.guild.id] + webhook = guild_config.webhooks.get(message.channel.id) + if webhook and webhook.id == message.webhook_id: + # There is a small race here: if we restarted after sending a + # zulip->Discord message, we might process it post-restart + # before adding the webhook to the webhook cache. That's + # probably an acceptable risk + logger.info("Message sent through our webhook, ignoring: %s", message) + return + + # Send it to Zulip + stream, topic = self.get_zulip_stream_and_topic_from_discord(message) + if not stream: + # Above already did logging + return + if self._forge_zulip: + content = message.content # TODO: consider including embeds, etc. + out_msg = dict( + forged="yes", + sender=self.get_zulip_sender_from_discord(cast(discord.abc.User, message.author)), + type="stream", + subject=topic, + to=stream, + content=content, + ) + else: + content = "***%s***: %s" % (message.author, message.content) + out_msg = dict( + type="stream", + subject=topic, + to=stream, + content=content, + ) + logger.info("About to send: %s", out_msg) + response = await self._zulip_client.send_message(out_msg) + logger.info("Zulip send message response: %s", response) + + # TODO: + # https://docs.pycord.dev/en/master/api.html#discord.on_reaction_add + + ChanThread = Tuple[Optional[discord.TextChannel], + Union[discord.Thread,Literal[False],Literal[None]]] + + def get_channel_from_zulip(self, message: Dict[str, Any]) -> ChanThread: + stream = message["display_recipient"] + topic = message["subject"] + try: + guild_id = self._streams[stream] + except KeyError: + logger.warning("Couldn't find guild for stream %s: message %s", stream, message) + return None, None + assert self._discord_client + guild = self._discord_client.get_guild(guild_id) + if not guild: + logger.warning( + "Guild ID %s not found, for stream %s: message %s, guilds=%s", + guild_id, + stream, + message, + self._discord_client.guilds, + ) + return None, None + channel_name, sep, thread_name = topic.partition(CHANNEL_THREAD_SEP) + channel = discord.utils.get(guild.channels, name=channel_name) + if not isinstance(channel, discord.TextChannel): + logger.warning("Channel %s not found as text channel in guild %s: " + "channel %s, message %s", topic, guild, channel, message) + return None, None + thread: Union[Optional[discord.Thread], Literal[False]] = False + if thread_name: + thread = discord.utils.get(channel.threads, name=thread_name) + if not thread: + logger.warning("Thread %s (chan=%s, thread=%s) not found " + "in guild %s: message %s", + topic, channel_name, thread_name, guild, message) + return channel, thread + + async def get_webhook_for_discord_channel( + self, channel: discord.TextChannel + ) -> Optional[discord.Webhook]: + if not self._discord_config["use_webhook"]: + # Discord webhooks are disabled + return None + + # Relevant docs: + # https://docs.pycord.dev/en/master/api.html#discord.TextChannel.create_webhook + # https://docs.pycord.dev/en/master/api.html#discord.TextChannel.webhooks + # https://discordpy.readthedocs.io/en/latest/api.html#discord.utils.get + + guild_config = self._guilds[channel.guild.id] + + # Check the cache + webhook = guild_config.webhooks.get(channel.id) + if webhook: + return webhook + + # See if we created one previously + webhooks = await channel.webhooks() + webhook = discord.utils.get(webhooks, name="zulip_mirror", + user=self._discord_client.user) + logger.info("Checked channel %s for existing webhooks, got %s", channel, webhook) + if not webhook: + # Create a new one + reason = "custom sender for zulip mirror" + webhook = await channel.create_webhook(name="zulip_mirror", reason=reason) + # Cache and return + guild_config.webhooks[channel.id] = webhook + return webhook + + async def on_zulip_message(self, message: Dict[str, Any]) -> None: + logger.info("Zulip message: %s", message) + if message["type"] != "stream": + # ignore personals + return + sender = cast(str, message["sender_full_name"]) + content = cast(str, message['content']) + + # Check if this was a message we might have sent + # Note that there is some server side filtering for clients named + # "mirror", but that might change, so do it ourselves too. See also + # https://chat.zulip.org/#narrow/stream/127-integrations/topic/suppressed.20own.20messages/near/1287622 + if message["client"] == self._zulip_client.sync_client.client_name: + logger.info("Ignoring message %s from mirroring client %s", message, message["client"]) + return + + # Send to Discord + channel, thread = self.get_channel_from_zulip(message) + if not channel: + # get_channel_from_zulip will have logged a warning + return + # TODO: consider including embeds, etc. + webhook = await self.get_webhook_for_discord_channel(channel) + if webhook: + thread_ = cast(discord.abc.Snowflake, thread or discord.utils.MISSING) + await webhook.send(username=sender, content=content, thread=thread_) + return + content = "%s: %s" % (sender, message["content"]) + await (thread or channel).send(content=content) + + async def run_tasks(self) -> None: + logger.info("Starting tasks...") + logger.info("Connecting to discord...") + assert not self._discord_client + intents = discord.Intents(messages=True, guilds=True) + self._discord_client = discord.Client(intents=intents) + await self._discord_client.login(self._discord_config["token"]) + + print("Creating message handler on Zulip client") + zulip_await = self._zulip_client.call_on_each_message(self.on_zulip_message) + + print("Creating message handler on Discord client") + + @self._discord_client.event + async def on_message(message: discord.Message) -> None: + """Discord message-handling callback""" + await self.on_discord_message(message) + + discord_await = self._discord_client.connect() + + awaitables = [zulip_await, discord_await] + logger.info("awaitables=%s", awaitables) + await asyncio.gather(*awaitables) + logger.info("run_tasks finished...") + + +def generate_parser() -> argparse.ArgumentParser: + description = """ + Script to bridge between Zulip and Discord. + """ + + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "-c", "--config", required=False, help="Path to the config file for the bridge." + ) + parser.add_argument( + "--write-sample-config", + metavar="PATH", + dest="sample_config", + help="Generate a configuration template at the specified location.", + ) + parser.add_argument( + "--from-zuliprc", + metavar="ZULIPRC", + dest="zuliprc", + help="Optional path to zuliprc file for bot, when using --write-sample-config", + ) + return parser + + +def read_configuration(config_file: str) -> Dict[str, Dict[str, Union[str, bool]]]: + config = configparser.ConfigParser() + + try: + config.read(config_file) + except configparser.Error as exception: + raise Bridge_ConfigException(str(exception)) + + if set(config.sections()) != {"discord", "zulip", "streams"}: + raise Bridge_ConfigException( + "Please ensure the configuration has discord, zulip, and streams sections." + ) + + # TODO Could add more checks for configuration content here + + parsed: Dict[str, Dict[str, Union[str, bool]]] = { + section: dict(config[section]) for section in config.sections() + } + parsed["zulip"]["forge_sender"] = config.getboolean("zulip", "forge_sender", fallback=False) + parsed["discord"]["use_webhook"] = config.getboolean("discord", "use_webhook", fallback=True) + return parsed + + +def write_sample_config(target_path: str, zuliprc: Optional[str]) -> None: + if os.path.exists(target_path): + raise Bridge_ConfigException( + "Path '{}' exists; not overwriting existing file.".format(target_path) + ) + + sample_dict = dict( + zulip=dict( + email="discord-bot@chat.zulip.org", + api_key="aPiKeY", + site="https://chat.zulip.org", + forge_sender="false", + ), + discord=dict( + token="bot_token", + use_webhook="true", + ), + streams={ + "test here": "guild ID", + }, + ) + + if zuliprc is not None: + if not os.path.exists(zuliprc): + raise Bridge_ConfigException("Zuliprc file '{}' does not exist.".format(zuliprc)) + + zuliprc_config = configparser.ConfigParser() + try: + zuliprc_config.read(zuliprc) + except configparser.Error as exception: + raise Bridge_ConfigException(str(exception)) + + # Can add more checks for validity of zuliprc file here + + sample_dict["zulip"]["email"] = zuliprc_config["api"]["email"] + sample_dict["zulip"]["site"] = zuliprc_config["api"]["site"] + sample_dict["zulip"]["api_key"] = zuliprc_config["api"]["key"] + + sample = configparser.ConfigParser() + sample.read_dict(sample_dict) + with open(target_path, "w") as target: + sample.write(target) + + +def main() -> None: + # signal.signal(signal.SIGINT, die) + logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) + + parser = generate_parser() + options = parser.parse_args() + + if options.sample_config: + try: + write_sample_config(options.sample_config, options.zuliprc) + except Bridge_ConfigException as exception: + print("Could not write sample config: {}".format(exception)) + sys.exit(1) + if options.zuliprc is None: + print("Wrote sample configuration to '{}'".format(options.sample_config)) + else: + print( + "Wrote sample configuration to '{}' using zuliprc file '{}'".format( + options.sample_config, options.zuliprc + ) + ) + sys.exit(0) + elif not options.config: + print("Options required: -c or --config to run, OR --write-sample-config.") + parser.print_usage() + sys.exit(1) + + try: + config = read_configuration(options.config) + except Bridge_ConfigException as exception: + print("Could not parse config file: {}".format(exception)) + sys.exit(1) + + # Get config for each client + discord_config = config["discord"] + zulip_config = config["zulip"] + stream_config = {stream: int(guild) for stream, guild in config["streams"].items()} + logger.info("zulip_config=%s", zulip_config) + + # Initiate clients + backoff = zulip.asynch.RandomExponentialBackoff(timeout_success_equivalent=300) + while backoff.keep_going(): + print("Starting mirroring bot") + try: + if zulip_config.get("forge_sender"): + # An odd "security"(?) measure is that only certain clients + # names can forge messages, even if they have the "super-admin" + # permission + client = "jabber_mirror" + else: + client = "discord_mirror" + zulip_sync_client = zulip.Client( + email=cast(str, zulip_config["email"]), + api_key=cast(str, zulip_config["api_key"]), + site=cast(str, zulip_config["site"]), + client=client, + verbose=True, + ) + + zulip_async_client = zulip.asynch.AsyncClient(zulip_sync_client) + + bridge = Bridge(discord_config, zulip_async_client, zulip_config, stream_config) + + logger.info("About to run_tasks") + asyncio.run(bridge.run_tasks()) + logger.info("Finished run()") + + break + + except Bridge_ZulipFatalException as exception: + sys.exit("Zulip bridge error: {}".format(exception)) + except zulip.ZulipError as exception: + sys.exit("Zulip error: {}".format(exception)) + except Exception: + traceback.print_exc() + backoff.fail() + + +if __name__ == "__main__": + main() diff --git a/zulip/integrations/discord/requirements.txt b/zulip/integrations/discord/requirements.txt new file mode 100644 index 000000000..f18ac7fbf --- /dev/null +++ b/zulip/integrations/discord/requirements.txt @@ -0,0 +1 @@ +py-cord>=2.0.0b1 diff --git a/zulip/zulip/__init__.py b/zulip/zulip/__init__.py index 86765f264..cbf9d6ca1 100644 --- a/zulip/zulip/__init__.py +++ b/zulip/zulip/__init__.py @@ -747,8 +747,8 @@ def do_register() -> Tuple[str, int]: queue_id = None # Make long-polling requests with `get_events`. Once a request - # has received an answer, pass it to the callback and before - # making a new long-polling request. + # has received an answer, pass it to the callback before making + # a new long-polling request. while True: if queue_id is None: (queue_id, last_event_id) = do_register() diff --git a/zulip/zulip/asynch.py b/zulip/zulip/asynch.py new file mode 100644 index 000000000..8b468eac2 --- /dev/null +++ b/zulip/zulip/asynch.py @@ -0,0 +1,423 @@ +import asyncio +import json +import logging +import random +import sys +import traceback +import urllib.parse +from typing import ( + IO, + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) + +import aiohttp + +import zulip + +logger = logging.getLogger(__name__) + +API_VERSTRING = "v1/" + + +class RandomExponentialBackoff(zulip.CountingBackoff): + async def fail(self) -> None: + super().fail() + # Exponential growth with ratio sqrt(2); compute random delay + # between x and 2x where x is growing exponentially + delay_scale = int(2 ** (self.number_of_retries / 2.0 - 1)) + 1 + delay = min(delay_scale + random.randint(1, delay_scale), self.delay_cap) + message = f"Sleeping for {delay}s [max {delay_scale * 2}] before retrying." + try: + logger.warning(message) + except NameError: + print(message) + await asyncio.sleep(delay) + + +class AsyncClient: + def __init__(self, client: zulip.Client): + self.sync_client = client + self.session: Optional[aiohttp.ClientSession] = None + self.retry_on_errors = client.retry_on_errors + self.verbose = client.verbose + + def ensure_session(self) -> None: + # Check if the session has been created already, and return + # immediately if so. + if self.session: + return + + # Build a client cert object for requests + if self.sync_client.client_cert_key is not None: + assert ( + self.sync_client.client_cert is not None + ) # Otherwise ZulipError near end of __init__ + client_cert = ( + self.sync_client.client_cert, + self.sync_client.client_cert_key, + ) # type: Union[None, str, Tuple[str, str]] + else: + client_cert = self.sync_client.client_cert + + # Actually construct the session + session = aiohttp.ClientSession( + auth=aiohttp.BasicAuth(self.sync_client.email, self.sync_client.api_key), + # TODO: Support overriding TLS verification + # verify = self.tls_verification, + # cert = client_cert, + headers={"User-agent": self.sync_client.get_user_agent()}, + ) + self.session = session + + async def do_api_query( + self, + orig_request: Mapping[str, Any], + url: str, + method: str = "POST", + longpolling: bool = False, + files: Optional[List[IO[Any]]] = None, + timeout: Optional[float] = None, + ) -> Dict[str, Any]: + if files is None: + files = [] + + if longpolling: + # When long-polling, set timeout to 90 sec as a balance + # between a low traffic rate and a still reasonable latency + # time in case of a connection failure. + request_timeout = 90.0 + else: + # Otherwise, 15s should be plenty of time. + request_timeout = 15.0 if not timeout else timeout + + request = {} + req_files = [] + + for (key, val) in orig_request.items(): + if isinstance(val, str) or isinstance(val, str): + request[key] = val + else: + request[key] = json.dumps(val) + + for f in files: + req_files.append((f.name, f)) + + self.ensure_session() + assert self.session is not None + + query_state = { + "had_error_retry": False, + "request": request, + "failures": 0, + } # type: Dict[str, Any] + + async def error_retry(error_string: str) -> bool: + if not self.retry_on_errors or query_state["failures"] >= 10: + return False + if self.verbose: + if not query_state["had_error_retry"]: + sys.stdout.write( + "zulip API(%s): connection error%s -- retrying." + % ( + url.split(API_VERSTRING, 2)[0], + error_string, + ) + ) + query_state["had_error_retry"] = True + else: + sys.stdout.write(".") + sys.stdout.flush() + query_state["request"]["dont_block"] = json.dumps(True) + await asyncio.sleep(1) + query_state["failures"] += 1 + return True + + def end_error_retry(succeeded: bool) -> None: + if query_state["had_error_retry"] and self.verbose: + if succeeded: + print("Success!") + else: + print("Failed!") + + while True: + try: + if method == "GET": + kwarg = "params" + else: + kwarg = "data" + + kwargs = {kwarg: query_state["request"]} + + if files: + kwargs["files"] = req_files + + # Actually make the request! + res = await self.session.request( + method, + urllib.parse.urljoin(self.sync_client.base_url, url), + timeout=aiohttp.ClientTimeout(total=request_timeout), + **kwargs, + ) + print(res) + + self.has_connected = True + + # On 50x errors, try again after a short sleep + if str(res.status).startswith("5"): + if await error_retry(f" (server {res.status})"): + continue + # Otherwise fall through and process the error normally + except (aiohttp.ServerTimeoutError, aiohttp.ClientSSLError) as e: + # Timeouts are either a ServerTimeoutError or a ClientSSLError. We + # want the later exception handlers to deal with any + # non-timeout other SSLErrors + if ( + isinstance(e, aiohttp.ClientSSLError) + and str(e) != "The read operation timed out" + ): + raise zulip.UnrecoverableNetworkError("SSL Error") + if longpolling: + # When longpolling, we expect the timeout to fire, + # and the correct response is to just retry + continue + else: + end_error_retry(False) + return { + "msg": f"Connection error:\n{traceback.format_exc()}", + "result": "connection-error", + } + except (aiohttp.ClientConnectorError, aiohttp.ServerDisconnectedError): + if not self.has_connected: + # If we have never successfully connected to the server, don't + # go into retry logic, because the most likely scenario here is + # that somebody just hasn't started their server, or they passed + # in an invalid site. + raise zulip.UnrecoverableNetworkError( + "cannot connect to server " + self.sync_client.base_url + ) + + if await error_retry(""): + continue + end_error_retry(False) + return { + "msg": f"Connection error:\n{traceback.format_exc()}", + "result": "connection-error", + } + except Exception: + # We'll split this out into more cases as we encounter new bugs. + return { + "msg": f"Unexpected error:\n{traceback.format_exc()}", + "result": "unexpected-error", + } + + status_code = -1 + try: + async with res: + status_code = res.status + json_result = await res.json() + except Exception: + json_result = None + + if json_result is not None: + end_error_retry(True) + return json_result + end_error_retry(False) + return { + "msg": "Unexpected error from the server", + "result": "http-error", + "status_code": status_code, + } + + async def call_endpoint( + self, + url: Optional[str] = None, + method: str = "POST", + request: Optional[Dict[str, Any]] = None, + longpolling: bool = False, + files: Optional[List[IO[Any]]] = None, + timeout: Optional[float] = None, + ) -> Awaitable[Dict[str, Any]]: + if request is None: + request = dict() + marshalled_request = {} + for (k, v) in request.items(): + if v is not None: + marshalled_request[k] = v + versioned_url = API_VERSTRING + (url if url is not None else "") + return await self.do_api_query( + marshalled_request, + versioned_url, + method=method, + longpolling=longpolling, + files=files, + timeout=timeout, + ) + + async def event_iter( + self, + event_types: Optional[List[str]] = None, + narrow: Optional[List[List[str]]] = None, + **kwargs: object, + ) -> AsyncGenerator[Dict[str, Any], None]: + if narrow is None: + narrow = [] + + async def do_register() -> Tuple[str, int]: + + while True: + if event_types is None: + res = await self.register(None, None, **kwargs) + else: + res = await self.register(event_types, narrow, **kwargs) + if "error" in res["result"]: + if self.verbose: + print("Server returned error:\n{}".format(res["msg"])) + await asyncio.sleep(1) + else: + return (res["queue_id"], res["last_event_id"]) + + queue_id = None + # Make long-polling requests with `get_events`. Once a request + # has received an answer, pass it to the callback before making + # a new long-polling request. + while True: + if queue_id is None: + (queue_id, last_event_id) = await do_register() + + res = await self.get_events(queue_id=queue_id, last_event_id=last_event_id) + if "error" in res["result"]: + if res["result"] == "http-error": + if self.verbose: + print("HTTP error fetching events -- probably a server restart") + elif res["result"] == "connection-error": + if self.verbose: + print( + "Connection error fetching events -- probably server is temporarily down?" + ) + else: + if self.verbose: + print("Server returned error:\n{}".format(res["msg"])) + # Eventually, we'll only want the + # BAD_EVENT_QUEUE_ID check, but we check for the + # old string to support legacy Zulip servers. We + # should remove that legacy check in 2019. + if res.get("code") == "BAD_EVENT_QUEUE_ID" or res["msg"].startswith( + "Bad event queue id:" + ): + # Our event queue went away, probably because + # we were asleep or the server restarted + # abnormally. We may have missed some + # events while the network was down or + # something, but there's not really anything + # we can do about it other than resuming + # getting new ones. + # + # Reset queue_id to register a new event queue. + queue_id = None + # Add a pause here to cover against potential bugs in this library + # causing a DoS attack against a server when getting errors. + # TODO: Make this back off exponentially. + await asyncio.sleep(1) + continue + + for event in res["events"]: + last_event_id = max(last_event_id, int(event["id"])) + yield event + + async def call_on_each_event( + self, + callback: Callable[[Dict[str, Any]], Awaitable[None]], + event_types: Optional[List[str]] = None, + narrow: Optional[List[List[str]]] = None, + **kwargs: object, + ) -> None: + async for event in self.event_iter(event_types, narrow, **kwargs): + await callback(event) + + async def call_on_each_message( + self, callback: Callable[[Dict[str, Any]], Awaitable[None]], **kwargs: object + ) -> None: + async def event_callback(event: Dict[str, Any]) -> None: + if event["type"] == "message": + await callback(event["message"]) + + await self.call_on_each_event(event_callback, ["message"], None, **kwargs) + + async def get_messages(self, message_filters: Dict[str, Any]) -> Dict[str, Any]: + """ + See examples/get-messages for example usage + """ + return await self.call_endpoint(url="messages", method="GET", request=message_filters) + + async def get_events(self, **request: Any) -> Dict[str, Any]: + """ + See the register() method for example usage. + """ + return await self.call_endpoint( + url="events", + method="GET", + longpolling=True, + request=request, + ) + + async def register( + self, + event_types: Optional[Iterable[str]] = None, + narrow: Optional[List[List[str]]] = None, + **kwargs: object, + ) -> Dict[str, Any]: + """ + Example usage: + + >>> client.register(['message']) + {u'msg': u'', u'max_message_id': 112, u'last_event_id': -1, u'result': u'success', u'queue_id': u'1482093786:2'} + >>> client.get_events(queue_id='1482093786:2', last_event_id=0) + {...} + """ + + if narrow is None: + narrow = [] + + request = dict(event_types=event_types, narrow=narrow, **kwargs) + + return await self.call_endpoint( + url="register", + request=request, + ) + + async def deregister(self, queue_id: str, timeout: Optional[float] = None) -> Dict[str, Any]: + """ + Example usage: + + >>> client.register(['message']) + {u'msg': u'', u'max_message_id': 113, u'last_event_id': -1, u'result': u'success', u'queue_id': u'1482093786:3'} + >>> client.deregister('1482093786:3') + {u'msg': u'', u'result': u'success'} + """ + request = dict(queue_id=queue_id) + + return await self.call_endpoint( + url="events", + method="DELETE", + request=request, + timeout=timeout, + ) + + async def send_message(self, message_data: Dict[str, Any]) -> Awaitable[Dict[str, Any]]: + """ + See examples/send-message for example usage. + """ + return await self.call_endpoint( + url="messages", + request=message_data, + )