Skip to content

Commit bdff8d5

Browse files
authored
fix: Fix session manager agent init (strands-agents#458)
1 parent 6d46291 commit bdff8d5

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

src/strands/session/repository_session_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,20 @@ def initialize(self, agent: Agent) -> None:
115115
session_agent = SessionAgent.from_agent(agent)
116116
self.session_repository.create_agent(self.session_id, session_agent)
117117
# Initialize messages with sequential indices
118+
session_message = None
118119
for i, message in enumerate(agent.messages):
119120
session_message = SessionMessage.from_message(message, i)
120121
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
122+
self._latest_agent_message[agent.agent_id] = session_message
121123
else:
122124
logger.debug(
123125
"agent_id=<%s> | session_id=<%s> | restoring agent",
124126
agent.agent_id,
125127
self.session_id,
126128
)
127-
agent.messages = [
128-
session_message.to_message()
129-
for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id)
130-
]
129+
session_messages = self.session_repository.list_messages(self.session_id, agent.agent_id)
130+
if len(session_messages) > 0:
131+
self._latest_agent_message[agent.agent_id] = session_messages[-1]
132+
agent.messages = [session_message.to_message() for session_message in session_messages]
133+
131134
agent.state = AgentState(session_agent.state)

tests/fixtures/mock_session_repository.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_agent(self, session_id, session_agent):
3333
if agent_id in self.agents.get(session_id, {}):
3434
raise SessionException(f"Agent {agent_id} already exists in session {session_id}")
3535
self.agents.setdefault(session_id, {})[agent_id] = session_agent
36-
self.messages.setdefault(session_id, {}).setdefault(agent_id, [])
36+
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})
3737
return session_agent
3838

3939
def read_agent(self, session_id, agent_id):
@@ -53,37 +53,34 @@ def update_agent(self, session_id, session_agent):
5353

5454
def create_message(self, session_id, agent_id, session_message):
5555
"""Create a message."""
56+
message_id = session_message.message_id
5657
if session_id not in self.sessions:
5758
raise SessionException(f"Session {session_id} does not exist")
5859
if agent_id not in self.agents.get(session_id, {}):
59-
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
60-
self.messages.setdefault(session_id, {}).setdefault(agent_id, []).append(session_message)
60+
raise SessionException(f"Agent {agent_id} does not exists in session {session_id}")
61+
if message_id in self.messages.get(session_id, {}).get(agent_id, {}):
62+
raise SessionException(f"Message {message_id} already exists in agent {agent_id} in session {session_id}")
63+
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})[message_id] = session_message
6164

6265
def read_message(self, session_id, agent_id, message_id):
6366
"""Read a message."""
6467
if session_id not in self.sessions:
6568
return None
6669
if agent_id not in self.agents.get(session_id, {}):
6770
return None
68-
for message in self.messages.get(session_id, {}).get(agent_id, []):
69-
if message.message_id == message_id:
70-
return message
71-
return None
71+
return self.messages.get(session_id, {}).get(agent_id, {}).get(message_id)
7272

7373
def update_message(self, session_id, agent_id, session_message):
7474
"""Update a message."""
75+
7576
message_id = session_message.message_id
7677
if session_id not in self.sessions:
7778
raise SessionException(f"Session {session_id} does not exist")
7879
if agent_id not in self.agents.get(session_id, {}):
7980
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
80-
81-
for i, message in enumerate(self.messages.get(session_id, {}).get(agent_id, [])):
82-
if message.message_id == message_id:
83-
self.messages[session_id][agent_id][i] = session_message
84-
return
85-
86-
raise SessionException(f"Message {message_id} does not exist")
81+
if message_id not in self.messages.get(session_id, {}).get(agent_id, {}):
82+
raise SessionException(f"Message {message_id} does not exist in session {session_id}")
83+
self.messages[session_id][agent_id][message_id] = session_message
8784

8885
def list_messages(self, session_id, agent_id, limit=None, offset=0):
8986
"""List messages."""
@@ -92,7 +89,9 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0):
9289
if agent_id not in self.agents.get(session_id, {}):
9390
return []
9491

95-
messages = self.messages.get(session_id, {}).get(agent_id, [])
92+
messages = self.messages.get(session_id, {}).get(agent_id, {})
93+
sorted_messages = [messages[key] for key in sorted(messages.keys())]
94+
9695
if limit is not None:
97-
return messages[offset : offset + limit]
98-
return messages[offset:]
96+
return sorted_messages[offset : offset + limit]
97+
return sorted_messages[offset:]

tests/strands/agent/test_agent.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from strands.session.repository_session_manager import RepositorySessionManager
2121
from strands.types.content import Messages
2222
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
23-
from strands.types.session import Session, SessionAgent, SessionType
23+
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
2424
from tests.fixtures.mock_session_repository import MockedSessionRepository
2525
from tests.fixtures.mocked_model_provider import MockedModelProvider
2626

@@ -1428,6 +1428,26 @@ def test_agent_restored_from_session_management():
14281428
assert agent.state.get("foo") == "bar"
14291429

14301430

1431+
def test_agent_restored_from_session_management_with_message():
1432+
mock_session_repository = MockedSessionRepository()
1433+
mock_session_repository.create_session(Session(session_id="123", session_type=SessionType.AGENT))
1434+
mock_session_repository.create_agent(
1435+
"123",
1436+
SessionAgent(
1437+
agent_id="default",
1438+
state={"foo": "bar"},
1439+
),
1440+
)
1441+
mock_session_repository.create_message(
1442+
"123", "default", SessionMessage({"role": "user", "content": [{"text": "Hello!"}]}, 0)
1443+
)
1444+
session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository)
1445+
1446+
agent = Agent(session_manager=session_manager)
1447+
1448+
assert agent.state.get("foo") == "bar"
1449+
1450+
14311451
def test_agent_redacts_input_on_triggered_guardrail():
14321452
mocked_model = MockedModelProvider(
14331453
[{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}]
@@ -1484,3 +1504,29 @@ def test_agent_restored_from_session_management_with_redacted_input():
14841504

14851505
# Assert that the restored agent redacted message is equal to the original agent
14861506
assert agent.messages[0] == agent_2.messages[0]
1507+
1508+
1509+
def test_agent_restored_from_session_management_with_correct_index():
1510+
mock_model_provider = MockedModelProvider(
1511+
[{"role": "assistant", "content": [{"text": "hello!"}]}, {"role": "assistant", "content": [{"text": "world!"}]}]
1512+
)
1513+
mock_session_repository = MockedSessionRepository()
1514+
session_manager = RepositorySessionManager(session_id="test", session_repository=mock_session_repository)
1515+
agent = Agent(session_manager=session_manager, model=mock_model_provider)
1516+
agent("Hello!")
1517+
1518+
assert len(mock_session_repository.list_messages("test", agent.agent_id)) == 2
1519+
1520+
session_manager_2 = RepositorySessionManager(session_id="test", session_repository=mock_session_repository)
1521+
agent_2 = Agent(session_manager=session_manager_2, model=mock_model_provider)
1522+
1523+
assert len(agent_2.messages) == 2
1524+
assert agent_2.messages[1]["content"][0]["text"] == "hello!"
1525+
1526+
agent_2("Hello!")
1527+
1528+
assert len(agent_2.messages) == 4
1529+
session_messages = mock_session_repository.list_messages("test", agent_2.agent_id)
1530+
assert (len(session_messages)) == 4
1531+
assert session_messages[1].message["content"][0]["text"] == "hello!"
1532+
assert session_messages[3].message["content"][0]["text"] == "world!"

0 commit comments

Comments
 (0)