Skip to content

OpenAI agents chat alt #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions openai_agents/run_customer_service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ async def main():
# Query the workflow for the chat history
# If the workflow is not open, start a new one
start = False
history = []
try:
history = await handle.query(
CustomerServiceWorkflow.get_chat_history,
reject_condition=QueryRejectCondition.NOT_OPEN,
)
except WorkflowQueryRejectedError as e:
except WorkflowQueryRejectedError:
start = True
except RPCError as e:
if e.status == RPCStatusCode.NOT_FOUND:
Expand All @@ -64,7 +65,7 @@ async def main():
CustomerServiceWorkflow.process_user_message, message_input
)
history.extend(new_history)
print(*new_history, sep="\n")
print(*new_history[1:], sep="\n")
except WorkflowUpdateFailedError:
print("** Stale conversation. Reloading...")
length = len(history)
Expand Down
1 change: 0 additions & 1 deletion openai_agents/run_hello_world_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from temporalio.contrib.pydantic import pydantic_data_converter

from openai_agents.workflows.hello_world_workflow import HelloWorldAgent
from openai_agents.workflows.research_bot_workflow import ResearchWorkflow


async def main():
Expand Down
2 changes: 2 additions & 0 deletions openai_agents/run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from temporalio.contrib.openai_agents import (
ModelActivity,
ModelActivityParameters,
OpenAIAgentsTracingInterceptor,
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.pydantic import pydantic_data_converter
Expand Down Expand Up @@ -46,6 +47,7 @@ async def main():
ModelActivity().invoke_model_activity,
get_weather,
],
interceptors=[OpenAIAgentsTracingInterceptor()],
)
await worker.run()

Expand Down
17 changes: 12 additions & 5 deletions openai_agents/workflows/customer_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations as _annotations

from typing import Dict, Tuple

from agents import Agent, RunContextWrapper, function_tool, handoff
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
from pydantic import BaseModel
Expand All @@ -23,19 +25,20 @@ class AirlineAgentContext(BaseModel):
description_override="Lookup frequently asked questions.",
)
async def faq_lookup_tool(question: str) -> str:
if "bag" in question or "baggage" in question:
question_lower = question.lower()
if "bag" in question_lower or "baggage" in question_lower:
return (
"You are allowed to bring one bag on the plane. "
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
)
elif "seats" in question or "plane" in question:
elif "seats" in question_lower or "plane" in question_lower:
return (
"There are 120 seats on the plane. "
"There are 22 business class seats and 98 economy seats. "
"Exit rows are rows 4 and 16. "
"Rows 5-8 are Economy Plus, with extra legroom. "
)
elif "wifi" in question:
elif "wifi" in question_lower:
return "We have free wifi on the plane, join Airline-Wifi"
return "I'm sorry, I don't know the answer to that question."

