57
57
from websockets .asyncio .client import ClientConnection
58
58
59
59
from agents .handoffs import Handoff
60
+ from agents .realtime ._default_tracker import ModelAudioTracker
60
61
from agents .tool import FunctionTool , Tool
61
62
from agents .util ._types import MaybeAwaitable
62
63
72
73
RealtimeModel ,
73
74
RealtimeModelConfig ,
74
75
RealtimeModelListener ,
76
+ RealtimePlaybackState ,
77
+ RealtimePlaybackTracker ,
75
78
)
76
79
from .model_events import (
77
80
RealtimeModelAudioDoneEvent ,
@@ -133,11 +136,10 @@ def __init__(self) -> None:
133
136
self ._websocket_task : asyncio .Task [None ] | None = None
134
137
self ._listeners : list [RealtimeModelListener ] = []
135
138
self ._current_item_id : str | None = None
136
- self ._audio_start_time : datetime | None = None
137
- self ._audio_length_ms : float = 0.0
139
+ self ._audio_state_tracker : ModelAudioTracker = ModelAudioTracker ()
138
140
self ._ongoing_response : bool = False
139
- self ._current_audio_content_index : int | None = None
140
141
self ._tracing_config : RealtimeModelTracingConfig | Literal ["auto" ] | None = None
142
+ self ._playback_tracker : RealtimePlaybackTracker | None = None
141
143
142
144
async def connect (self , options : RealtimeModelConfig ) -> None :
143
145
"""Establish a connection to the model and keep it alive."""
@@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None:
146
148
147
149
model_settings : RealtimeSessionModelSettings = options .get ("initial_model_settings" , {})
148
150
151
+ self ._playback_tracker = options .get ("playback_tracker" , RealtimePlaybackTracker ())
152
+
149
153
self .model = model_settings .get ("model_name" , self .model )
150
154
api_key = await get_api_key (options .get ("api_key" ))
151
155
@@ -294,47 +298,75 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294
298
if event .start_response :
295
299
await self ._send_raw_message (OpenAIResponseCreateEvent (type = "response.create" ))
296
300
301
+ def _get_playback_state (self ) -> RealtimePlaybackState :
302
+ if self ._playback_tracker :
303
+ return self ._playback_tracker .get_state ()
304
+
305
+ if last_audio_item_id := self ._audio_state_tracker .get_last_audio_item ():
306
+ item_id , item_content_index = last_audio_item_id
307
+ audio_state = self ._audio_state_tracker .get_state (item_id , item_content_index )
308
+ if audio_state :
309
+ elapsed_ms = (
310
+ datetime .now () - audio_state .initial_received_time
311
+ ).total_seconds () * 1000
312
+ return {
313
+ "current_item_id" : item_id ,
314
+ "current_item_content_index" : item_content_index ,
315
+ "elapsed_ms" : elapsed_ms ,
316
+ }
317
+
318
+ return {
319
+ "current_item_id" : None ,
320
+ "current_item_content_index" : None ,
321
+ "elapsed_ms" : None ,
322
+ }
323
+
297
324
async def _send_interrupt (self , event : RealtimeModelSendInterrupt ) -> None :
298
- if not self ._current_item_id or not self ._audio_start_time :
325
+ playback_state = self ._get_playback_state ()
326
+ current_item_id = playback_state .get ("current_item_id" )
327
+ current_item_content_index = playback_state .get ("current_item_content_index" )
328
+ elapsed_ms = playback_state .get ("elapsed_ms" )
329
+ if current_item_id is None or elapsed_ms is None :
330
+ logger .info (
331
+ "Skipping interrupt. "
332
+ f"Item id: { current_item_id } , "
333
+ f"elapsed ms: { elapsed_ms } , "
334
+ f"content index: { current_item_content_index } "
335
+ )
299
336
return
300
337
301
- await self ._cancel_response ()
302
-
303
- elapsed_time_ms = (datetime .now () - self ._audio_start_time ).total_seconds () * 1000
304
- if elapsed_time_ms > 0 and elapsed_time_ms < self ._audio_length_ms :
338
+ current_item_content_index = current_item_content_index or 0
339
+ if elapsed_ms > 0 :
305
340
await self ._emit_event (
306
341
RealtimeModelAudioInterruptedEvent (
307
- item_id = self . _current_item_id ,
308
- content_index = self . _current_audio_content_index or 0 ,
342
+ item_id = current_item_id ,
343
+ content_index = current_item_content_index ,
309
344
)
310
345
)
311
346
converted = _ConversionHelper .convert_interrupt (
312
- self . _current_item_id ,
313
- self . _current_audio_content_index or 0 ,
314
- int (elapsed_time_ms ),
347
+ current_item_id ,
348
+ current_item_content_index ,
349
+ int (elapsed_ms ),
315
350
)
316
351
await self ._send_raw_message (converted )
352
+ await self ._cancel_response ()
317
353
318
- self ._current_item_id = None
319
- self ._audio_start_time = None
320
- self ._audio_length_ms = 0.0
321
- self ._current_audio_content_index = None
354
+ self ._audio_state_tracker .on_interrupted ()
355
+ if self ._playback_tracker :
356
+ self ._playback_tracker .on_interrupted ()
322
357
323
358
async def _send_session_update (self , event : RealtimeModelSendSessionUpdate ) -> None :
324
359
"""Send a session update to the model."""
325
360
await self ._update_session_config (event .session_settings )
326
361
327
362
async def _handle_audio_delta (self , parsed : ResponseAudioDeltaEvent ) -> None :
328
363
"""Handle audio delta events and update audio tracking state."""
329
- self ._current_audio_content_index = parsed .content_index
330
364
self ._current_item_id = parsed .item_id
331
- if self ._audio_start_time is None :
332
- self ._audio_start_time = datetime .now ()
333
- self ._audio_length_ms = 0.0
334
365
335
366
audio_bytes = base64 .b64decode (parsed .delta )
336
- # Calculate audio length in ms using 24KHz pcm16le
337
- self ._audio_length_ms += self ._calculate_audio_length_ms (audio_bytes )
367
+
368
+ self ._audio_state_tracker .on_audio_delta (parsed .item_id , parsed .content_index , audio_bytes )
369
+
338
370
await self ._emit_event (
339
371
RealtimeModelAudioEvent (
340
372
data = audio_bytes ,
@@ -344,10 +376,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
344
376
)
345
377
)
346
378
347
- def _calculate_audio_length_ms (self , audio_bytes : bytes ) -> float :
348
- """Calculate audio length in milliseconds for 24KHz PCM16LE format."""
349
- return len (audio_bytes ) / 24 / 2
350
-
351
379
async def _handle_output_item (self , item : ConversationItem ) -> None :
352
380
"""Handle response output item events (function calls and messages)."""
353
381
if item .type == "function_call" and item .status == "completed" :
0 commit comments