Skip to content

Commit ce3fe9e

Browse files
authored
feat: Add kwargs to session interfaces for future extensibility (strands-agents#464)
1 parent 4cf3d72 commit ce3fe9e

File tree

5 files changed

+66
-55
lines changed

5 files changed

+66
-55
lines changed

src/strands/session/file_session_manager.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
3535
3636
"""
3737

38-
def __init__(self, session_id: str, storage_dir: Optional[str] = None):
38+
def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
3939
"""Initialize FileSession with filesystem storage.
4040
4141
Args:
4242
session_id: ID for the session
4343
storage_dir: Directory for local filesystem storage (defaults to temp dir)
44+
**kwargs: Additional keyword arguments for future extensibility.
4445
"""
4546
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
4647
os.makedirs(self.storage_dir, exist_ok=True)
@@ -83,7 +84,7 @@ def _write_file(self, path: str, data: dict[str, Any]) -> None:
8384
with open(path, "w", encoding="utf-8") as f:
8485
json.dump(data, f, indent=2, ensure_ascii=False)
8586

86-
def create_session(self, session: Session) -> Session:
87+
def create_session(self, session: Session, **kwargs: Any) -> Session:
8788
"""Create a new session."""
8889
session_dir = self._get_session_path(session.session_id)
8990
if os.path.exists(session_dir):
@@ -100,7 +101,7 @@ def create_session(self, session: Session) -> Session:
100101

101102
return session
102103

103-
def read_session(self, session_id: str) -> Optional[Session]:
104+
def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
104105
"""Read session data."""
105106
session_file = os.path.join(self._get_session_path(session_id), "session.json")
106107
if not os.path.exists(session_file):
@@ -109,7 +110,15 @@ def read_session(self, session_id: str) -> Optional[Session]:
109110
session_data = self._read_file(session_file)
110111
return Session.from_dict(session_data)
111112

112-
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
113+
def delete_session(self, session_id: str, **kwargs: Any) -> None:
114+
"""Delete session and all associated data."""
115+
session_dir = self._get_session_path(session_id)
116+
if not os.path.exists(session_dir):
117+
raise SessionException(f"Session {session_id} does not exist")
118+
119+
shutil.rmtree(session_dir)
120+
121+
def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
113122
"""Create a new agent in the session."""
114123
agent_id = session_agent.agent_id
115124

@@ -121,15 +130,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
121130
session_data = session_agent.to_dict()
122131
self._write_file(agent_file, session_data)
123132

124-
def delete_session(self, session_id: str) -> None:
125-
"""Delete session and all associated data."""
126-
session_dir = self._get_session_path(session_id)
127-
if not os.path.exists(session_dir):
128-
raise SessionException(f"Session {session_id} does not exist")
129-
130-
shutil.rmtree(session_dir)
131-
132-
def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
133+
def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
133134
"""Read agent data."""
134135
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
135136
if not os.path.exists(agent_file):
@@ -138,7 +139,7 @@ def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
138139
agent_data = self._read_file(agent_file)
139140
return SessionAgent.from_dict(agent_data)
140141

141-
def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
142+
def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
142143
"""Update agent data."""
143144
agent_id = session_agent.agent_id
144145
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
@@ -149,7 +150,7 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
149150
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
150151
self._write_file(agent_file, session_agent.to_dict())
151152

152-
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
153+
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
153154
"""Create a new message for the agent."""
154155
message_file = self._get_message_path(
155156
session_id,
@@ -159,15 +160,15 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
159160
session_dict = session_message.to_dict()
160161
self._write_file(message_file, session_dict)
161162

162-
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
163+
def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]:
163164
"""Read message data."""
164165
message_path = self._get_message_path(session_id, agent_id, message_id)
165166
if not os.path.exists(message_path):
166167
return None
167168
message_data = self._read_file(message_path)
168169
return SessionMessage.from_dict(message_data)
169170

170-
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
171+
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
171172
"""Update message data."""
172173
message_id = session_message.message_id
173174
previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id)
@@ -180,7 +181,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
180181
self._write_file(message_file, session_message.to_dict())
181182

182183
def list_messages(
183-
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
184+
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
184185
) -> list[SessionMessage]:
185186
"""List messages for an agent with pagination."""
186187
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")

src/strands/session/repository_session_manager.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Repository session manager implementation."""
22

33
import logging
4-
from typing import Optional
4+
from typing import Any, Optional
55

66
from ..agent.agent import Agent
77
from ..agent.state import AgentState
@@ -22,20 +22,18 @@
2222
class RepositorySessionManager(SessionManager):
2323
"""Session manager for persisting agents in a SessionRepository."""
2424

25-
def __init__(
26-
self,
27-
session_id: str,
28-
session_repository: SessionRepository,
29-
):
25+
def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any):
3026
"""Initialize the RepositorySessionManager.
3127
3228
If no session with the specified session_id exists yet, it will be created
3329
in the session_repository.
3430
3531
Args:
36-
session_id: ID to use for the session. A new session with this id will be created if it does
37-
not exist in the reposiory yet
38-
session_repository: Underlying session repository to use to store the sessions state.
32+
session_id: ID to use for the session. A new session with this id will be created if it does
33+
not exist in the reposiory yet
34+
session_repository: Underlying session repository to use to store the sessions state.
35+
**kwargs: Additional keyword arguments for future extensibility.
36+
3937
"""
4038
self.session_repository = session_repository
4139
self.session_id = session_id
@@ -51,12 +49,13 @@ def __init__(
5149
# Keep track of the latest message of each agent in case we need to redact it.
5250
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
5351

54-
def append_message(self, message: Message, agent: Agent) -> None:
52+
def append_message(self, message: Message, agent: Agent, **kwargs: Any) -> None:
5553
"""Append a message to the agent's session.
5654
5755
Args:
5856
message: Message to add to the agent in the session
5957
agent: Agent to append the message to
58+
**kwargs: Additional keyword arguments for future extensibility.
6059
"""
6160
# Calculate the next index (0 if this is the first message, otherwise increment the previous index)
6261
latest_agent_message = self._latest_agent_message[agent.agent_id]
@@ -69,35 +68,38 @@ def append_message(self, message: Message, agent: Agent) -> None:
6968
self._latest_agent_message[agent.agent_id] = session_message
7069
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
7170

72-
def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
71+
def redact_latest_message(self, redact_message: Message, agent: Agent, **kwargs: Any) -> None:
7372
"""Redact the latest message appended to the session.
7473
7574
Args:
7675
redact_message: New message to use that contains the redact content
7776
agent: Agent to apply the message redaction to
77+
**kwargs: Additional keyword arguments for future extensibility.
7878
"""
7979
latest_agent_message = self._latest_agent_message[agent.agent_id]
8080
if latest_agent_message is None:
8181
raise SessionException("No message to redact.")
8282
latest_agent_message.redact_message = redact_message
8383
return self.session_repository.update_message(self.session_id, agent.agent_id, latest_agent_message)
8484

85-
def sync_agent(self, agent: Agent) -> None:
85+
def sync_agent(self, agent: Agent, **kwargs: Any) -> None:
8686
"""Serialize and update the agent into the session repository.
8787
8888
Args:
8989
agent: Agent to sync to the session.
90+
**kwargs: Additional keyword arguments for future extensibility.
9091
"""
9192
self.session_repository.update_agent(
9293
self.session_id,
9394
SessionAgent.from_agent(agent),
9495
)
9596

96-
def initialize(self, agent: Agent) -> None:
97+
def initialize(self, agent: Agent, **kwargs: Any) -> None:
9798
"""Initialize an agent with a session.
9899
99100
Args:
100101
agent: Agent to initialize from the session
102+
**kwargs: Additional keyword arguments for future extensibility.
101103
"""
102104
if agent.agent_id in self._latest_agent_message:
103105
raise SessionException("The `agent_id` of an agent must be unique in a session.")

src/strands/session/s3_session_manager.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
boto_session: Optional[boto3.Session] = None,
4545
boto_client_config: Optional[BotocoreConfig] = None,
4646
region_name: Optional[str] = None,
47+
**kwargs: Any,
4748
):
4849
"""Initialize S3SessionManager with S3 storage.
4950
@@ -54,6 +55,7 @@ def __init__(
5455
boto_session: Optional boto3 session
5556
boto_client_config: Optional boto3 client configuration
5657
region_name: AWS region for S3 storage
58+
**kwargs: Additional keyword arguments for future extensibility.
5759
"""
5860
self.bucket = bucket
5961
self.prefix = prefix
@@ -91,6 +93,8 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
9193
session_id: ID of the session
9294
agent_id: ID of the agent
9395
message_id: Index of the message
96+
**kwargs: Additional keyword arguments for future extensibility.
97+
9498
Returns:
9599
The key for the message
96100
"""
@@ -121,7 +125,7 @@ def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None:
121125
except ClientError as e:
122126
raise SessionException(f"Failed to write S3 object {key}: {e}") from e
123127

124-
def create_session(self, session: Session) -> Session:
128+
def create_session(self, session: Session, **kwargs: Any) -> Session:
125129
"""Create a new session in S3."""
126130
session_key = f"{self._get_session_path(session.session_id)}session.json"
127131

@@ -138,15 +142,15 @@ def create_session(self, session: Session) -> Session:
138142
self._write_s3_object(session_key, session_dict)
139143
return session
140144

141-
def read_session(self, session_id: str) -> Optional[Session]:
145+
def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
142146
"""Read session data from S3."""
143147
session_key = f"{self._get_session_path(session_id)}session.json"
144148
session_data = self._read_s3_object(session_key)
145149
if session_data is None:
146150
return None
147151
return Session.from_dict(session_data)
148152

149-
def delete_session(self, session_id: str) -> None:
153+
def delete_session(self, session_id: str, **kwargs: Any) -> None:
150154
"""Delete session and all associated data from S3."""
151155
session_prefix = self._get_session_path(session_id)
152156
try:
@@ -169,22 +173,22 @@ def delete_session(self, session_id: str) -> None:
169173
except ClientError as e:
170174
raise SessionException(f"S3 error deleting session {session_id}: {e}") from e
171175

172-
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
176+
def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
173177
"""Create a new agent in S3."""
174178
agent_id = session_agent.agent_id
175179
agent_dict = session_agent.to_dict()
176180
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
177181
self._write_s3_object(agent_key, agent_dict)
178182

179-
def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
183+
def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
180184
"""Read agent data from S3."""
181185
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
182186
agent_data = self._read_s3_object(agent_key)
183187
if agent_data is None:
184188
return None
185189
return SessionAgent.from_dict(agent_data)
186190

187-
def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
191+
def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
188192
"""Update agent data in S3."""
189193
agent_id = session_agent.agent_id
190194
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
@@ -196,22 +200,22 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
196200
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
197201
self._write_s3_object(agent_key, session_agent.to_dict())
198202

199-
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
203+
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
200204
"""Create a new message in S3."""
201205
message_id = session_message.message_id
202206
message_dict = session_message.to_dict()
203207
message_key = self._get_message_path(session_id, agent_id, message_id)
204208
self._write_s3_object(message_key, message_dict)
205209

206-
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
210+
def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]:
207211
"""Read message data from S3."""
208212
message_key = self._get_message_path(session_id, agent_id, message_id)
209213
message_data = self._read_s3_object(message_key)
210214
if message_data is None:
211215
return None
212216
return SessionMessage.from_dict(message_data)
213217

214-
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
218+
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
215219
"""Update message data in S3."""
216220
message_id = session_message.message_id
217221
previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id)
@@ -224,7 +228,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
224228
self._write_s3_object(message_key, session_message.to_dict())
225229

226230
def list_messages(
227-
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
231+
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
228232
) -> List[SessionMessage]:
229233
"""List messages for an agent with pagination from S3."""
230234
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"

src/strands/session/session_manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,35 +35,39 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
3535
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
3636

3737
@abstractmethod
38-
def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None:
38+
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
3939
"""Redact the message most recently appended to the agent in the session.
4040
4141
Args:
4242
redact_message: New message to use that contains the redact content
4343
agent: Agent to apply the message redaction to
44+
**kwargs: Additional keyword arguments for future extensibility.
4445
"""
4546

4647
@abstractmethod
47-
def append_message(self, message: Message, agent: "Agent") -> None:
48+
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
4849
"""Append a message to the agent's session.
4950
5051
Args:
5152
message: Message to add to the agent in the session
5253
agent: Agent to append the message to
54+
**kwargs: Additional keyword arguments for future extensibility.
5355
"""
5456

5557
@abstractmethod
56-
def sync_agent(self, agent: "Agent") -> None:
58+
def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
5759
"""Serialize and sync the agent with the session storage.
5860
5961
Args:
6062
agent: Agent who should be synchronized with the session storage
63+
**kwargs: Additional keyword arguments for future extensibility.
6164
"""
6265

6366
@abstractmethod
64-
def initialize(self, agent: "Agent") -> None:
67+
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
6568
"""Initialize an agent with a session.
6669
6770
Args:
6871
agent: Agent to initialize
72+
**kwargs: Additional keyword arguments for future extensibility.
6973
"""

0 commit comments

Comments
 (0)