Skip to content

Commit

Permalink
Remove health check and "bad" node concept
Browse files Browse the repository at this point in the history
A concern was raised that our samples should not be demonstrating any
form of polling from within an entity workflow loop, even if the poll
frequency is low. Instead, we should point to long-running activity patterns.
  • Loading branch information
dandavison committed Jul 24, 2024
1 parent 35b476d commit c0e27a6
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 67 deletions.
12 changes: 5 additions & 7 deletions tests/updates_and_signals/safe_message_handlers/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from updates_and_signals.safe_message_handlers.activities import (
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)
from updates_and_signals.safe_message_handlers.workflow import (
Expand All @@ -30,7 +29,7 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand Down Expand Up @@ -82,7 +81,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand All @@ -106,8 +105,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
total_num_nodes=5, job_name="jobby-job"
),
)
# the second call should not assign more nodes (it may return fewer if the health check finds bad nodes
# in between the two signals.)
# the second call should not assign more nodes
assert result_1.nodes_assigned >= result_2.nodes_assigned


Expand All @@ -121,7 +119,7 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
client,
task_queue=task_queue,
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
cluster_manager_handle = await client.start_workflow(
ClusterManagerWorkflow.run,
Expand Down Expand Up @@ -152,4 +150,4 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
finally:
await cluster_manager_handle.signal(ClusterManagerWorkflow.shutdown_cluster)
result = await cluster_manager_handle.result()
assert result.num_currently_assigned_nodes + result.num_bad_nodes == 24
assert result.num_currently_assigned_nodes == 24
16 changes: 0 additions & 16 deletions updates_and_signals/safe_message_handlers/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,3 @@ class UnassignNodesForJobInput:
async def unassign_nodes_for_job(input: UnassignNodesForJobInput) -> None:
print(f"Deallocating nodes {input.nodes} from job {input.job_name}")
await asyncio.sleep(0.1)


@dataclass
class FindBadNodesInput:
nodes_to_check: Set[str]


@activity.defn
async def find_bad_nodes(input: FindBadNodesInput) -> Set[str]:
await asyncio.sleep(0.1)
bad_nodes = set([n for n in input.nodes_to_check if int(n) % 5 == 0])
if bad_nodes:
print(f"Found bad nodes: {bad_nodes}")
else:
print("No new bad nodes found.")
return bad_nodes
3 changes: 1 addition & 2 deletions updates_and_signals/safe_message_handlers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from updates_and_signals.safe_message_handlers.workflow import (
ClusterManagerWorkflow,
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)

Expand All @@ -22,7 +21,7 @@ async def main():
client,
task_queue="safe-message-handlers-task-queue",
workflows=[ClusterManagerWorkflow],
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
activities=[assign_nodes_to_job, unassign_nodes_for_job],
):
# Wait until interrupted
logging.info("ClusterManagerWorkflow worker started, ctrl+c to exit")
Expand Down
46 changes: 4 additions & 42 deletions updates_and_signals/safe_message_handlers/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
from typing import Dict, List, Optional, Set

from temporalio import workflow
from temporalio.common import RetryPolicy
from temporalio.exceptions import ApplicationError

from updates_and_signals.safe_message_handlers.activities import (
AssignNodesToJobInput,
FindBadNodesInput,
UnassignNodesForJobInput,
assign_nodes_to_job,
find_bad_nodes,
unassign_nodes_for_job,
)

Expand All @@ -37,7 +34,6 @@ class ClusterManagerInput:
@dataclass
class ClusterManagerResult:
num_currently_assigned_nodes: int
num_bad_nodes: int


# Be in the habit of storing message inputs and outputs in serializable structures.
Expand Down Expand Up @@ -116,7 +112,7 @@ async def assign_nodes_to_job(
)
nodes_to_assign = unassigned_nodes[: input.total_num_nodes]
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with delete_job and perform_health_checks, which both touch self.state.nodes.
# with delete_job, which touches self.state.nodes.
await self._assign_nodes_to_job(nodes_to_assign, input.job_name)
return ClusterManagerAssignNodesToJobResult(
nodes_assigned=self.get_assigned_nodes(job_name=input.job_name)
Expand Down Expand Up @@ -150,7 +146,7 @@ async def delete_job(self, input: ClusterManagerDeleteJobInput) -> None:
k for k, v in self.state.nodes.items() if v == input.job_name
]
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with assign_nodes_to_job and perform_health_checks, which all touch self.state.nodes.
# with assign_nodes_to_job, which touches self.state.nodes.
await self._unassign_nodes_for_job(nodes_to_unassign, input.job_name)

async def _unassign_nodes_for_job(
Expand All @@ -167,40 +163,11 @@ async def _unassign_nodes_for_job(
def get_unassigned_nodes(self) -> List[str]:
return [k for k, v in self.state.nodes.items() if v is None]

def get_bad_nodes(self) -> Set[str]:
return set([k for k, v in self.state.nodes.items() if v == "BAD!"])

def get_assigned_nodes(self, *, job_name: Optional[str] = None) -> Set[str]:
if job_name:
return set([k for k, v in self.state.nodes.items() if v == job_name])
else:
return set(
[
k
for k, v in self.state.nodes.items()
if v is not None and v != "BAD!"
]
)

async def perform_health_checks(self) -> None:
async with self.nodes_lock:
assigned_nodes = self.get_assigned_nodes()
try:
# This await would be dangerous without nodes_lock because it yields control and allows interleaving
# with assign_nodes_to_job and delete_job, which both touch self.state.nodes.
bad_nodes = await workflow.execute_activity(
find_bad_nodes,
FindBadNodesInput(nodes_to_check=assigned_nodes),
start_to_close_timeout=timedelta(seconds=10),
# This health check is optional, and our lock would block the whole workflow if we let it retry forever.
retry_policy=RetryPolicy(maximum_attempts=1),
)
for node in bad_nodes:
self.state.nodes[node] = "BAD!"
except Exception as e:
workflow.logger.warn(
f"Health check failed with error {type(e).__name__}:{e}"
)
return set([k for k, v in self.state.nodes.items() if v is not None])

# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
# continue-as-new.
Expand Down Expand Up @@ -229,9 +196,7 @@ def should_continue_as_new(self) -> bool:
async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
self.init(input)
await workflow.wait_condition(lambda: self.state.cluster_started)
# Perform health checks at intervals.
while True:
await self.perform_health_checks()
try:
await workflow.wait_condition(
lambda: self.state.cluster_shutdown
Expand All @@ -250,7 +215,4 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
test_continue_as_new=input.test_continue_as_new,
)
)
return ClusterManagerResult(
len(self.get_assigned_nodes()),
len(self.get_bad_nodes()),
)
return ClusterManagerResult(len(self.get_assigned_nodes()))

0 comments on commit c0e27a6

Please sign in to comment.