Skip to content

async model stream interface #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
01f7d1b
async models
pgrayy Jun 27, 2025
0fd2671
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jun 30, 2025
ab2c088
lint
pgrayy Jun 30, 2025
3fd243e
tests
pgrayy Jun 30, 2025
d9deb93
agent - asyncio.run stream_async in call
pgrayy Jun 30, 2025
b864b65
tests - agenerator helper
pgrayy Jun 30, 2025
bd2adff
tests - agent - stream async result
pgrayy Jun 30, 2025
7ea90e9
lint
pgrayy Jun 30, 2025
3462e1d
agent - stream async - result
pgrayy Jun 30, 2025
8063c08
typing
pgrayy Jun 30, 2025
f8e58a1
tests - anext
pgrayy Jun 30, 2025
5bb0620
tests - alist
pgrayy Jun 30, 2025
760fcfb
lint
pgrayy Jun 30, 2025
61bb44d
tests - async utilities - scope session
pgrayy Jun 30, 2025
e47567f
tests integ - conftest
pgrayy Jun 30, 2025
90aaa47
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 1, 2025
ccc44d9
lint
pgrayy Jul 1, 2025
e8c7bda
tests - async mock model provider
pgrayy Jul 1, 2025
17b24b3
lint
pgrayy Jul 1, 2025
6f46740
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 2, 2025
1c667d9
async invoke and structured output
pgrayy Jul 2, 2025
21f80cc
thread asyncio run
pgrayy Jul 2, 2025
885f98d
test async threading
pgrayy Jul 2, 2025
773bef1
lint
pgrayy Jul 2, 2025
eb1ffd5
move invoke_async up for clarity
pgrayy Jul 2, 2025
2ce4581
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 3, 2025
3a58ef4
lint
pgrayy Jul 3, 2025
780b13e
tests
pgrayy Jul 3, 2025
5761545
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 76 additions & 31 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import asyncio
import json
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast

from opentelemetry import trace
from pydantic import BaseModel
Expand Down Expand Up @@ -378,33 +379,43 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)

self._start_agent_trace_span(prompt)
def execute() -> AgentResult:
return asyncio.run(self.invoke_async(prompt, **kwargs))

try:
events = self._run_loop(prompt, kwargs)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])
with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)
async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

self._end_agent_trace_span(response=result)
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
the conversation history, processes it through the model, executes any tool calls, and returns the final result.

return result
Args:
prompt: The natural language prompt from the user.
**kwargs: Additional parameters to pass through the event loop.

except Exception as e:
self._end_agent_trace_span(error=e)
raise
Returns:
Result object containing:

- stop_reason: Why the event loop stopped (e.g., "end_turn", "max_tokens")
- message: The final message from the model
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
events = self.stream_async(prompt, **kwargs)
async for event in events:
_ = event

return cast(AgentResult, event["result"])

def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
If you don't pass in a prompt, it will use only the conversation history to respond.
If no conversation history exists and no prompt is provided, an error will be raised.

For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
instruct the model to output the structured data.
Expand All @@ -413,25 +424,52 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt: The prompt to use for the agent.

Raises:
ValueError: If no conversation history or prompt is provided.
"""

def execute() -> T:
return asyncio.run(self.structured_output_async(output_model, prompt))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

async def structured_output_async(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
If you don't pass in a prompt, it will use only the conversation history to respond.

For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
instruct the model to output the structured data.

Args:
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt: The prompt to use for the agent.

Raises:
ValueError: If no conversation history or prompt is provided.
"""
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))

try:
messages = self.messages
if not messages and not prompt:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")

# add the prompt as the last message
if prompt:
messages.append({"role": "user", "content": [{"text": prompt}]})
self.messages.append({"role": "user", "content": [{"text": prompt}]})

# get the structured output from the model
events = self.model.structured_output(output_model, messages)
for event in events:
events = self.model.structured_output(output_model, self.messages)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))

return event["output"]

finally:
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))

Expand Down Expand Up @@ -471,21 +509,22 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:

try:
events = self._run_loop(prompt, kwargs)
for event in events:
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]

stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)
result = AgentResult(*event["stop"])
callback_handler(result=result)
yield {"result": result}

