From f8de0f11b1d48a920ce384713bb6ef2857309f7e Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 10 Jun 2024 13:14:33 +0000 Subject: [PATCH] chore: add lock usage --- google/cloud/sql/connector/instance.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 7e6ed5a7..2490ce61 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -104,11 +104,13 @@ async def force_refresh(self) -> None: """ # if next refresh is not already in progress, cancel it and schedule new one immediately if not self._lock.locked(): - self._next.cancel() - self._next = self._schedule_refresh(0) + async with self._lock: + self._next.cancel() + self._next = self._schedule_refresh(0) # block all sequential connection attempts on the next refresh result if current is invalid - if not await _is_valid(self._current): - self._current = self._next + async with self._lock: + if not await _is_valid(self._current): + self._current = self._next async def _perform_refresh(self) -> ConnectionInfo: """Retrieves instance metadata and ephemeral certificate from the @@ -122,8 +124,8 @@ async def _perform_refresh(self) -> ConnectionInfo: """ async with self._lock: logger.debug( - f"['{self._instance_connection_string}']: Connection info refresh " - "operation started" + f"['{self._instance_connection_string}']: Connection info " + "refresh operation started" ) try: @@ -241,8 +243,9 @@ async def close(self) -> None: f"['{self._instance_connection_string}']: Canceling connection info " "refresh operation tasks" ) - self._current.cancel() - self._next.cancel() - # gracefully wait for tasks to cancel - tasks = asyncio.gather(self._current, self._next, return_exceptions=True) + async with self._lock: + self._current.cancel() + self._next.cancel() + # gracefully wait for tasks to cancel + tasks = asyncio.gather(self._current, self._next, return_exceptions=True) await asyncio.wait_for(tasks, timeout=2.0)