Skip to content

Commit

Permalink
chore: better management of asyncio tasks and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xadahiya committed Sep 30, 2024
1 parent 96644ff commit 30497bf
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 52 deletions.
19 changes: 7 additions & 12 deletions snapshotter/utils/aggregation_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,14 @@ async def _processor_task(
).json(),
},
)
current_time = time.time()
commit_task = asyncio.create_task(
self._commit_payload(
task_type=task_type,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_aggregates,
_ipfs_writer_client=self._ipfs_writer_client,
),
await self._commit_payload(
task_type=task_type,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_aggregates,
_ipfs_writer_client=self._ipfs_writer_client,
)
self._active_tasks.add((current_time, commit_task))
commit_task.add_done_callback(lambda _: self._active_tasks.discard((current_time, commit_task)))
self._logger.debug(
'Updated epoch processing status in aggregation worker for project {} for transition {}',
project_id, SnapshotterStates.SNAPSHOT_BUILD.value,
Expand Down
45 changes: 31 additions & 14 deletions snapshotter/utils/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,7 @@ async def _commit_payload(

# Upload to web3 storage
if storage_flag:
current_time = time.time()
task = asyncio.create_task(
self._upload_web3_storage(snapshot_bytes),
)
self._active_tasks.add((current_time, task))
task.add_done_callback(lambda _: self._active_tasks.discard((current_time, task)))
await self._create_tracked_task(self._upload_web3_storage(snapshot_bytes))

async def _rabbitmq_consumer(self, loop):
"""
Expand Down Expand Up @@ -584,6 +579,34 @@ async def _close_stream(self):

self._last_stream_close_time = time.time()

async def _create_tracked_task(self, task):
"""
Creates and tracks an asynchronous task.
This method creates a new task from the given coroutine, adds it to the set of active tasks,
and sets up a callback to remove the task from the set when it's completed.
Args:
task (Coroutine): The coroutine to be executed as a task.
Returns:
None
Note:
This method is used to keep track of all running tasks for potential cleanup or monitoring.
"""
# Get the current timestamp
current_time = time.time()

# Create a new task from the given coroutine
new_task = asyncio.create_task(task)

# Add the task to the set of active tasks, along with its creation time
self._active_tasks.add((current_time, new_task))

# Set up a callback to remove the task from the set when it's done
new_task.add_done_callback(lambda _: self._active_tasks.discard((current_time, new_task)))

async def _send_submission_to_collector(self, snapshot_cid, epoch_id, project_id, slot_id=None, private_key=None):
"""
Sends a snapshot submission to the collector.
Expand Down Expand Up @@ -628,11 +651,8 @@ async def _send_submission_to_collector(self, snapshot_cid, epoch_id, project_id
if (current_time - self._last_stream_close_time) > self._stream_lifetime:
await self._close_stream()

await self.send_message(msg)
await self._create_tracked_task(self.send_message(msg))

except asyncio.TimeoutError:
self._logger.error('Timeout waiting for response, assuming success')
return
except Exception as e:
self._logger.opt(
exception=True,
Expand Down Expand Up @@ -677,12 +697,9 @@ async def send_message(self, msg, simulation=False):
else:
try:
async with self.open_stream() as stream:
await asyncio.wait_for(stream.send_message(msg), timeout=30)
await stream.send_message(msg)
self._logger.debug(f'Sent message: {msg}')
return {'status_code': 200}
except asyncio.TimeoutError:
self._logger.error('Timeout waiting for response, assuming success')
return
except Exception as e:
self._logger.opt(exception=settings.logs.trace_enabled).error(f'Failed to send message: {e}')
raise Exception(f'Failed to send message: {e}')
Expand Down
40 changes: 14 additions & 26 deletions snapshotter/utils/snapshot_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,14 @@ async def _process_single_mode(self, msg_obj: PowerloomSnapshotProcessMessage, t
# Execute Redis pipeline
await p.execute()

# Commit payload asynchronously
current_time = time.time()
task = asyncio.create_task(
self._commit_payload(
task_type=task_type,
_ipfs_writer_client=self._ipfs_writer_client,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_snapshots,
),
await self._commit_payload(
task_type=task_type,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_snapshots,
_ipfs_writer_client=self._ipfs_writer_client,
)
self._active_tasks.add((current_time, task))
task.add_done_callback(lambda _: self._active_tasks.discard((current_time, task)))

async def _process_bulk_mode(self, msg_obj: PowerloomSnapshotProcessMessage, task_type: str):
"""
Expand Down Expand Up @@ -313,20 +307,14 @@ async def _process_bulk_mode(self, msg_obj: PowerloomSnapshotProcessMessage, tas
)
await p.execute()

# Commit payload asynchronously
current_time = time.time()
task = asyncio.create_task(
self._commit_payload(
task_type=task_type,
_ipfs_writer_client=self._ipfs_writer_client,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_snapshots,
),
await self._commit_payload(
task_type=task_type,
project_id=project_id,
epoch=msg_obj,
snapshot=snapshot,
storage_flag=settings.web3storage.upload_snapshots,
_ipfs_writer_client=self._ipfs_writer_client,
)
self._active_tasks.add((current_time, task))
task.add_done_callback(lambda _: self._active_tasks.discard((current_time, task)))

async def _process_task(self, msg_obj: PowerloomSnapshotProcessMessage, task_type: str):
"""
Expand Down

0 comments on commit 30497bf

Please sign in to comment.