Skip to content

Commit b77e33c

Browse files
committed
Realtime: enable a playback tracker
1 parent 00412a1 commit b77e33c

File tree

7 files changed

+368
-59
lines changed

7 files changed

+368
-59
lines changed

src/agents/realtime/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
RealtimeModel,
4848
RealtimeModelConfig,
4949
RealtimeModelListener,
50+
RealtimePlaybackState,
51+
RealtimePlaybackTracker,
5052
)
5153
from .model_events import (
5254
RealtimeConnectionStatus,
@@ -139,6 +141,8 @@
139141
"RealtimeModel",
140142
"RealtimeModelConfig",
141143
"RealtimeModelListener",
144+
"RealtimePlaybackTracker",
145+
"RealtimePlaybackState",
142146
# Model Events
143147
"RealtimeConnectionStatus",
144148
"RealtimeModelAudioDoneEvent",
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataclasses import dataclass
2+
from datetime import datetime
3+
4+
from ._util import calculate_audio_length_ms
5+
from .config import RealtimeAudioFormat
6+
7+
8+
@dataclass
9+
class ModelAudioState:
10+
initial_received_time: datetime
11+
audio_length_ms: float
12+
13+
14+
class ModelAudioTracker:
15+
def __init__(self) -> None:
16+
# (item_id, item_content_index) -> ModelAudioState
17+
self._states: dict[tuple[str, int], ModelAudioState] = {}
18+
self._last_audio_item: tuple[str, int] | None = None
19+
20+
def set_audio_format(self, format: RealtimeAudioFormat) -> None:
21+
"""Called when the model wants to set the audio format."""
22+
self._format = format
23+
24+
def on_audio_delta(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
25+
"""Called when an audio delta is received from the model."""
26+
ms = calculate_audio_length_ms(self._format, bytes)
27+
new_key = (item_id, item_content_index)
28+
29+
self._last_audio_item = new_key
30+
if new_key not in self._states:
31+
self._states[new_key] = ModelAudioState(datetime.now(), ms)
32+
else:
33+
self._states[new_key].audio_length_ms += ms
34+
35+
def on_interrupted(self) -> None:
36+
"""Called when the audio playback has been interrupted."""
37+
self._last_audio_item = None
38+
39+
def get_state(self, item_id: str, item_content_index: int) -> ModelAudioState | None:
40+
"""Called when the model wants to get the current playback state."""
41+
return self._states.get((item_id, item_content_index))
42+
43+
def get_last_audio_item(self) -> tuple[str, int] | None:
44+
"""Called when the model wants to get the last audio item ID and content index."""
45+
return self._last_audio_item

src/agents/realtime/_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .config import RealtimeAudioFormat
2+
3+
4+
def calculate_audio_length_ms(format: RealtimeAudioFormat | None, bytes: bytes) -> float:
5+
if format and format.startswith("g711"):
6+
return (len(bytes) / 8000) * 1000
7+
return (len(bytes) / 24 / 2) * 1000

src/agents/realtime/model.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,95 @@
66
from typing_extensions import NotRequired, TypedDict
77

88
from ..util._types import MaybeAwaitable
9+
from ._util import calculate_audio_length_ms
910
from .config import (
11+
RealtimeAudioFormat,
1012
RealtimeSessionModelSettings,
1113
)
1214
from .model_events import RealtimeModelEvent
1315
from .model_inputs import RealtimeModelSendEvent
1416

1517

18+
class RealtimePlaybackState(TypedDict):
19+
current_item_id: str | None
20+
"""The item ID of the current item being played."""
21+
22+
current_item_content_index: int | None
23+
"""The index of the current item content being played."""
24+
25+
elapsed_ms: float | None
26+
"""The number of milliseconds of audio that have been played."""
27+
28+
29+
class RealtimePlaybackTracker:
30+
"""If you have custom playback logic or expect that audio is played with delays or at different
31+
speeds, create an instance of RealtimePlaybackTracker and pass it to the session. You are
32+
responsible for tracking the audio playback progress and calling `on_play_bytes` or
33+
`on_play_ms` when the user has played some audio."""
34+
35+
def __init__(self) -> None:
36+
self._format: RealtimeAudioFormat | None = None
37+
# (item_id, item_content_index)
38+
self._current_item: tuple[str, int] | None = None
39+
self._elapsed_ms: float | None = None
40+
41+
def on_play_bytes(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
42+
"""Called by you when you have played some audio.
43+
44+
Args:
45+
item_id: The item ID of the audio being played.
46+
item_content_index: The index of the audio content in `item.content`
47+
bytes: The audio bytes that have been fully played.
48+
"""
49+
ms = calculate_audio_length_ms(self._format, bytes)
50+
self.on_play_ms(item_id, item_content_index, ms)
51+
52+
def on_play_ms(self, item_id: str, item_content_index: int, ms: float) -> None:
53+
"""Called by you when you have played some audio.
54+
55+
Args:
56+
item_id: The item ID of the audio being played.
57+
item_content_index: The index of the audio content in `item.content`
58+
ms: The number of milliseconds of audio that have been played.
59+
"""
60+
if self._current_item != (item_id, item_content_index):
61+
self._current_item = (item_id, item_content_index)
62+
self._elapsed_ms = ms
63+
else:
64+
assert self._elapsed_ms is not None
65+
self._elapsed_ms += ms
66+
67+
def on_interrupted(self) -> None:
68+
"""Called by the model when the audio playback has been interrupted."""
69+
self._current_item = None
70+
self._elapsed_ms = None
71+
72+
def set_audio_format(self, format: RealtimeAudioFormat) -> None:
73+
"""Will be called by the model to set the audio format.
74+
75+
Args:
76+
format: The audio format to use.
77+
"""
78+
self._format = format
79+
80+
def get_state(self) -> RealtimePlaybackState:
81+
"""Will be called by the model to get the current playback state."""
82+
if self._current_item is None:
83+
return {
84+
"current_item_id": None,
85+
"current_item_content_index": None,
86+
"elapsed_ms": None,
87+
}
88+
assert self._elapsed_ms is not None
89+
90+
item_id, item_content_index = self._current_item
91+
return {
92+
"current_item_id": item_id,
93+
"current_item_content_index": item_content_index,
94+
"elapsed_ms": self._elapsed_ms,
95+
}
96+
97+
1698
class RealtimeModelListener(abc.ABC):
1799
"""A listener for realtime transport events."""
18100

@@ -39,6 +121,18 @@ class RealtimeModelConfig(TypedDict):
39121
initial_model_settings: NotRequired[RealtimeSessionModelSettings]
40122
"""The initial model settings to use when connecting."""
41123

124+
playback_tracker: NotRequired[RealtimePlaybackTracker]
125+
"""The playback tracker to use when tracking audio playback progress. If not set, the model will
126+
use a default implementation that assumes audio is played immediately, at realtime speed.
127+
128+
A playback tracker is useful for interruptions. The model generates audio much faster than
129+
realtime playback speed. So if there's an interruption, its useful for the model to know how
130+
much of the audio has been played by the user. In low-latency scenarios, it's fine to assume
131+
that audio is played back immediately at realtime speed. But in scenarios like phone calls or
132+
other remote interactions, you can set a playback tracker that lets the model know when audio
133+
is played to the user.
134+
"""
135+
42136

43137
class RealtimeModel(abc.ABC):
44138
"""Interface for connecting to a realtime model and sending/receiving events."""

src/agents/realtime/openai_realtime.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from websockets.asyncio.client import ClientConnection
5858

5959
from agents.handoffs import Handoff
60+
from agents.realtime._default_tracker import ModelAudioTracker
6061
from agents.tool import FunctionTool, Tool
6162
from agents.util._types import MaybeAwaitable
6263

@@ -72,6 +73,8 @@
7273
RealtimeModel,
7374
RealtimeModelConfig,
7475
RealtimeModelListener,
76+
RealtimePlaybackState,
77+
RealtimePlaybackTracker,
7578
)
7679
from .model_events import (
7780
RealtimeModelAudioDoneEvent,
@@ -133,11 +136,10 @@ def __init__(self) -> None:
133136
self._websocket_task: asyncio.Task[None] | None = None
134137
self._listeners: list[RealtimeModelListener] = []
135138
self._current_item_id: str | None = None
136-
self._audio_start_time: datetime | None = None
137-
self._audio_length_ms: float = 0.0
139+
self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
138140
self._ongoing_response: bool = False
139-
self._current_audio_content_index: int | None = None
140141
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
142+
self._playback_tracker: RealtimePlaybackTracker | None = None
141143

142144
async def connect(self, options: RealtimeModelConfig) -> None:
143145
"""Establish a connection to the model and keep it alive."""
@@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None:
146148

147149
model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {})
148150

151+
self._playback_tracker = options.get("playback_tracker", RealtimePlaybackTracker())
152+
149153
self.model = model_settings.get("model_name", self.model)
150154
api_key = await get_api_key(options.get("api_key"))
151155

@@ -294,47 +298,75 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294298
if event.start_response:
295299
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
296300

301+
def _get_playback_state(self) -> RealtimePlaybackState:
302+
if self._playback_tracker:
303+
return self._playback_tracker.get_state()
304+
305+
if last_audio_item_id := self._audio_state_tracker.get_last_audio_item():
306+
item_id, item_content_index = last_audio_item_id
307+
audio_state = self._audio_state_tracker.get_state(item_id, item_content_index)
308+
if audio_state:
309+
elapsed_ms = (
310+
datetime.now() - audio_state.initial_received_time
311+
).total_seconds() * 1000
312+
return {
313+
"current_item_id": item_id,
314+
"current_item_content_index": item_content_index,
315+
"elapsed_ms": elapsed_ms,
316+
}
317+
318+
return {
319+
"current_item_id": None,
320+
"current_item_content_index": None,
321+
"elapsed_ms": None,
322+
}
323+
297324
async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
298-
if not self._current_item_id or not self._audio_start_time:
325+
playback_state = self._get_playback_state()
326+
current_item_id = playback_state.get("current_item_id")
327+
current_item_content_index = playback_state.get("current_item_content_index")
328+
elapsed_ms = playback_state.get("elapsed_ms")
329+
if current_item_id is None or elapsed_ms is None:
330+
logger.info(
331+
"Skipping interrupt. "
332+
f"Item id: {current_item_id}, "
333+
f"elapsed ms: {elapsed_ms}, "
334+
f"content index: {current_item_content_index}"
335+
)
299336
return
300337

301-
await self._cancel_response()
302-
303-
elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
304-
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
338+
current_item_content_index = current_item_content_index or 0
339+
if elapsed_ms > 0:
305340
await self._emit_event(
306341
RealtimeModelAudioInterruptedEvent(
307-
item_id=self._current_item_id,
308-
content_index=self._current_audio_content_index or 0,
342+
item_id=current_item_id,
343+
content_index=current_item_content_index,
309344
)
310345
)
311346
converted = _ConversionHelper.convert_interrupt(
312-
self._current_item_id,
313-
self._current_audio_content_index or 0,
314-
int(elapsed_time_ms),
347+
current_item_id,
348+
current_item_content_index,
349+
int(elapsed_ms),
315350
)
316351
await self._send_raw_message(converted)
352+
await self._cancel_response()
317353

318-
self._current_item_id = None
319-
self._audio_start_time = None
320-
self._audio_length_ms = 0.0
321-
self._current_audio_content_index = None
354+
self._audio_state_tracker.on_interrupted()
355+
if self._playback_tracker:
356+
self._playback_tracker.on_interrupted()
322357

323358
async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
324359
"""Send a session update to the model."""
325360
await self._update_session_config(event.session_settings)
326361

327362
async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
328363
"""Handle audio delta events and update audio tracking state."""
329-
self._current_audio_content_index = parsed.content_index
330364
self._current_item_id = parsed.item_id
331-
if self._audio_start_time is None:
332-
self._audio_start_time = datetime.now()
333-
self._audio_length_ms = 0.0
334365

335366
audio_bytes = base64.b64decode(parsed.delta)
336-
# Calculate audio length in ms using 24KHz pcm16le
337-
self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes)
367+
368+
self._audio_state_tracker.on_audio_delta(parsed.item_id, parsed.content_index, audio_bytes)
369+
338370
await self._emit_event(
339371
RealtimeModelAudioEvent(
340372
data=audio_bytes,
@@ -344,10 +376,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
344376
)
345377
)
346378

347-
def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
348-
"""Calculate audio length in milliseconds for 24KHz PCM16LE format."""
349-
return len(audio_bytes) / 24 / 2
350-
351379
async def _handle_output_item(self, item: ConversationItem) -> None:
352380
"""Handle response output item events (function calls and messages)."""
353381
if item.type == "function_call" and item.status == "completed":

0 commit comments

Comments
 (0)