Skip to content

Commit 9ba977b

Browse files
authored
Cluster manager: use update to wait for cluster to start (#153)
* Use workflow.init, refactor * Start cluster automatically; use update to wait until started * Don't require node IDs to parse as ints
1 parent c03ad45 commit 9ba977b

File tree

6 files changed

+83
-43
lines changed

6 files changed

+83
-43
lines changed

message_passing/safe_message_handlers/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
This sample shows off important techniques for handling signals and updates, aka messages. In particular, it illustrates how message handlers can interleave or not be completed before the workflow completes, and how you can manage that.
44

55
* Here, using workflow.wait_condition, signal and update handlers will only operate when the workflow is within a certain state--between cluster_started and cluster_shutdown.
6-
* You can run start_workflow with an initializer signal that you want to run before anything else other than the workflow's constructor. This pattern is known as "signal-with-start."
76
* Message handlers can block and their actions can be interleaved with one another and with the main workflow. This can easily cause bugs, so you can use a lock to protect shared state from interleaved access.
87
* An "Entity" workflow, i.e. a long-lived workflow, periodically "continues as new". It must do this to prevent its history from growing too large, and it passes its state to the next workflow. You can check `workflow.info().is_continue_as_new_suggested()` to see when it's time.
98
* Most people want their message handlers to finish before the workflow run completes or continues as new. Use `await workflow.wait_condition(lambda: workflow.all_handlers_finished())` to achieve this.

message_passing/safe_message_handlers/activities.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ class AssignNodesToJobInput:
1111
job_name: str
1212

1313

14+
@dataclass
15+
class ClusterState:
16+
node_ids: List[str]
17+
18+
19+
@activity.defn
20+
async def start_cluster() -> ClusterState:
21+
return ClusterState(node_ids=[f"node-{i}" for i in range(25)])
22+
23+
1424
@activity.defn
1525
async def assign_nodes_to_job(input: AssignNodesToJobInput) -> None:
1626
print(f"Assigning nodes {input.nodes} to job {input.job_name}")
@@ -37,7 +47,7 @@ class FindBadNodesInput:
3747
@activity.defn
3848
async def find_bad_nodes(input: FindBadNodesInput) -> Set[str]:
3949
await asyncio.sleep(0.1)
40-
bad_nodes = set([n for n in input.nodes_to_check if int(n) % 5 == 0])
50+
bad_nodes = set([id for id in input.nodes_to_check if hash(id) % 5 == 0])
4151
if bad_nodes:
4252
print(f"Found bad nodes: {bad_nodes}")
4353
else:

message_passing/safe_message_handlers/starter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717

1818
async def do_cluster_lifecycle(wf: WorkflowHandle, delay_seconds: Optional[int] = None):
19-
20-
await wf.signal(ClusterManagerWorkflow.start_cluster)
19+
cluster_status = await wf.execute_update(
20+
ClusterManagerWorkflow.wait_until_cluster_started
21+
)
22+
print(f"Cluster started with {len(cluster_status.nodes)} nodes")
2123

2224
print("Assigning jobs to nodes...")
2325
allocation_updates = []

message_passing/safe_message_handlers/worker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ClusterManagerWorkflow,
99
assign_nodes_to_job,
1010
find_bad_nodes,
11+
start_cluster,
1112
unassign_nodes_for_job,
1213
)
1314

@@ -21,7 +22,12 @@ async def main():
2122
client,
2223
task_queue="safe-message-handlers-task-queue",
2324
workflows=[ClusterManagerWorkflow],
24-
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
25+
activities=[
26+
assign_nodes_to_job,
27+
unassign_nodes_for_job,
28+
find_bad_nodes,
29+
start_cluster,
30+
],
2531
):
2632
logging.info("ClusterManagerWorkflow worker started, ctrl+c to exit")
2733
await interrupt_event.wait()

