From f3b683b92ca49625f6e95b2c72bf5a0864156562 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 2 Aug 2024 21:04:38 +0000 Subject: [PATCH 1/2] feat: use non-blocking disk read/writes --- google/cloud/sql/connector/connection_info.py | 9 +++++---- google/cloud/sql/connector/connector.py | 4 ++-- google/cloud/sql/connector/utils.py | 16 +++++++++------- requirements.txt | 1 + setup.py | 1 + tests/unit/test_instance.py | 6 +++--- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 06a0b976..7181134d 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -17,9 +17,10 @@ from dataclasses import dataclass import logging import ssl -from tempfile import TemporaryDirectory from typing import Any, Dict, Optional, TYPE_CHECKING +from aiofiles.tempfile import TemporaryDirectory + from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -45,7 +46,7 @@ class ConnectionInfo: expiration: datetime.datetime context: Optional[ssl.SSLContext] = None - def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: + async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: """Constructs a SSL/TLS context for the given connection info. Cache the SSL context to ensure we don't read from disk repeatedly when @@ -83,8 +84,8 @@ def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: # tmpdir and its contents are automatically deleted after the CA cert # and ephemeral cert are loaded into the SSLcontext. The values # need to be written to files in order to be loaded by the SSLContext - with TemporaryDirectory() as tmpdir: - ca_filename, cert_filename, key_filename = write_to_file( + async with TemporaryDirectory() as tmpdir: + ca_filename, cert_filename, key_filename = await write_to_file( tmpdir, self.server_ca_cert, self.client_cert, self.private_key ) context.load_cert_chain(cert_filename, keyfile=key_filename) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 1197470a..20235362 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -365,14 +365,14 @@ async def connect_async( if driver in ASYNC_DRIVERS: return await connector( ip_address, - conn_info.create_ssl_context(enable_iam_auth), + await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) # synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, - conn_info.create_ssl_context(enable_iam_auth), + await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) diff --git a/google/cloud/sql/connector/utils.py b/google/cloud/sql/connector/utils.py index 175e763f..47a318fb 100755 --- a/google/cloud/sql/connector/utils.py +++ b/google/cloud/sql/connector/utils.py @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import Tuple +import aiofiles from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -57,7 +59,7 @@ async def generate_keys() -> Tuple[bytes, str]: return priv_key, pub_key -def write_to_file( +async def write_to_file( dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes ) -> Tuple[str, str, str]: """ @@ -68,12 +70,12 @@ def write_to_file( cert_filename = f"{dir_path}/cert.pem" key_filename = f"{dir_path}/priv.pem" - with open(ca_filename, "w+") as ca_out: - ca_out.write(serverCaCert) - with open(cert_filename, "w+") as ephemeral_out: - ephemeral_out.write(ephemeralCert) - with open(key_filename, "wb") as priv_out: - priv_out.write(priv_key) + async with aiofiles.open(ca_filename, "w+") as ca_out: + await ca_out.write(serverCaCert) + async with aiofiles.open(cert_filename, "w+") as ephemeral_out: + await ephemeral_out.write(ephemeralCert) + async with aiofiles.open(key_filename, "wb") as priv_out: + await priv_out.write(priv_key) return (ca_filename, cert_filename, key_filename) diff --git a/requirements.txt b/requirements.txt index 04f9ae06..22a5681f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles==24.1.0 aiohttp==3.9.5 cryptography==42.0.8 Requests==2.32.3 diff --git a/setup.py b/setup.py index 22b0094a..d2a3e68e 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ ) release_status = "Development Status :: 5 - Production/Stable" dependencies = [ + "aiofiles", "aiohttp", "cryptography>=42.0.0", "Requests", diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 9b667675..5dcf1f5a 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -367,7 +367,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None await cache._current -def test_ConnectionInfo_caches_sslcontext() -> None: +async def test_ConnectionInfo_caches_sslcontext() -> None: info = ConnectionInfo( "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() ) @@ -375,6 +375,6 @@ def test_ConnectionInfo_caches_sslcontext() -> None: assert info.context is None # cache a 'context' info.context = "context" - # caling create_ssl_context should no-op with an existing 'context' - info.create_ssl_context() + # calling create_ssl_context should no-op with an existing 'context' + await info.create_ssl_context() assert info.context == "context" From 80896ade6386f6c1acb5d88212da1e9b835b84f8 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 2 Aug 2024 21:13:54 +0000 Subject: [PATCH 2/2] chore: update mock --- tests/unit/mocks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 31317981..03ad5a6e 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -19,9 +19,9 @@ import datetime import json import ssl -from tempfile import TemporaryDirectory from typing import Any, Callable, Dict, Literal, Optional, Tuple +from aiofiles.tempfile import TemporaryDirectory from aiohttp import web from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -203,8 +203,8 @@ async def create_ssl_context() -> ssl.SSLContext: # build default ssl.SSLContext context = ssl.create_default_context() # load ssl.SSLContext with certs - with TemporaryDirectory() as tmpdir: - ca_filename, cert_filename, key_filename = write_to_file( + async with TemporaryDirectory() as tmpdir: + ca_filename, cert_filename, key_filename = await write_to_file( tmpdir, server_ca_cert, ephemeral_cert, client_private ) context.load_cert_chain(cert_filename, keyfile=key_filename)