Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pluggable auth plugins #945

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ target/
tests/fixtures/my.cnf

.pytest_cache

venv
5 changes: 5 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changes
-------

next (unreleased)
^^^^^^^^^^^^^^^^^

* Add pluggable authentication plugins.

0.2.0 (2023-06-11)
^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions aiomysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
InternalError,
NotSupportedError, ProgrammingError, MySQLError)

from .auth import AuthPlugin
from .connection import Connection, connect
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
from .pool import create_pool, Pool
Expand All @@ -55,6 +56,7 @@
'escape_sequence',
'escape_string',

"AuthPlugin",
'Connection',
'Pool',
'connect',
Expand Down
168 changes: 168 additions & 0 deletions aiomysql/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from pymysql import OperationalError
from pymysql.connections import _auth

from .log import logger

if TYPE_CHECKING:
from aiomysql.connection import Connection


@dataclass
class AuthInfo:
password: str
secure: bool
conn: "Connection"


class AuthPlugin:
"""
Abstract base class for authentication plugins.
"""

name = ""

async def auth(self, auth_info, data):
"""
Async generator for authentication process.

Subclasses should extend this method.

Many authentication plugins require back-and-forth exchanges
with the server. These client/server IO - including constructing
the MySQL protocol packets - is handled by the Connection.
All this generator needs to do is receive and send plugin specific data.

Example:
```
class EchoPlugin(AuthPlugin):
async def auth(self, auth_info, data):
data_from_server = data
while True:
data_to_server = data_from_server
data_from_server = yield data_to_server
```

:param auth_info: Various metadata from the current connection,
including a reference to the connection itself.
:param data: Arbitrary data sent by the server.
This can be, for example, a salt, but it's really up to the
plugin protocol to choose.
"""
yield b""

async def start(
self, auth_info, data
):
state = self.auth(auth_info, data)
data = await state.__anext__()
return data, state


class MysqlNativePassword(AuthPlugin):
name = "mysql_native_password"

async def auth(self, auth_info, data):
yield _auth.scramble_native_password(auth_info.password.encode('latin1'), data)


class CachingSha2Password(AuthPlugin):
name = "caching_sha2_password"

async def auth(self, auth_info, data):
salt = data
if auth_info.password:
data = yield _auth.scramble_caching_sha2(
auth_info.password.encode('latin1'), data
)
else:
data = yield b""

# magic numbers:
# 2 - request public key
# 3 - fast auth succeeded
# 4 - need full auth

n = data[0]

if n == 3:
logger.debug("caching sha2: succeeded by fast path.")
yield None
return

if n != 4:
raise OperationalError("caching sha2: Unknown "
"result for fast auth: {}".format(n))

logger.debug("caching sha2: Trying full auth...")

if auth_info.secure:
logger.debug("caching sha2: Sending plain "
"password via secure connection")
yield auth_info.password.encode('latin1') + b'\0'
return

if not auth_info.conn.server_public_key:
auth_info.conn.server_public_key = yield b'\x02'
logger.debug(auth_info.conn.server_public_key.decode('ascii'))

yield _auth.sha2_rsa_encrypt(
auth_info.password.encode('latin1'), salt,
auth_info.conn.server_public_key
)


class Sha256Password(AuthPlugin):
name = "sha256_password"

async def auth(self, auth_info, data):
if auth_info.secure:
logger.debug("sha256: Sending plain password")
yield auth_info.password.encode('latin1') + b'\0'
return

salt = data

if auth_info.password:
data = yield b'\1' # request public key
auth_info.conn.server_public_key = data
logger.debug(
"Received public key:\n%s",
auth_info.conn.server_public_key.decode('ascii')
)
yield _auth.sha2_rsa_encrypt(
auth_info.password.encode('latin1'), salt,
auth_info.conn.server_public_key.server_public_key
)

else:
yield b'\0' # empty password


class MysqlClearPassword(AuthPlugin):
name = "mysql_clear_password"

async def auth(self, auth_info, data):
yield auth_info.password.encode('latin1') + b'\0'


class MysqlOldPassword(AuthPlugin):
name = "mysql_old_password"

async def auth(self, auth_info, data):
yield _auth.scramble_old_password(
auth_info.password.encode('latin1'),
data,
) + b'\0'


def get_plugins():
return [
MysqlNativePassword(),
CachingSha2Password(),
Sha256Password(),
MysqlClearPassword(),
MysqlOldPassword()
]
Loading