message_passing/safe_message_handlers/workflow.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
UnassignNodesForJobInput,
1515
assign_nodes_to_job,
1616
find_bad_nodes,
17+
start_cluster,
1718
unassign_nodes_for_job,
1819
)
1920

@@ -65,18 +66,27 @@ class ClusterManagerAssignNodesToJobResult:
6566
# These updates must run atomically.
6667
@workflow.defn
6768
class ClusterManagerWorkflow:
68-
def __init__(self) -> None:
69-
self.state = ClusterManagerState()
69+
@workflow.init
70+
def __init__(self, input: ClusterManagerInput) -> None:
71+
if input.state:
72+
self.state = input.state
73+
else:
74+
self.state = ClusterManagerState()
75+
76+
if input.test_continue_as_new:
77+
self.max_history_length: Optional[int] = 120
78+
self.sleep_interval_seconds = 1
79+
else:
80+
self.max_history_length = None
81+
self.sleep_interval_seconds = 600
82+
7083
# Protects workflow state from interleaved access
7184
self.nodes_lock = asyncio.Lock()
72-
self.max_history_length: Optional[int] = None
73-
self.sleep_interval_seconds: int = 600
7485

75-
@workflow.signal
76-
async def start_cluster(self) -> None:
77-
self.state.cluster_started = True
78-
self.state.nodes = {str(k): None for k in range(25)}
79-
workflow.logger.info("Cluster started")
86+
@workflow.update
87+
async def wait_until_cluster_started(self) -> ClusterManagerState:
88+
await workflow.wait_condition(lambda: self.state.cluster_started)
89+
return self.state
8090

8191
@workflow.signal
8292
async def shutdown_cluster(self) -> None:
@@ -135,7 +145,7 @@ async def _assign_nodes_to_job(
135145
self.state.jobs_assigned.add(job_name)
136146

137147
# Even though it returns nothing, this is an update because the client may want to track it, for example
138-
# to wait for nodes to be unassignd before reassigning them.
148+
# to wait for nodes to be unassigned before reassigning them.
139149
@workflow.update
140150
async def delete_job(self, input: ClusterManagerDeleteJobInput) -> None:
141151
await workflow.wait_condition(lambda: self.state.cluster_started)
@@ -202,30 +212,15 @@ async def perform_health_checks(self) -> None:
202212
f"Health check failed with error {type(e).__name__}:{e}"
203213
)
204214

205-
# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
206-
# continue-as-new.
207-
def init(self, input: ClusterManagerInput) -> None:
208-
if input.state:
209-
self.state = input.state
210-
if input.test_continue_as_new:
211-
self.max_history_length = 120
212-
self.sleep_interval_seconds = 1
213-
214-
def should_continue_as_new(self) -> bool:
215-
if workflow.info().is_continue_as_new_suggested():
216-
return True
217-
# This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new.
218-
if (
219-
self.max_history_length
220-
and workflow.info().get_current_history_length() > self.max_history_length
221-
):
222-
return True
223-
return False
224-
225215
@workflow.run
226216
async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
227-
self.init(input)
228-
await workflow.wait_condition(lambda: self.state.cluster_started)
217+
cluster_state = await workflow.execute_activity(
218+
start_cluster, schedule_to_close_timeout=timedelta(seconds=10)
219+
)
220+
self.state.nodes = {k: None for k in cluster_state.node_ids}
221+
self.state.cluster_started = True
222+
workflow.logger.info("Cluster started")
223+
229224
# Perform health checks at intervals.
230225
while True:
231226
await self.perform_health_checks()
@@ -239,6 +234,8 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
239234
pass
240235
if self.state.cluster_shutdown:
241236
break
237+
# The cluster manager is a long-running "entity" workflow so we need to periodically checkpoint its state and
238+
# continue-as-new.
242239
if self.should_continue_as_new():
243240
# We don't want to leave any job assignment or deletion handlers half-finished when we continue as new.
244241
await workflow.wait_condition(lambda: workflow.all_handlers_finished())
@@ -255,3 +252,14 @@ async def run(self, input: ClusterManagerInput) -> ClusterManagerResult:
255252
len(self.get_assigned_nodes()),
256253
len(self.get_bad_nodes()),
257254
)
255+
256+
def should_continue_as_new(self) -> bool:
257+
if workflow.info().is_continue_as_new_suggested():
258+
return True
259+
# This is just for ease-of-testing. In production, we trust temporal to tell us when to continue as new.
260+
if (
261+
self.max_history_length
262+
and workflow.info().get_current_history_length() > self.max_history_length
263+
):
264+
return True
265+
return False

