Skip to content

Commit 680f17a

Browse files
authored
fix(multiagent): raise ValueError for unsupported Graph and Swarm agent features (strands-agents#472)
1 parent ea4e878 commit 680f17a

File tree

6 files changed

+163
-5
lines changed

6 files changed

+163
-5
lines changed

src/strands/hooks/registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,20 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent:
204204

205205
return event
206206

207+
def has_callbacks(self) -> bool:
208+
"""Check if the registry has any registered callbacks.
209+
210+
Returns:
211+
True if there are any registered callbacks, False otherwise.
212+
213+
Example:
214+
```python
215+
if registry.has_callbacks():
216+
print("Registry has callbacks registered")
217+
```
218+
"""
219+
return bool(self._registered_callbacks)
220+
207221
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
208222
"""Get callbacks registered for the given event in the appropriate order.
209223

src/strands/multiagent/graph.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,32 @@ def __eq__(self, other: Any) -> bool:
129129
return self.node_id == other.node_id
130130

131131

132+
def _validate_node_executor(
133+
executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
134+
) -> None:
135+
"""Validate a node executor for graph compatibility.
136+
137+
Args:
138+
executor: The executor to validate
139+
existing_nodes: Optional dict of existing nodes to check for duplicates
140+
"""
141+
# Check for duplicate node instances
142+
if existing_nodes:
143+
seen_instances = {id(node.executor) for node in existing_nodes.values()}
144+
if id(executor) in seen_instances:
145+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
146+
147+
# Validate Agent-specific constraints
148+
if isinstance(executor, Agent):
149+
# Check for session persistence
150+
if executor._session_manager is not None:
151+
raise ValueError("Session persistence is not supported for Graph agents yet.")
152+
153+
# Check for callbacks
154+
if executor.hooks.has_callbacks():
155+
raise ValueError("Agent callbacks are not supported for Graph agents yet.")
156+
157+
132158
class GraphBuilder:
133159
"""Builder pattern for constructing graphs."""
134160

@@ -140,10 +166,7 @@ def __init__(self) -> None:
140166

141167
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
142168
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
143-
# Check for duplicate node instances
144-
seen_instances = {id(node.executor) for node in self.nodes.values()}
145-
if id(executor) in seen_instances:
146-
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
169+
_validate_node_executor(executor, self.nodes)
147170

148171
# Auto-generate node_id if not provided
149172
if node_id is None:
@@ -304,6 +327,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
304327
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
305328
seen_instances.add(id(node.executor))
306329

330+
# Validate Agent-specific constraints for each node
331+
_validate_node_executor(node.executor)
332+
307333
async def _execute_graph(self) -> None:
308334
"""Unified execution flow with conditional routing."""
309335
ready_nodes = list(self.entry_points)

src/strands/multiagent/swarm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,14 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
314314
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
315315
seen_instances.add(id(node))
316316

317+
# Check for session persistence
318+
if node._session_manager is not None:
319+
raise ValueError("Session persistence is not supported for Swarm agents yet.")
320+
321+
# Check for callbacks
322+
if node.hooks.has_callbacks():
323+
raise ValueError("Agent callbacks are not supported for Swarm agents yet.")
324+
317325
def _inject_swarm_tools(self) -> None:
318326
"""Add swarm coordination tools to each agent."""
319327
# Create tool functions with proper closures

tests/strands/experimental/hooks/test_hook_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,18 @@ def callback2(_event):
150150
hook_registry.invoke_callbacks(test_after_event)
151151

152152
assert call_order == ["callback2", "callback1"] # Reverse order
153+
154+
155+
def test_has_callbacks(hook_registry, test_event):
156+
"""Test that has_callbacks returns correct boolean values."""
157+
# Empty registry should return False
158+
assert not hook_registry.has_callbacks()
159+
160+
# Registry with callbacks should return True
161+
callback = Mock()
162+
hook_registry.add_callback(TestEvent, callback)
163+
assert hook_registry.has_callbacks()
164+
165+
# Test with multiple event types
166+
hook_registry.add_callback(TestAfterEvent, Mock())
167+
assert hook_registry.has_callbacks()

tests/strands/multiagent/test_graph.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
import pytest
44

55
from strands.agent import Agent, AgentResult
6+
from strands.hooks import AgentInitializedEvent
7+
from strands.hooks.registry import HookProvider, HookRegistry
68
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult
7-
from strands.multiagent.graph import GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status
9+
from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status
10+
from strands.session.session_manager import SessionManager
811

912

1013
def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None):
1114
"""Create a mock Agent with specified properties."""
1215
agent = Mock(spec=Agent)
1316
agent.name = name
1417
agent.id = agent_id or f"{name}_id"
18+
agent._session_manager = None
19+
agent.hooks = HookRegistry()
1520

1621
if metrics is None:
1722
metrics = Mock(
@@ -261,6 +266,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
261266
failing_agent.id = "fail_node"
262267
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))
263268

269+
# Add required attributes for validation
270+
failing_agent._session_manager = None
271+
failing_agent.hooks = HookRegistry()
272+
264273
async def mock_invoke_failure(*args, **kwargs):
265274
raise Exception("Simulated failure")
266275

@@ -489,3 +498,51 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag
489498

490499
mock_strands_tracer.start_multiagent_span.assert_called()
491500
mock_use_span.assert_called_once()
501+
502+
503+
def test_graph_validate_unsupported_features():
504+
"""Test Graph validation for session persistence and callbacks."""
505+
# Test with normal agent (should work)
506+
normal_agent = create_mock_agent("normal_agent")
507+
normal_agent._session_manager = None
508+
normal_agent.hooks = HookRegistry()
509+
510+
builder = GraphBuilder()
511+
builder.add_node(normal_agent)
512+
graph = builder.build()
513+
assert len(graph.nodes) == 1
514+
515+
# Test with session manager (should fail in GraphBuilder.add_node)
516+
mock_session_manager = Mock(spec=SessionManager)
517+
agent_with_session = create_mock_agent("agent_with_session")
518+
agent_with_session._session_manager = mock_session_manager
519+
agent_with_session.hooks = HookRegistry()
520+
521+
builder = GraphBuilder()
522+
with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"):
523+
builder.add_node(agent_with_session)
524+
525+
# Test with callbacks (should fail in GraphBuilder.add_node)
526+
class TestHookProvider(HookProvider):
527+
def register_hooks(self, registry, **kwargs):
528+
registry.add_callback(AgentInitializedEvent, lambda e: None)
529+
530+
agent_with_hooks = create_mock_agent("agent_with_hooks")
531+
agent_with_hooks._session_manager = None
532+
agent_with_hooks.hooks = HookRegistry()
533+
agent_with_hooks.hooks.add_hook(TestHookProvider())
534+
535+
builder = GraphBuilder()
536+
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
537+
builder.add_node(agent_with_hooks)
538+
539+
# Test validation in Graph constructor (when nodes are passed directly)
540+
# Test with session manager in Graph constructor
541+
node_with_session = GraphNode("node_with_session", agent_with_session)
542+
with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"):
543+
Graph(nodes={"node_with_session": node_with_session}, edges=set(), entry_points=set())
544+
545+
# Test with callbacks in Graph constructor
546+
node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks)
547+
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
548+
Graph(nodes={"node_with_hooks": node_with_hooks}, edges=set(), entry_points=set())

tests/strands/multiagent/test_swarm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
from strands.agent import Agent, AgentResult
88
from strands.agent.state import AgentState
9+
from strands.hooks import AgentInitializedEvent
10+
from strands.hooks.registry import HookProvider, HookRegistry
911
from strands.multiagent.base import Status
1012
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
13+
from strands.session.session_manager import SessionManager
1114
from strands.types.content import ContentBlock
1215

1316

@@ -27,6 +30,8 @@ def create_mock_agent(
2730
agent._complete_after = complete_after_calls
2831
agent._swarm_ref = None # Will be set by the swarm
2932
agent._should_fail = should_fail
33+
agent._session_manager = None
34+
agent.hooks = HookRegistry()
3035

3136
if metrics is None:
3237
metrics = Mock(
@@ -450,3 +455,36 @@ def test_swarm_metrics_handling():
450455

451456
result = no_metrics_swarm("Test no metrics")
452457
assert result.status == Status.COMPLETED
458+
459+
460+
def test_swarm_validate_unsupported_features():
461+
"""Test Swarm validation for session persistence and callbacks."""
462+
# Test with normal agent (should work)
463+
normal_agent = create_mock_agent("normal_agent")
464+
normal_agent._session_manager = None
465+
normal_agent.hooks = HookRegistry()
466+
467+
swarm = Swarm([normal_agent])
468+
assert len(swarm.nodes) == 1
469+
470+
# Test with session manager (should fail)
471+
mock_session_manager = Mock(spec=SessionManager)
472+
agent_with_session = create_mock_agent("agent_with_session")
473+
agent_with_session._session_manager = mock_session_manager
474+
agent_with_session.hooks = HookRegistry()
475+
476+
with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"):
477+
Swarm([agent_with_session])
478+
479+
# Test with callbacks (should fail)
480+
class TestHookProvider(HookProvider):
481+
def register_hooks(self, registry, **kwargs):
482+
registry.add_callback(AgentInitializedEvent, lambda e: None)
483+
484+
agent_with_hooks = create_mock_agent("agent_with_hooks")
485+
agent_with_hooks._session_manager = None
486+
agent_with_hooks.hooks = HookRegistry()
487+
agent_with_hooks.hooks.add_hook(TestHookProvider())
488+
489+
with pytest.raises(ValueError, match="Agent callbacks are not supported for Swarm agents yet"):
490+
Swarm([agent_with_hooks])

0 commit comments

Comments
 (0)