Skip to content

Commit

Permalink
chore: refactor credentials to be initialized in Connector (#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Jan 31, 2024
1 parent 811f661 commit d4d9b15
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 195 deletions.
24 changes: 19 additions & 5 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import socket
from threading import Thread
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
from typing import Any, Dict, Optional, Type

import google.auth
from google.auth.credentials import Credentials
from google.auth.credentials import with_scopes_if_required

import google.cloud.sql.connector.asyncpg as asyncpg
from google.cloud.sql.connector.exceptions import ConnectorLoopError
Expand All @@ -34,9 +38,6 @@
from google.cloud.sql.connector.utils import format_database_user
from google.cloud.sql.connector.utils import generate_keys

if TYPE_CHECKING:
from google.auth.credentials import Credentials

logger = logging.getLogger(name=__name__)

ASYNC_DRIVERS = ["asyncpg"]
Expand Down Expand Up @@ -109,13 +110,26 @@ def __init__(
)
self._instances: Dict[str, Instance] = {}

# initialize credentials
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
if credentials:
# verfiy custom credentials are proper type
# and atleast base class of google.auth.credentials
if not isinstance(credentials, Credentials):
raise TypeError(
"credentials must be of type google.auth.credentials.Credentials,"
f" got {type(credentials)}"
)
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
self._credentials, _ = google.auth.default(scopes=scopes)
# set default params for connections
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._ip_type = ip_type
self._quota_project = quota_project
self._sqladmin_api_endpoint = sqladmin_api_endpoint
self._credentials = credentials
self._user_agent = user_agent

def connect(
Expand Down
8 changes: 0 additions & 8 deletions google/cloud/sql/connector/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ class PlatformNotSupportedError(Exception):
pass


class CredentialsTypeError(Exception):
"""
Raised when credentials parameter is not proper type.
"""

pass


class AutoIAMAuthNotSupported(Exception):
"""
Exception to be raised when Automatic IAM Authentication is not
Expand Down
15 changes: 3 additions & 12 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@

from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import CredentialsTypeError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.refresh_utils import _get_ephemeral
from google.cloud.sql.connector.refresh_utils import _get_metadata
from google.cloud.sql.connector.refresh_utils import _is_valid
from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh
from google.cloud.sql.connector.utils import _auth_init
from google.cloud.sql.connector.utils import write_to_file
from google.cloud.sql.connector.version import __version__ as version

Expand Down Expand Up @@ -157,7 +155,6 @@ class Instance:
:type credentials: google.auth.credentials.Credentials
:param credentials
Credentials object used to authenticate connections to Cloud SQL server.
If not specified, Application Default Credentials are used.
:param enable_iam_auth
Enables automatic IAM database authentication for Postgres or MySQL
Expand Down Expand Up @@ -206,7 +203,7 @@ def _client_session(self) -> aiohttp.ClientSession:
self.__client_session = aiohttp.ClientSession(headers=headers)
return self.__client_session

_credentials: Optional[Credentials] = None
_credentials: Credentials
_keys: asyncio.Future

_instance_connection_string: str
Expand All @@ -227,7 +224,7 @@ def __init__(
driver_name: str,
keys: asyncio.Future,
loop: asyncio.AbstractEventLoop,
credentials: Optional[Credentials] = None,
credentials: Credentials,
enable_iam_auth: bool = False,
quota_project: Optional[str] = None,
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
Expand All @@ -250,13 +247,7 @@ def __init__(
self._sqladmin_api_endpoint = sqladmin_api_endpoint
self._loop = loop
self._keys = keys
# validate credentials type
if not isinstance(credentials, Credentials) and credentials is not None:
raise CredentialsTypeError(
"Arg credentials must be type 'google.auth.credentials.Credentials' "
"or None (to use Application Default Credentials)"
)
self._credentials = _auth_init(credentials)
self._credentials = credentials
self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2, rate=1 / 30, loop=self._loop
)
Expand Down
25 changes: 1 addition & 24 deletions google/cloud/sql/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Tuple
from typing import Tuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from google.auth import default
from google.auth.credentials import Credentials
from google.auth.credentials import with_scopes_if_required


async def generate_keys() -> Tuple[bytes, str]:
Expand Down Expand Up @@ -104,23 +101,3 @@ def format_database_user(database_version: str, user: str) -> str:
return user.split("@")[0]

return user


def _auth_init(credentials: Optional[Credentials]) -> Credentials:
"""Creates google.auth credentials object with scopes required to make
calls to the Cloud SQL Admin APIs.
:type credentials: google.auth.credentials.Credentials
:param credentials
Credentials object used to authenticate connections to Cloud SQL server.
If not specified, Application Default Credentials are used.
"""
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
# if Credentials object is passed in, use for authentication
if isinstance(credentials, Credentials):
credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
credentials, _ = default(scopes=scopes)

return credentials
138 changes: 69 additions & 69 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from google.auth.credentials import Credentials
from google.auth.credentials import with_scopes_if_required
from google.oauth2 import service_account
from mock import patch
import pytest # noqa F401 Needed to run the tests
from unit.mocks import FakeCSQLInstance # type: ignore

Expand Down Expand Up @@ -146,78 +145,79 @@ async def instance(
keys = asyncio.create_task(generate_keys())
_, client_key = await keys

with patch("google.cloud.sql.connector.utils.default") as mock_auth:
mock_auth.return_value = fake_credentials, None
# mock Cloud SQL Admin API calls
with aioresponses() as mocked:
mocked.get(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}/connectSettings",
status=200,
body=mock_instance.connect_settings(),
repeat=True,
)
mocked.post(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}:generateEphemeralCert",
status=200,
body=mock_instance.generate_ephemeral(client_key),
repeat=True,
)

instance = Instance(
f"{mock_instance.project}:{mock_instance.region}:{mock_instance.name}",
"pg8000",
keys,
loop,
)

yield instance
await instance.close()
# mock Cloud SQL Admin API calls
with aioresponses() as mocked:
mocked.get(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}/connectSettings",
status=200,
body=mock_instance.connect_settings(),
repeat=True,
)
mocked.post(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}:generateEphemeralCert",
status=200,
body=mock_instance.generate_ephemeral(client_key),
repeat=True,
)

instance = Instance(
f"{mock_instance.project}:{mock_instance.region}:{mock_instance.name}",
"pg8000",
keys,
loop,
fake_credentials,
)

yield instance
await instance.close()


@pytest.fixture
async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, None]:
instance_connection_name = "my-project:my-region:my-instance"
project, region, instance_name = instance_connection_name.split(":")
# initialize connector
connector = Connector()
with patch("google.cloud.sql.connector.utils.default") as mock_auth:
mock_auth.return_value = fake_credentials, None
# mock Cloud SQL Admin API calls
mock_instance = FakeCSQLInstance(project, region, instance_name)

async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]:
"""
Helper method to await keys of Connector in tests prior to
initializing an Instance object.
"""
return await future

# converting asyncio.Future into concurrent.Future
# await keys in background thread so that .result() is set
# required because keys are needed for mocks, but are not awaited
# in the code until Instance() is initialized
_, client_key = asyncio.run_coroutine_threadsafe(
wait_for_keys(connector._keys), connector._loop
).result()
with aioresponses() as mocked:
mocked.get(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings",
status=200,
body=mock_instance.connect_settings(),
repeat=True,
)
mocked.post(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}:generateEphemeralCert",
status=200,
body=mock_instance.generate_ephemeral(client_key),
repeat=True,
)
# initialize Instance using mocked API calls
instance = Instance(
instance_connection_name, "pg8000", connector._keys, connector._loop
)

connector._instances[instance_connection_name] = instance

yield connector
connector.close()
connector = Connector(credentials=fake_credentials)
# mock Cloud SQL Admin API calls
mock_instance = FakeCSQLInstance(project, region, instance_name)

async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]:
"""
Helper method to await keys of Connector in tests prior to
initializing an Instance object.
"""
return await future

# converting asyncio.Future into concurrent.Future
# await keys in background thread so that .result() is set
# required because keys are needed for mocks, but are not awaited
# in the code until Instance() is initialized
_, client_key = asyncio.run_coroutine_threadsafe(
wait_for_keys(connector._keys), connector._loop
).result()
with aioresponses() as mocked:
mocked.get(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings",
status=200,
body=mock_instance.connect_settings(),
repeat=True,
)
mocked.post(
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}:generateEphemeralCert",
status=200,
body=mock_instance.generate_ephemeral(client_key),
repeat=True,
)
# initialize Instance using mocked API calls
instance = Instance(
instance_connection_name,
"pg8000",
connector._keys,
connector._loop,
fake_credentials,
)

connector._instances[instance_connection_name] = instance

yield connector
connector.close()
Loading

0 comments on commit d4d9b15

Please sign in to comment.