Skip to content

Commit

Permalink
feat: add support for a custom user agent (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom authored Jan 26, 2024
1 parent 48a9ccb commit 82da410
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
3 changes: 3 additions & 0 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
loop: Optional[asyncio.AbstractEventLoop] = None,
quota_project: Optional[str] = None,
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
user_agent: Optional[str] = None,
) -> None:
# if event loop is given, use for background tasks
if loop:
Expand All @@ -115,6 +116,7 @@ def __init__(
self._quota_project = quota_project
self._sqladmin_api_endpoint = sqladmin_api_endpoint
self._credentials = credentials
self._user_agent = user_agent

def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
Expand Down Expand Up @@ -211,6 +213,7 @@ async def connect_async(
enable_iam_auth,
self._quota_project,
self._sqladmin_api_endpoint,
user_agent=self._user_agent,
)
self._instances[instance_connection_string] = instance

Expand Down
14 changes: 13 additions & 1 deletion google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
)


def _format_user_agent(version: str, driver: str, custom: Optional[str]) -> str:
agent = f"{APPLICATION_NAME}/{version}+{driver}"
if custom:
agent = f"{agent} {custom}"
return agent


class Instance:
"""A class to manage the details of the connection to a Cloud SQL
instance, including refreshing the credentials.
Expand Down Expand Up @@ -224,6 +231,7 @@ def __init__(
enable_iam_auth: bool = False,
quota_project: Optional[str] = None,
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
user_agent: Optional[str] = None,
) -> None:
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
Expand All @@ -233,7 +241,11 @@ def __init__(

self._enable_iam_auth = enable_iam_auth

self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
self._user_agent_string = _format_user_agent(
version,
driver_name,
user_agent,
)
self._quota_project = quota_project
self._sqladmin_api_endpoint = sqladmin_api_endpoint
self._loop = loop
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from google.cloud.sql.connector.instance import IPTypes
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.utils import generate_keys
from google.cloud.sql.connector.version import __version__ as version


@pytest.fixture
Expand Down Expand Up @@ -84,7 +85,13 @@ async def test_Instance_init(
)
with patch("google.cloud.sql.connector.utils.default") as mock_auth:
mock_auth.return_value = fake_credentials, None
instance = Instance(connect_string, "pymysql", keys, event_loop)
instance = Instance(
connect_string,
"pymysql",
keys,
event_loop,
user_agent="custom/v1.0.0",
)
project_result = instance._project
region_result = instance._region
instance_result = instance._instance
Expand All @@ -93,6 +100,10 @@ async def test_Instance_init(
and region_result == "test-region"
and instance_result == "test-instance"
)
assert (
instance._user_agent_string
== f"cloud-sql-python-connector/{version}+pymysql custom/v1.0.0"
)
# cleanup instance
await instance.close()

Expand Down

0 comments on commit 82da410

Please sign in to comment.