tests/message_passing/safe_message_handlers/workflow_test.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import uuid
3+
from typing import Callable, Sequence
34

45
import pytest
56
from temporalio.client import Client, WorkflowUpdateFailedError
@@ -10,6 +11,7 @@
1011
from message_passing.safe_message_handlers.activities import (
1112
assign_nodes_to_job,
1213
find_bad_nodes,
14+
start_cluster,
1315
unassign_nodes_for_job,
1416
)
1517
from message_passing.safe_message_handlers.workflow import (
@@ -19,6 +21,13 @@
1921
ClusterManagerWorkflow,
2022
)
2123

24+
ACTIVITIES: Sequence[Callable] = [
25+
assign_nodes_to_job,
26+
unassign_nodes_for_job,
27+
find_bad_nodes,
28+
start_cluster,
29+
]
30+
2231

2332
async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
2433
if env.supports_time_skipping:
@@ -30,15 +39,17 @@ async def test_safe_message_handlers(client: Client, env: WorkflowEnvironment):
3039
client,
3140
task_queue=task_queue,
3241
workflows=[ClusterManagerWorkflow],
33-
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
42+
activities=ACTIVITIES,
3443
):
3544
cluster_manager_handle = await client.start_workflow(
3645
ClusterManagerWorkflow.run,
3746
ClusterManagerInput(),
3847
id=f"ClusterManagerWorkflow-{uuid.uuid4()}",
3948
task_queue=task_queue,
4049
)
41-
await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
50+
await cluster_manager_handle.execute_update(
51+
ClusterManagerWorkflow.wait_until_cluster_started
52+
)
4253

4354
allocation_updates = []
4455
for i in range(6):
@@ -82,7 +93,7 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
8293
client,
8394
task_queue=task_queue,
8495
workflows=[ClusterManagerWorkflow],
85-
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
96+
activities=ACTIVITIES,
8697
):
8798
cluster_manager_handle = await client.start_workflow(
8899
ClusterManagerWorkflow.run,
@@ -91,7 +102,9 @@ async def test_update_idempotency(client: Client, env: WorkflowEnvironment):
91102
task_queue=task_queue,
92103
)
93104

94-
await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
105+
await cluster_manager_handle.execute_update(
106+
ClusterManagerWorkflow.wait_until_cluster_started
107+
)
95108

96109
result_1 = await cluster_manager_handle.execute_update(
97110
ClusterManagerWorkflow.assign_nodes_to_job,
@@ -121,7 +134,7 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
121134
client,
122135
task_queue=task_queue,
123136
workflows=[ClusterManagerWorkflow],
124-
activities=[assign_nodes_to_job, unassign_nodes_for_job, find_bad_nodes],
137+
activities=ACTIVITIES,
125138
):
126139
cluster_manager_handle = await client.start_workflow(
127140
ClusterManagerWorkflow.run,
@@ -130,7 +143,9 @@ async def test_update_failure(client: Client, env: WorkflowEnvironment):
130143
task_queue=task_queue,
131144
)
132145

133-
await cluster_manager_handle.signal(ClusterManagerWorkflow.start_cluster)
146+
await cluster_manager_handle.execute_update(
147+
ClusterManagerWorkflow.wait_until_cluster_started
148+
)
134149

135150
await cluster_manager_handle.execute_update(
136151
ClusterManagerWorkflow.assign_nodes_to_job,

0 commit comments

Comments
 (0)