diff --git a/tests/updates_and_signals/atomic_message_handlers_test.py b/tests/updates_and_signals/atomic_message_handlers_test.py index 6268c5ad..d13b0c7e 100644 --- a/tests/updates_and_signals/atomic_message_handlers_test.py +++ b/tests/updates_and_signals/atomic_message_handlers_test.py @@ -29,13 +29,14 @@ async def test_atomic_message_handlers(client: Client): ClusterManagerWorkflow.run, ClusterManagerInput(), id=f"ClusterManagerWorkflow-{uuid.uuid4()}", - task_queue=task_queue + task_queue=task_queue, ) await do_cluster_lifecycle(cluster_manager_handle, delay_seconds=1) result = await cluster_manager_handle.result() assert result.max_assigned_nodes == 12 assert result.num_currently_assigned_nodes == 0 + async def test_update_failure(client: Client): task_queue = f"tq-{uuid.uuid4()}" async with Worker( @@ -54,14 +55,16 @@ async def test_update_failure(client: Client): await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster) await cluster_manager_handle.execute_update( - ClusterManagerWorkflow.allocate_n_nodes_to_job, - ClusterManagerAllocateNNodesToJobInput(num_nodes=24, task_name=f"big-task") + ClusterManagerWorkflow.allocate_n_nodes_to_job, + ClusterManagerAllocateNNodesToJobInput(num_nodes=24, task_name=f"big-task"), ) try: # Try to allocate too many nodes await cluster_manager_handle.execute_update( - ClusterManagerWorkflow.allocate_n_nodes_to_job, - ClusterManagerAllocateNNodesToJobInput(num_nodes=3, task_name=f"little-task") + ClusterManagerWorkflow.allocate_n_nodes_to_job, + ClusterManagerAllocateNNodesToJobInput( + num_nodes=3, task_name=f"little-task" + ), ) except WorkflowUpdateFailedError as e: assert e.cause.message == "Cannot allocate 3 nodes; have only 1 available" @@ -69,4 +72,3 @@ async def test_update_failure(client: Client): await cluster_manager_handle.signal(ClusterManagerWorkflow.shutdown_cluster) result = await cluster_manager_handle.result() assert result.num_currently_assigned_nodes == 24 - diff --git a/updates_and_signals/atomic_message_handlers/starter.py b/updates_and_signals/atomic_message_handlers/starter.py index c88c8d96..6097012f 100644 --- a/updates_and_signals/atomic_message_handlers/starter.py +++ b/updates_and_signals/atomic_message_handlers/starter.py @@ -16,15 +16,17 @@ async def do_cluster_lifecycle(wf: WorkflowHandle, delay_seconds: Optional[int] = None): - + await wf.signal(ClusterManagerWorkflow.start_cluster) allocation_updates = [] for i in range(6): allocation_updates.append( wf.execute_update( - ClusterManagerWorkflow.allocate_n_nodes_to_job, - ClusterManagerAllocateNNodesToJobInput(num_nodes=2, task_name=f"task-{i}") + ClusterManagerWorkflow.allocate_n_nodes_to_job, + ClusterManagerAllocateNNodesToJobInput( + num_nodes=2, task_name=f"task-{i}" + ), ) ) await asyncio.gather(*allocation_updates) @@ -36,8 +38,8 @@ async def do_cluster_lifecycle(wf: WorkflowHandle, delay_seconds: Optional[int] for i in range(6): deletion_updates.append( wf.execute_update( - ClusterManagerWorkflow.delete_job, - ClusterManagerDeleteJobInput(task_name=f"task-{i}") + ClusterManagerWorkflow.delete_job, + ClusterManagerDeleteJobInput(task_name=f"task-{i}"), ) ) await asyncio.gather(*deletion_updates) diff --git a/updates_and_signals/atomic_message_handlers/workflow.py b/updates_and_signals/atomic_message_handlers/workflow.py index b8c736d4..54829abf 100644 --- a/updates_and_signals/atomic_message_handlers/workflow.py +++ b/updates_and_signals/atomic_message_handlers/workflow.py @@ -42,15 +42,18 @@ class ClusterManagerResult: max_assigned_nodes: int num_currently_assigned_nodes: int + @dataclass(kw_only=True) class ClusterManagerAllocateNNodesToJobInput: num_nodes: int task_name: str + @dataclass(kw_only=True) class ClusterManagerDeleteJobInput: task_name: str + # ClusterManagerWorkflow keeps track of the allocations of a cluster of nodes. # Via signals, the cluster can be started and shutdown. # Via updates, clients can also assign jobs to nodes and delete jobs. @@ -76,8 +79,7 @@ async def shutdown_cluster(self): @workflow.update async def allocate_n_nodes_to_job( - self, - input: ClusterManagerAllocateNNodesToJobInput + self, input: ClusterManagerAllocateNNodesToJobInput ) -> List[str]: await workflow.wait_condition(lambda: self.state.cluster_started) if self.state.cluster_shutdown: @@ -97,7 +99,7 @@ async def allocate_n_nodes_to_job( raise ApplicationError( f"Cannot allocate {input.num_nodes} nodes; have only {len(unassigned_nodes)} available" ) - assigned_nodes = unassigned_nodes[:input.num_nodes] + assigned_nodes = unassigned_nodes[: input.num_nodes] # This await would be dangerous without nodes_lock because it yields control and allows interleaving. await self._allocate_nodes_to_job(assigned_nodes, input.task_name) self.state.max_assigned_nodes = max( @@ -122,12 +124,12 @@ async def delete_job(self, input: ClusterManagerDeleteJobInput): # If you want the client to receive a failure, either add an update validator and throw the # exception from there, or raise an ApplicationError. Other exceptions in the main handler # will cause the workflow to keep retrying and get it stuck. - raise ApplicationError( - "Cannot delete a job: Cluster is already shut down" - ) + raise ApplicationError("Cannot delete a job: Cluster is already shut down") async with self.nodes_lock: - nodes_to_free = [k for k, v in self.state.nodes.items() if v == input.task_name] + nodes_to_free = [ + k for k, v in self.state.nodes.items() if v == input.task_name + ] # This await would be dangerous without nodes_lock because it yields control and allows interleaving. await self._deallocate_nodes_for_job(nodes_to_free, input.task_name)