self._end_agent_trace_span(response=result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise

def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given prompt and parameters."""
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))

Expand All @@ -499,13 +538,15 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str,
self.messages.append(new_message)

# Execute the event loop cycle with retry logic for context limits
yield from self._execute_event_loop_cycle(kwargs)
events = self._execute_event_loop_cycle(kwargs)
async for event in events:
yield event

finally:
self.conversation_manager.apply_management(self)
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))

def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand All @@ -520,7 +561,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st

try:
# Execute the main event loop cycle
yield from event_loop_cycle(
events = event_loop_cycle(
model=self.model,
system_prompt=self.system_prompt,
messages=self.messages, # will be modified by event_loop_cycle
Expand All @@ -531,11 +572,15 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st
event_loop_parent_span=self.trace_span,
kwargs=kwargs,
)
async for event in events:
yield event

except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
self.conversation_manager.reduce_context(self, e=e)
yield from self._execute_event_loop_cycle(kwargs)
events = self._execute_event_loop_cycle(kwargs)
async for event in events:
yield event

def _record_tool_execution(
self,
Expand All @@ -560,7 +605,7 @@ def _record_tool_execution(
messages: The message history to append to.
"""
# Create user message describing the tool call
user_msg_content: List[ContentBlock] = [
user_msg_content: list[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
]

Expand Down
34 changes: 22 additions & 12 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Generator, Optional
from typing import Any, AsyncGenerator, Optional

from opentelemetry import trace

Expand All @@ -35,7 +35,7 @@
MAX_DELAY = 240 # 4 minutes


def event_loop_cycle(
async def event_loop_cycle(
model: Model,
system_prompt: Optional[str],
messages: Messages,
Expand All @@ -45,7 +45,7 @@ def event_loop_cycle(
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
kwargs: dict[str, Any],
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute a single cycle of the event loop.

This core function processes a single conversation turn, handling model inference, tool execution, and error
Expand Down Expand Up @@ -132,7 +132,7 @@ def event_loop_cycle(
try:
# TODO: To maintain backwards compatability, we need to combine the stream event with kwargs before yielding
# to the callback handler. This will be revisited when migrating to strongly typed events.
for event in stream_messages(model, system_prompt, messages, tool_config):
async for event in stream_messages(model, system_prompt, messages, tool_config):
if "callback" in event:
yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}}

Expand Down Expand Up @@ -202,7 +202,7 @@ def event_loop_cycle(
)

# Handle tool execution
yield from _handle_tool_execution(
events = _handle_tool_execution(
stop_reason,
message,
model,
Expand All @@ -218,6 +218,9 @@ def event_loop_cycle(
cycle_start_time,
kwargs,
)
async for event in events:
yield event

return

# End the cycle and return results
Expand Down Expand Up @@ -250,7 +253,7 @@ def event_loop_cycle(
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}


def recurse_event_loop(
async def recurse_event_loop(
model: Model,
system_prompt: Optional[str],
messages: Messages,
Expand All @@ -260,7 +263,7 @@ def recurse_event_loop(
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
kwargs: dict[str, Any],
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.
Expand Down Expand Up @@ -292,7 +295,8 @@ def recurse_event_loop(
cycle_trace.add_child(recursive_trace)

yield {"callback": {"start": True}}
yield from event_loop_cycle(

events = event_loop_cycle(
model=model,
system_prompt=system_prompt,
messages=messages,
Expand All @@ -303,11 +307,13 @@ def recurse_event_loop(
event_loop_parent_span=event_loop_parent_span,
kwargs=kwargs,
)
async for event in events:
yield event

recursive_trace.end()


def _handle_tool_execution(
async def _handle_tool_execution(
stop_reason: StopReason,
message: Message,
model: Model,
Expand All @@ -322,7 +328,7 @@ def _handle_tool_execution(
cycle_span: Any,
cycle_start_time: float,
kwargs: dict[str, Any],
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
tool_uses: list[ToolUse] = []
tool_results: list[ToolResult] = []
invalid_tool_use_ids: list[str] = []
Expand Down Expand Up @@ -369,7 +375,7 @@ def _handle_tool_execution(
kwargs=kwargs,
)

yield from run_tools(
tool_events = run_tools(
handler=tool_handler_process,
tool_uses=tool_uses,
event_loop_metrics=event_loop_metrics,
Expand All @@ -379,6 +385,8 @@ def _handle_tool_execution(
parent_span=cycle_span,
thread_pool=thread_pool,
)
for tool_event in tool_events:
yield tool_event

# Store parent cycle ID for the next cycle
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]
Expand All @@ -400,7 +408,7 @@ def _handle_tool_execution(
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}
return

yield from recurse_event_loop(
events = recurse_event_loop(
model=model,
system_prompt=system_prompt,
messages=messages,
Expand All @@ -411,3 +419,5 @@ def _handle_tool_execution(
event_loop_parent_span=event_loop_parent_span,
kwargs=kwargs,
)
async for event in events:
yield event
Loading