Skip to content

Commit

Permalink
Scheduler add workgraph subsriber
Browse files Browse the repository at this point in the history
The scheduler will listen to the task from scheduler_queue
  • Loading branch information
superstar54 committed Sep 2, 2024
1 parent 35aab9b commit d267263
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
19 changes: 18 additions & 1 deletion aiida_workgraph/engine/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def load_instance_state(
self._temp = {"awaitables": {}}

self.set_logger(self.node.logger)
self.add_workgraph_subsriber()

if self._awaitables:
# For the "ascyncio.tasks.Task" awaitable, because there are only in-memory,
Expand Down Expand Up @@ -512,7 +513,7 @@ def setup(self) -> None:
# self.ctx._workgraph[pk]["_execution_count"] = {}
# data not to be persisted, because they are not serializable
self._temp = {"awaitables": {}}
# self.launch_workgraph(122305)
self.add_workgraph_subsriber()

def launch_workgraph(self, pk: str) -> None:
"""Launch the workgraph."""
Expand Down Expand Up @@ -1633,6 +1634,22 @@ def message_receive(
# Didn't match any known intents
raise RuntimeError("Unknown intent")

def call_on_receive_workgraph_message(self, _comm, msg):
"""Call on receive workgraph message."""
# self.report(f"Received workgraph message: {msg}")
pk = int(msg)
# To avoid "DbNode is not persistent", we need to schedule the call
self._schedule_rpc(self.launch_workgraph, pk=pk)
return True

def add_workgraph_subsriber(self) -> None:
"""Add workgraph subscriber."""
queue_name = "scheduler_queue"
# self.report(f"Add workgraph subscriber on queue: {queue_name}")
comm = self.runner.communicator._communicator
queue = comm.task_queue(queue_name, prefetch_count=1000)
queue.add_task_subscriber(self.callback)

def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]:
""""""
# expose outputs of the workgraph
Expand Down
11 changes: 11 additions & 0 deletions aiida_workgraph/utils/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def create_task_action(
controller._communicator.rpc_send(pk, message)


def create_scheduler_action(
pk: int,
):
"""Send workgraph task to scheduler."""

controller = get_manager().get_process_controller()
message = str(pk)
queue = controller._communicator.task_queue("scheduler_queue")
queue.task_send(message)


def get_task_state_info(node, name: str, key: str) -> str:
"""Get task state info from base.extras."""
from aiida.orm.utils.serialize import deserialize_unsafe
Expand Down
21 changes: 15 additions & 6 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def submit(
if self.process.process_state.value.upper() not in ["CREATED"]:
raise ValueError(f"Process {self.process.pk} has already been submitted.")
if to_scheduler:
self.continue_process_in_scheduler()
self.continue_process_in_scheduler(to_scheduler)
else:
self.continue_process()
# as long as we submit the process, it is a new submission, we should set restart_process to None
Expand Down Expand Up @@ -432,15 +432,24 @@ def continue_process(self):
process_controller = get_manager().get_process_controller()
process_controller.continue_process(self.pk)

def continue_process_in_scheduler(self):
"""Ask the scheduler to pick up the process from the database and run it."""
from aiida_workgraph.utils.control import create_task_action
def continue_process_in_scheduler(self, to_scheduler: Union[int, bool]) -> None:
"""Ask the scheduler to pick up the process from the database and run it.
If to_scheduler is an integer, it will be used as the scheduler pk.
Otherwise, it will send the message to the queue, and the scheduler will pick it up.
"""
from aiida_workgraph.utils.control import (
create_task_action,
create_scheduler_action,
)
from aiida_workgraph.engine.scheduler.client import get_scheduler
import kiwipy

try:
scheduler_pk = get_scheduler()
create_task_action(scheduler_pk, [self.pk], action="launch_workgraph")
if isinstance(to_scheduler, int):
scheduler_pk = get_scheduler()
create_task_action(scheduler_pk, [self.pk], action="launch_workgraph")
else:
create_scheduler_action(self.pk)
except ValueError:
print(
"""Scheduler is not running.
Expand Down

0 comments on commit d267263

Please sign in to comment.