Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 127 additions & 23 deletions src/strands_tools/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@
# Set up logging
logger = logging.getLogger(__name__)

# Global session manager to avoid passing non-serializable boto3.Session objects
_SESSION_STORE = {}


def set_memory_session(session: Optional[boto3.Session] = None, key: str = "default") -> None:
"""
Store a boto3 session for use by the memory tool.

Args:
session: The boto3 Session to store
key: Optional key to store multiple sessions (default: "default")
"""
if session is not None:
_SESSION_STORE[key] = session
elif key in _SESSION_STORE:
del _SESSION_STORE[key]


def get_memory_session(key: str = "default") -> Optional[boto3.Session]:
"""
Retrieve a stored boto3 session.

Args:
key: The key used to store the session (default: "default")

Returns:
The stored boto3 Session or None
"""
return _SESSION_STORE.get(key)


class MemoryServiceClient:
"""
Expand All @@ -105,7 +135,12 @@ class MemoryServiceClient:
session: The boto3 session used for API calls
"""

def __init__(self, region: str = None, profile_name: Optional[str] = None):
def __init__(
self,
region: str = None,
profile_name: Optional[str] = None,
session: Optional[boto3.Session] = None,
):
"""
Initialize the memory service client.

Expand All @@ -115,14 +150,13 @@ def __init__(self, region: str = None, profile_name: Optional[str] = None):
"""
self.region = region or os.getenv("AWS_REGION", "us-west-2")
self.profile_name = profile_name
self.session = session
self._agent_client = None
self._runtime_client = None

# Set up session if profile is provided
if profile_name:
self.session = boto3.Session(profile_name=profile_name)
else:
self.session = boto3.Session()