Expand Down Expand Up @@ -74,7 +77,9 @@ async def on_seat_booking_handoff(
### AGENTS


def init_agents() -> Agent[AirlineAgentContext]:
def init_agents() -> Tuple[
Agent[AirlineAgentContext], Dict[str, Agent[AirlineAgentContext]]
]:
"""
Initialize the agents for the airline customer service workflow.
:return: triage agent
Expand Down Expand Up @@ -121,7 +126,9 @@ def init_agents() -> Agent[AirlineAgentContext]:

faq_agent.handoffs.append(triage_agent)
seat_booking_agent.handoffs.append(triage_agent)
return triage_agent
return triage_agent, {
agent.name: agent for agent in [faq_agent, seat_booking_agent, triage_agent]
}


class ProcessUserMessageInput(BaseModel):
Expand Down
163 changes: 115 additions & 48 deletions openai_agents/workflows/customer_service_workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations as _annotations

import asyncio
from datetime import timedelta

from agents import (
Agent,
HandoffCallItem,
HandoffOutputItem,
ItemHelpers,
MessageOutputItem,
Expand All @@ -12,6 +15,7 @@
TResponseInputItem,
trace,
)
from pydantic import BaseModel, dataclasses
from temporalio import workflow

from openai_agents.workflows.customer_service import (
Expand All @@ -21,71 +25,134 @@
)


@dataclasses.dataclass
class CustomerServiceWorkflowState:
printed_history: list[str]
current_agent_name: str
context: AirlineAgentContext
input_items: list[dict] # Store as plain dictionaries to avoid serialization issues


@workflow.defn
class CustomerServiceWorkflow:
@workflow.init
def __init__(self, input_items: list[TResponseInputItem] | None = None):
def __init__(
self, customer_service_state: CustomerServiceWorkflowState | None = None
):
self.run_config = RunConfig()
self.chat_history: list[str] = []
self.current_agent: Agent[AirlineAgentContext] = init_agents()
self.context = AirlineAgentContext()
self.input_items = [] if input_items is None else input_items

@workflow.run
async def run(self, input_items: list[TResponseInputItem] | None = None):
await workflow.wait_condition(
lambda: workflow.info().is_continue_as_new_suggested()
and workflow.all_handlers_finished()
starting_agent, self.agent_map = init_agents()
self.current_agent = (
self.agent_map[customer_service_state.current_agent_name]
if customer_service_state
else starting_agent
)
self.context = (
customer_service_state.context
if customer_service_state
else AirlineAgentContext()
)
workflow.continue_as_new(self.input_items)

self.printed_history: list[str] = (
customer_service_state.printed_history if customer_service_state else []
)

self.input_items = (
customer_service_state.input_items if customer_service_state else []
)

# Communication channels
self.user_input_queue: asyncio.Queue[str] = asyncio.Queue()
self.update_condition: asyncio.Condition = asyncio.Condition()

@workflow.run
async def run(
self, customer_service_state: CustomerServiceWorkflowState | None = None
):
while True:
with trace("Customer service", group_id=workflow.info().workflow_id):
user_input = await self.user_input_queue.get()
self.input_items.append({"content": user_input, "role": "user"})
result = await Runner.run(
self.current_agent,
self.input_items,
context=self.context,
run_config=self.run_config,
)
self.printed_history.append(f"Enter your message: {user_input}")
for new_item in result.new_items:
agent_name = new_item.agent.name
if isinstance(new_item, MessageOutputItem):
self.printed_history.append(
f"{agent_name}: {ItemHelpers.text_message_output(new_item)}"
)
elif isinstance(new_item, HandoffOutputItem):
self.printed_history.append(
f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}"
)
elif isinstance(new_item, HandoffCallItem):
self.printed_history.append(
f"{agent_name}: Handed off to tool {new_item.raw_item.name}"
)
elif isinstance(new_item, ToolCallItem):
self.printed_history.append(f"{agent_name}: Calling a tool")
elif isinstance(new_item, ToolCallOutputItem):
self.printed_history.append(
f"{agent_name}: Tool call output: {new_item.output}"
)
else:
self.printed_history.append(
f"{agent_name}: Skipping item: {new_item.__class__.__name__}"
)
self.input_items = result.to_input_list()
self.current_agent = result.last_agent
async with self.update_condition:
self.update_condition.notify_all()

if workflow.info().is_continue_as_new_suggested():
await workflow.wait_condition(
lambda: workflow.all_handlers_finished(),
timeout=timedelta(seconds=10),
timeout_summary="Continue as new timeout - deadlock avoidance",
)

# Convert input_items to plain dictionaries for serialization
serializable_input_items = []
for item in self.input_items:
if hasattr(item, "model_dump"):
# Convert Pydantic objects to dictionaries
serializable_input_items.append(item.model_dump())
else:
# Already a plain Python object
serializable_input_items.append(item)
workflow.continue_as_new(
CustomerServiceWorkflowState(
printed_history=self.printed_history,
current_agent_name=self.current_agent.name,
context=self.context,
input_items=serializable_input_items,
)
)

@workflow.query
def get_chat_history(self) -> list[str]:
return self.chat_history
return self.printed_history

@workflow.update
async def process_user_message(self, input: ProcessUserMessageInput) -> list[str]:
length = len(self.chat_history)
self.chat_history.append(f"User: {input.user_input}")
with trace("Customer service", group_id=workflow.info().workflow_id):
self.input_items.append({"content": input.user_input, "role": "user"})
result = await Runner.run(
self.current_agent,
self.input_items,
context=self.context,
run_config=self.run_config,
length = len(self.printed_history)
self.user_input_queue.put_nowait(input.user_input)
async with self.update_condition:
await self.update_condition.wait_for(
lambda: len(self.printed_history) > length
)

for new_item in result.new_items:
agent_name = new_item.agent.name
if isinstance(new_item, MessageOutputItem):
self.chat_history.append(
f"{agent_name}: {ItemHelpers.text_message_output(new_item)}"
)
elif isinstance(new_item, HandoffOutputItem):
self.chat_history.append(
f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}"
)
elif isinstance(new_item, ToolCallItem):
self.chat_history.append(f"{agent_name}: Calling a tool")
elif isinstance(new_item, ToolCallOutputItem):
self.chat_history.append(
f"{agent_name}: Tool call output: {new_item.output}"
)
else:
self.chat_history.append(
f"{agent_name}: Skipping item: {new_item.__class__.__name__}"
)
self.input_items = result.to_input_list()
self.current_agent = result.last_agent
workflow.set_current_details("\n\n".join(self.chat_history))
return self.chat_history[length:]
return self.printed_history[length:]

@process_user_message.validator
def validate_process_user_message(self, input: ProcessUserMessageInput) -> None:
if not input.user_input:
raise ValueError("User input cannot be empty.")
if len(input.user_input) > 1000:
raise ValueError("User input is too long. Please limit to 1000 characters.")
if input.chat_length != len(self.chat_history):
if input.chat_length != len(self.printed_history):
raise ValueError("Stale chat history. Please refresh the chat.")
5 changes: 2 additions & 3 deletions openai_agents/workflows/research_agents/research_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

with workflow.unsafe.imports_passed_through():
# TODO: Restore progress updates
from agents import RunConfig, Runner, custom_span, gen_trace_id, trace
from agents import RunConfig, Runner, custom_span, trace

from openai_agents.workflows.research_agents.planner_agent import (
WebSearchItem,
Expand All @@ -28,8 +28,7 @@ def __init__(self):
self.writer_agent = new_writer_agent()

async def run(self, query: str) -> str:
trace_id = gen_trace_id()
with trace("Research trace", trace_id=trace_id):
with trace("Research trace"):
search_plan = await self._plan_searches(query)
search_results = await self._perform_searches(search_plan)
report = await self._write_report(query, search_results)
Expand Down
25 changes: 13 additions & 12 deletions openai_agents/workflows/tools_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import timedelta

from agents import Agent, Runner
from agents import Agent, Runner, trace
from temporalio import workflow
from temporalio.contrib import openai_agents as temporal_agents

Expand All @@ -13,15 +13,16 @@
class ToolsWorkflow:
@workflow.run
async def run(self, question: str) -> str:
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[
temporal_agents.workflow.activity_as_tool(
get_weather, start_to_close_timeout=timedelta(seconds=10)
)
],
)
with trace("Activity as tool"):
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[
temporal_agents.workflow.activity_as_tool(
get_weather, start_to_close_timeout=timedelta(seconds=10)
)
],
)

result = await Runner.run(agent, input=question)
return result.final_output
result = await Runner.run(agent, input=question)
return result.final_output
Loading