diff --git a/.gitignore b/.gitignore index 94d5bb2e..27392e3f 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,5 @@ target/ tests/fixtures/my.cnf .pytest_cache + +venv \ No newline at end of file diff --git a/CHANGES.txt b/CHANGES.txt index e7fe2231..fc658ec1 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,6 +1,11 @@ Changes ------- +next (unreleased) +^^^^^^^^^^^^^^^^^ + +* Add pluggable authentication plugins. + 0.2.0 (2023-06-11) ^^^^^^^^^^^^^^^^^^ diff --git a/aiomysql/__init__.py b/aiomysql/__init__.py index a367fcd2..c06b67fe 100644 --- a/aiomysql/__init__.py +++ b/aiomysql/__init__.py @@ -29,6 +29,7 @@ InternalError, NotSupportedError, ProgrammingError, MySQLError) +from .auth import AuthPlugin from .connection import Connection, connect from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor from .pool import create_pool, Pool @@ -55,6 +56,7 @@ 'escape_sequence', 'escape_string', + "AuthPlugin", 'Connection', 'Pool', 'connect', diff --git a/aiomysql/auth.py b/aiomysql/auth.py new file mode 100644 index 00000000..332c3fc4 --- /dev/null +++ b/aiomysql/auth.py @@ -0,0 +1,168 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pymysql import OperationalError +from pymysql.connections import _auth + +from .log import logger + +if TYPE_CHECKING: + from aiomysql.connection import Connection + + +@dataclass +class AuthInfo: + password: str + secure: bool + conn: "Connection" + + +class AuthPlugin: + """ + Abstract base class for authentication plugins. + """ + + name = "" + + async def auth(self, auth_info, data): + """ + Async generator for authentication process. + + Subclasses should extend this method. + + Many authentication plugins require back-and-forth exchanges + with the server. These client/server IO - including constructing + the MySQL protocol packets - is handled by the Connection. + All this generator needs to do is receive and send plugin specific data. + + Example: + ``` + class EchoPlugin(AuthPlugin): + async def auth(self, auth_info, data): + data_from_server = data + while True: + data_to_server = data_from_server + data_from_server = yield data_to_server + ``` + + :param auth_info: Various metadata from the current connection, + including a reference to the connection itself. + :param data: Arbitrary data sent by the server. + This can be, for example, a salt, but it's really up to the + plugin protocol to choose. + """ + yield b"" + + async def start( + self, auth_info, data + ): + state = self.auth(auth_info, data) + data = await state.__anext__() + return data, state + + +class MysqlNativePassword(AuthPlugin): + name = "mysql_native_password" + + async def auth(self, auth_info, data): + yield _auth.scramble_native_password(auth_info.password.encode('latin1'), data) + + +class CachingSha2Password(AuthPlugin): + name = "caching_sha2_password" + + async def auth(self, auth_info, data): + salt = data + if auth_info.password: + data = yield _auth.scramble_caching_sha2( + auth_info.password.encode('latin1'), data + ) + else: + data = yield b"" + + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + n = data[0] + + if n == 3: + logger.debug("caching sha2: succeeded by fast path.") + yield None + return + + if n != 4: + raise OperationalError("caching sha2: Unknown " + "result for fast auth: {}".format(n)) + + logger.debug("caching sha2: Trying full auth...") + + if auth_info.secure: + logger.debug("caching sha2: Sending plain " + "password via secure connection") + yield auth_info.password.encode('latin1') + b'\0' + return + + if not auth_info.conn.server_public_key: + auth_info.conn.server_public_key = yield b'\x02' + logger.debug(auth_info.conn.server_public_key.decode('ascii')) + + yield _auth.sha2_rsa_encrypt( + auth_info.password.encode('latin1'), salt, + auth_info.conn.server_public_key + ) + + +class Sha256Password(AuthPlugin): + name = "sha256_password" + + async def auth(self, auth_info, data): + if auth_info.secure: + logger.debug("sha256: Sending plain password") + yield auth_info.password.encode('latin1') + b'\0' + return + + salt = data + + if auth_info.password: + data = yield b'\1' # request public key + auth_info.conn.server_public_key = data + logger.debug( + "Received public key:\n%s", + auth_info.conn.server_public_key.decode('ascii') + ) + yield _auth.sha2_rsa_encrypt( + auth_info.password.encode('latin1'), salt, + auth_info.conn.server_public_key.server_public_key + ) + + else: + yield b'\0' # empty password + + +class MysqlClearPassword(AuthPlugin): + name = "mysql_clear_password" + + async def auth(self, auth_info, data): + yield auth_info.password.encode('latin1') + b'\0' + + +class MysqlOldPassword(AuthPlugin): + name = "mysql_old_password" + + async def auth(self, auth_info, data): + yield _auth.scramble_old_password( + auth_info.password.encode('latin1'), + data, + ) + b'\0' + + +def get_plugins(): + return [ + MysqlNativePassword(), + CachingSha2Password(), + Sha256Password(), + MysqlClearPassword(), + MysqlOldPassword() + ] diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 3520dfcc..56dfe431 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -26,7 +26,6 @@ ProgrammingError) from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET -from pymysql.connections import _auth from pymysql.connections import MysqlPacket from pymysql.connections import FieldDescriptorPacket @@ -35,9 +34,9 @@ from pymysql.connections import LoadLocalPacketWrapper # from aiomysql.utils import _convert_to_str +from .auth import get_plugins, MysqlNativePassword, AuthInfo from .cursors import Cursor from .utils import _pack_int24, _lenenc_int, _ConnectionContextManager, _ContextManager -from .log import logger try: DEFAULT_USER = getpass.getuser() @@ -53,7 +52,7 @@ def connect(host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, auth_plugins=None): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -66,7 +65,8 @@ def connect(host="localhost", user=None, password="", read_default_group=read_default_group, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, - auth_plugin=auth_plugin, program_name=program_name) + auth_plugin=auth_plugin, program_name=program_name, + server_public_key=server_public_key, auth_plugins=auth_plugins) return _ConnectionContextManager(coro) @@ -142,7 +142,7 @@ def __init__(self, host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, auth_plugins=None): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -185,6 +185,8 @@ def __init__(self, host="localhost", user=None, password="", :param server_public_key: SHA256 authentication plugin public key value. :param loop: asyncio loop + :param auth_plugins: List of additional aiomysql.AuthPlugin instances. + These can be referenced by name using `auth_plugin`. """ self._loop = loop or asyncio.get_event_loop() @@ -220,6 +222,14 @@ def __init__(self, host="localhost", user=None, password="", self.server_public_key = server_public_key self.salt = None + self.auth_plugins = {p.name: p for p in get_plugins()} + + # Use mysql_native_password as the default plugin + self.auth_plugins[""] = MysqlNativePassword() + + if auth_plugins: + self.auth_plugins.update({p.name: p for p in auth_plugins}) + from . import __version__ self._connect_attrs = { '_client_name': 'aiomysql', @@ -780,32 +790,12 @@ async def _request_authentication(self): data = data_init + _user + b'\0' - authresp = b'' - auth_plugin = self._client_auth_plugin if not self._client_auth_plugin: # Contains the auth plugin from handshake auth_plugin = self._server_auth_plugin - if auth_plugin in ('', 'mysql_native_password'): - authresp = _auth.scramble_native_password( - self._password.encode('latin1'), self.salt) - elif auth_plugin == 'caching_sha2_password': - if self._password: - authresp = _auth.scramble_caching_sha2( - self._password.encode('latin1'), self.salt - ) - # Else: empty password - elif auth_plugin == 'sha256_password': - if self._ssl_context and self.server_capabilities & CLIENT.SSL: - authresp = self._password.encode('latin1') + b'\0' - elif self._password: - authresp = b'\1' # request public key - else: - authresp = b'\0' # empty password - - elif auth_plugin in ('', 'mysql_clear_password'): - authresp = self._password.encode('latin1') + b'\0' + authresp, authstate = await self._start_auth_plugin(auth_plugin, self.salt) if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += _lenenc_int(len(authresp)) + authresp @@ -829,8 +819,6 @@ async def _request_authentication(self): name = name.encode('ascii') data += name + b'\0' - self._auth_plugin_used = auth_plugin - # Sends the server a few pieces of client info if self.server_capabilities & CLIENT.CONNECT_ATTRS: connect_attrs = b'' @@ -843,192 +831,47 @@ async def _request_authentication(self): self.write_packet(data) auth_packet = await self._read_packet() - # if authentication method isn't accepted the first byte - # will have the octet 254 - if auth_packet.is_auth_switch_request(): - # https://dev.mysql.com/doc/internals/en/ - # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - auth_packet.read_uint8() # 0xfe packet identifier - plugin_name = auth_packet.read_string() - if (self.server_capabilities & CLIENT.PLUGIN_AUTH and - plugin_name is not None): - await self._process_auth(plugin_name, auth_packet) - else: - # send legacy handshake - data = _auth.scramble_old_password( - self._password.encode('latin1'), - auth_packet.read_all()) + b'\0' - self.write_packet(data) - await self._read_packet() - elif auth_packet.is_extra_auth_data(): - if auth_plugin == "caching_sha2_password": - await self.caching_sha2_password_auth(auth_packet) - elif auth_plugin == "sha256_password": - await self.sha256_password_auth(auth_packet) - else: - raise OperationalError("Received extra packet " - "for auth method %r", auth_plugin) - - async def _process_auth(self, plugin_name, auth_packet): - # These auth plugins do their own packet handling - if plugin_name == b"caching_sha2_password": - await self.caching_sha2_password_auth(auth_packet) - self._auth_plugin_used = plugin_name.decode() - elif plugin_name == b"sha256_password": - await self.sha256_password_auth(auth_packet) - self._auth_plugin_used = plugin_name.decode() - else: - - if plugin_name == b"mysql_native_password": - # https://dev.mysql.com/doc/internals/en/ - # secure-password-authentication.html#packet-Authentication:: - # Native41 - data = _auth.scramble_native_password( - self._password.encode('latin1'), - auth_packet.read_all()) - elif plugin_name == b"mysql_old_password": + while auth_packet.is_auth_switch_request() or auth_packet.is_extra_auth_data(): + # if authentication method isn't accepted the first byte + # will have the octet 254 + if auth_packet.is_auth_switch_request(): # https://dev.mysql.com/doc/internals/en/ - # old-password-authentication.html - data = _auth.scramble_old_password( - self._password.encode('latin1'), + # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest + auth_packet.read_uint8() # 0xfe packet identifier + plugin_name = auth_packet.read_string() + if (self.server_capabilities & CLIENT.PLUGIN_AUTH and + plugin_name is not None): + plugin_name = plugin_name.decode("latin1") + else: + plugin_name = "mysql_old_password" + authresp, authstate = await self._start_auth_plugin( + plugin_name, auth_packet.read_all() - ) + b'\0' - elif plugin_name == b"mysql_clear_password": - # https://dev.mysql.com/doc/internals/en/ - # clear-text-authentication.html - data = self._password.encode('latin1') + b'\0' - else: - raise OperationalError( - 2059, "Authentication plugin '{}'" - " not configured".format(plugin_name) - ) - - self.write_packet(data) - pkt = await self._read_packet() - pkt.check_error() - - self._auth_plugin_used = plugin_name.decode() - - return pkt - - async def caching_sha2_password_auth(self, pkt): - # No password fast path - if not self._password: - self.write_packet(b'') - pkt = await self._read_packet() - pkt.check_error() - return pkt - - if pkt.is_auth_switch_request(): - # Try from fast auth - logger.debug("caching sha2: Trying fast path") - self.salt = pkt.read_all() - scrambled = _auth.scramble_caching_sha2( - self._password.encode('latin1'), self.salt - ) - - self.write_packet(scrambled) - pkt = await self._read_packet() - pkt.check_error() - - # else: fast auth is tried in initial handshake - - if not pkt.is_extra_auth_data(): - raise OperationalError( - "caching sha2: Unknown packet " - "for fast auth: {}".format(pkt._data[:1]) - ) - - # magic numbers: - # 2 - request public key - # 3 - fast auth succeeded - # 4 - need full auth - - pkt.advance(1) - n = pkt.read_uint8() - - if n == 3: - logger.debug("caching sha2: succeeded by fast path.") - pkt = await self._read_packet() - pkt.check_error() # pkt must be OK packet - return pkt - - if n != 4: - raise OperationalError("caching sha2: Unknown " - "result for fast auth: {}".format(n)) - - logger.debug("caching sha2: Trying full auth...") - - if self._secure: - logger.debug("caching sha2: Sending plain " - "password via secure connection") - self.write_packet(self._password.encode('latin1') + b'\0') - pkt = await self._read_packet() - pkt.check_error() - return pkt - - if not self.server_public_key: - self.write_packet(b'\x02') - pkt = await self._read_packet() # Request public key - pkt.check_error() - - if not pkt.is_extra_auth_data(): - raise OperationalError( - "caching sha2: Unknown packet " - "for public key: {}".format(pkt._data[:1]) ) - - self.server_public_key = pkt._data[1:] - logger.debug(self.server_public_key.decode('ascii')) - - data = _auth.sha2_rsa_encrypt( - self._password.encode('latin1'), self.salt, - self.server_public_key + else: + auth_packet.read_uint8() # 0x01 packet identifier + try: + authresp = await authstate.asend(auth_packet.read_all()) + except StopAsyncIteration: + raise OperationalError("Received extra packet " + "for auth method %r", self._auth_plugin_used) + + if authresp is not None: + self.write_packet(authresp) + auth_packet = await self._read_packet() + + async def _start_auth_plugin(self, plugin_name, data): + plugin = self.auth_plugins.get(plugin_name) + if not plugin: + raise OperationalError(f"Unknown auth_plugin: {plugin_name}") + auth_info = AuthInfo( + password=self._password, + secure=self._secure, + conn=self, ) - self.write_packet(data) - pkt = await self._read_packet() - pkt.check_error() - - async def sha256_password_auth(self, pkt): - if self._secure: - logger.debug("sha256: Sending plain password") - data = self._password.encode('latin1') + b'\0' - self.write_packet(data) - pkt = await self._read_packet() - pkt.check_error() - return pkt - - if pkt.is_auth_switch_request(): - self.salt = pkt.read_all() - if not self.server_public_key and self._password: - # Request server public key - logger.debug("sha256: Requesting server public key") - self.write_packet(b'\1') - pkt = await self._read_packet() - pkt.check_error() - - if pkt.is_extra_auth_data(): - self.server_public_key = pkt._data[1:] - logger.debug( - "Received public key:\n%s", - self.server_public_key.decode('ascii') - ) - - if self._password: - if not self.server_public_key: - raise OperationalError("Couldn't receive server's public key") - - data = _auth.sha2_rsa_encrypt( - self._password.encode('latin1'), self.salt, - self.server_public_key - ) - else: - data = b'' - - self.write_packet(data) - pkt = await self._read_packet() - pkt.check_error() - return pkt + authresp, authstate = await plugin.start(auth_info, data) + self._auth_plugin_used = plugin.name + return authresp, authstate # _mysql support def thread_id(self):