@property
def agent_client(self):
Expand Down Expand Up @@ -225,7 +259,13 @@ def get_document(self, kb_id: str, data_source_id: str = None, document_id: str

return self.agent_client.get_knowledge_base_documents(**get_request)

def store_document(self, kb_id: str, data_source_id: str = None, content: str = None, title: str = None):
def store_document(
self,
kb_id: str,
data_source_id: str = None,
content: str = None,
title: str = None,
):
"""
Store a document in the knowledge base.

Expand Down Expand Up @@ -522,7 +562,11 @@ def format_retrieve_response(self, response: Dict, min_score: float = 0.0) -> Li


# Factory functions for dependency injection
def get_memory_service_client(region: str = None, profile_name: str = None) -> MemoryServiceClient:
def get_memory_service_client(
region: str = None,
profile_name: str = None,
session: Optional[boto3.Session] = None,
) -> MemoryServiceClient:
"""
Factory function to create a memory service client.

Expand All @@ -535,7 +579,7 @@ def get_memory_service_client(region: str = None, profile_name: str = None) -> M
Returns:
An initialized MemoryServiceClient instance
"""
return MemoryServiceClient(region=region, profile_name=profile_name)
return MemoryServiceClient(region=region, profile_name=profile_name, session=session)


def get_memory_formatter() -> MemoryFormatter:
Expand All @@ -562,6 +606,7 @@ def memory(
next_token: Optional[str] = None,
min_score: float = None,
region_name: str = None,
session_key: str = "default", # Use a key to retrieve the session instead
) -> Dict[str, Any]:
"""
Manage content in a Bedrock Knowledge Base (store, delete, list, get, or retrieve).
Expand All @@ -585,6 +630,8 @@ def memory(
min_score: Minimum relevance score threshold (0.0-1.0) for 'retrieve' action. Default is 0.4.
region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable.
If AWS_REGION is not specified, it will default to us-west-2.
session_key: Key to retrieve a pre-stored boto3 session (default: "default").
Use set_memory_session() to store a session before calling this tool.

Returns:
A dictionary containing the result of the operation.
Expand All @@ -596,11 +643,15 @@ def memory(
- Operation can be cancelled by the user during confirmation
- Retrieve provides semantic search across all documents in the knowledge base
- Knowledge base IDs must contain only alphanumeric characters (no hyphens or special characters)
- To use a custom boto3 session, call set_memory_session(session) before using this tool
"""
console = console_util.create()

# Retrieve the session from the global store
session = get_memory_session(session_key)

# Initialize the client and formatter using factory functions
client = get_memory_service_client(region=region_name)
client = get_memory_service_client(region=region_name, session=session)
formatter = get_memory_formatter()

# Get environment variables at runtime
Expand Down Expand Up @@ -659,18 +710,30 @@ def memory(
if action == "store":
# Validate content
if not content or not content.strip():
return {"status": "error", "content": [{"text": "❌ Content cannot be empty"}]}
return {
"status": "error",
"content": [{"text": "❌ Content cannot be empty"}],
}

# Preview what will be stored
doc_title = title or f"Memory {time.strftime('%Y%m%d_%H%M%S')}"
content_preview = content[:15000] + "..." if len(content) > 15000 else content

console.print(Panel(content_preview, title=f"[bold green]{doc_title}", border_style="green"))
console.print(
Panel(
content_preview,
title=f"[bold green]{doc_title}",
border_style="green",
)
)

elif action == "delete":
# Validate document_id
if not document_id:
return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for delete operation"}]}
return {
"status": "error",
"content": [{"text": "❌ Document ID cannot be empty for delete operation"}],
}

# Try to get document info first for better context
try:
Expand Down Expand Up @@ -738,7 +801,10 @@ def memory(
if action == "store":
# Validate content if not already done in confirmation step
if not needs_confirmation and (not content or not content.strip()):
return {"status": "error", "content": [{"text": "❌ Content cannot be empty"}]}
return {
"status": "error",
"content": [{"text": "❌ Content cannot be empty"}],
}

# Generate a title if none provided
store_title = title
Expand All @@ -754,7 +820,10 @@ def memory(
elif action == "delete":
# Validate document_id if not already done in confirmation step
if not needs_confirmation and not document_id:
return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for delete operation"}]}
return {
"status": "error",
"content": [{"text": "❌ Document ID cannot be empty for delete operation"}],
}

# Delete the document
response = client.delete_document(kb_id, data_source_id, document_id)
Expand All @@ -779,7 +848,10 @@ def memory(
elif action == "get":
# Validate document_id
if not document_id:
return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for get operation"}]}
return {
"status": "error",
"content": [{"text": "❌ Document ID cannot be empty for get operation"}],
}

try:
# Get document
Expand All @@ -788,7 +860,10 @@ def memory(
# Check if document exists
document_details = response.get("documentDetails", [])
if not document_details:
return {"status": "error", "content": [{"text": f"❌ Document not found: {document_id}"}]}
return {
"status": "error",
"content": [{"text": f"❌ Document not found: {document_id}"}],
}

# Get the first document detail
document_detail = document_details[0]
Expand Down Expand Up @@ -956,15 +1031,24 @@ def memory(
],
}
except Exception as e:
return {"status": "error", "content": [{"text": f"❌ Error retrieving document content: {str(e)}"}]}
return {
"status": "error",
"content": [{"text": f"❌ Error retrieving document content: {str(e)}"}],
}

except Exception as e:
return {"status": "error", "content": [{"text": f"❌ Error retrieving document: {str(e)}"}]}
return {
"status": "error",
"content": [{"text": f"❌ Error retrieving document: {str(e)}"}],
}

elif action == "list":
# Validate max_results
if max_results < 1 or max_results > 1000:
return {"status": "error", "content": [{"text": "❌ max_results must be between 1 and 1000"}]}
return {
"status": "error",
"content": [{"text": "❌ max_results must be between 1 and 1000"}],
}

response = client.list_documents(kb_id, data_source_id, max_results, next_token)
formatted_content = formatter.format_list_response(response)
Expand All @@ -983,22 +1067,36 @@ def memory(

elif action == "retrieve":
if not query:
return {"status": "error", "content": [{"text": "❌ No query provided for retrieval."}]}
return {
"status": "error",
"content": [{"text": "❌ No query provided for retrieval."}],
}

# Validate parameters
if min_score < 0.0 or min_score > 1.0:
return {"status": "error", "content": [{"text": "❌ min_score must be between 0.0 and 1.0"}]}
return {
"status": "error",
"content": [{"text": "❌ min_score must be between 0.0 and 1.0"}],
}

if max_results < 1 or max_results > 1000:
return {"status": "error", "content": [{"text": "❌ max_results must be between 1 and 1000"}]}
return {
"status": "error",
"content": [{"text": "❌ max_results must be between 1 and 1000"}],
}

# Set default max results if not provided
if max_results is None:
max_results = 5

try:
# Perform retrieval
response = client.retrieve(kb_id=kb_id, query=query, max_results=max_results, next_token=next_token)
response = client.retrieve(
kb_id=kb_id,
query=query,
max_results=max_results,
next_token=next_token,
)

# Format and filter response
formatted_content = formatter.format_retrieve_response(response, min_score)
Expand All @@ -1023,7 +1121,13 @@ def memory(
},
],
}
return {"status": "error", "content": [{"text": f"❌ Error during retrieval: {str(e)}"}]}
return {
"status": "error",
"content": [{"text": f"❌ Error during retrieval: {str(e)}"}],
}

except Exception as e:
return {"status": "error", "content": [{"text": f"❌ Error during {action} operation: {str(e)}"}]}
return {
"status": "error",
"content": [{"text": f"❌ Error during {action} operation: {str(e)}"}],
}
Loading