diff --git a/docs/api-reference/openapi.json b/docs/api-reference/openapi.json index fe9ce054..46441d33 100644 --- a/docs/api-reference/openapi.json +++ b/docs/api-reference/openapi.json @@ -788,7 +788,7 @@ "type" : "string" } }, { - "description" : "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned.", + "description" : "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned. \n\nWARNING!\n Use `limit` only for read-only/display purposes (pagination, viewing). Do NOT use `limit` to truncate messages before sending to LLM as it may cause tool-call and tool-result unpairing issues. Instead, use the `token_limit` edit strategy in `edit_strategies` parameter to safely manage message context size.", "in" : "query", "name" : "limit", "schema" : { diff --git a/docs/store/editing.mdx b/docs/store/editing.mdx index b012fac6..0c177df3 100644 --- a/docs/store/editing.mdx +++ b/docs/store/editing.mdx @@ -55,7 +55,7 @@ Please do not use the token count to calculate the cost of LLM, as the actual to ## Context Editing On-the-fly Acontext supports to edit the session context when you obtain the current messages. -The basic usage is to pass the `edit_strategy` and `edit_params` to the `get_messages` method to get the edited session messages without modifying the original session storage: +The basic usage is to pass the `edit_strategies` to the `get_messages` method to get the edited session messages without modifying the original session storage: ```python Python @@ -99,6 +99,46 @@ const originalSession = await client.sessions.getMessages('session-uuid'); ``` +### Token Limit +This strategy truncates messages based on token count, removing the oldest messages until the total token count is within the specified limit. + +It's useful for managing context window limits and ensuring your session stays within model constraints. + +It will: +- Removes messages from oldest to newest +- Maintains tool-call/tool-result pairing (when removing a tool-call, its corresponding tool-result is also removed) + + +```python Python +# Limit session to 20,000 tokens +edited_session = client.sessions.get_messages( + session_id="session-uuid", + edit_strategies=[ + { + "type": "token_limit", + "params": { + "limit_tokens": 20000 + } + } + ], +) +``` +```typescript TypeScript +// Limit session to 20,000 tokens +const editedSession = await client.sessions.getMessages('session-uuid', { + editStrategies: [ + { + type: 'token_limit' as const, + params: { + limit_tokens: 20000 + } + } + ], +}); +``` + + + ### Remove Tool Result This strategy will replace the oldest tool results' content with a placeholder text to reduce the session context, while keeping the most recent N tool results intact. @@ -153,6 +193,8 @@ const editedSession = await client.sessions.getMessages('session-uuid', { ``` + + ## Context Engineering and Editing Context Engineering is an emerging discipline focused on designing, managing, and optimizing the information provided to large language models (LLMs) and AI agents to enhance their performance, reliability, and consistency. diff --git a/docs/store/messages/multi-modal.mdx b/docs/store/messages/multi-modal.mdx index 189c115d..a10f6563 100644 --- a/docs/store/messages/multi-modal.mdx +++ b/docs/store/messages/multi-modal.mdx @@ -558,7 +558,6 @@ client = AcontextClient( result = client.sessions.get_messages( session_id="session_uuid", format="anthropic", # or "openai" - limit=50 ) print(f"Retrieved {len(result.items)} messages") diff --git a/docs/store/messages/multi-provider.mdx b/docs/store/messages/multi-provider.mdx index fbe64c5e..21733f51 100644 --- a/docs/store/messages/multi-provider.mdx +++ b/docs/store/messages/multi-provider.mdx @@ -223,7 +223,7 @@ Each message receives a unique ID upon creation. You can use these IDs to refere ### Get all messages from a session -Retrieve messages from a session with pagination support: +Retrieve messages from a session: ```python Python @@ -237,22 +237,12 @@ client = AcontextClient( # Get messages from a session result = client.sessions.get_messages( session_id="session_uuid", - limit=50, - format="openai", - time_desc=True # Most recent first + format="openai" ) print(f"Retrieved {len(result.items)} messages") for msg in result.items: print(f"- {msg.role}: {msg.content[:50]}...") - -# Handle pagination if there are more messages -if result.next_cursor: - next_page = client.sessions.get_messages( - session_id="session_uuid", - cursor=result.next_cursor, - limit=50 - ) ``` ```typescript TypeScript @@ -265,23 +255,13 @@ const client = new AcontextClient({ // Get messages from a session const result = await client.sessions.getMessages('session_uuid', { - limit: 50, - format: 'openai', - timeDesc: true // Most recent first + format: 'openai' }); console.log(`Retrieved ${result.items.length} messages`); result.items.forEach(msg => { console.log(`- ${msg.role}: ${msg.content.substring(0, 50)}...`); }); - -// Handle pagination if there are more messages -if (result.nextCursor) { - const nextPage = await client.sessions.getMessages('session_uuid', { - cursor: result.nextCursor, - limit: 50 - }); -} ``` @@ -341,8 +321,7 @@ try: # 3. Retrieve messages later result = client.sessions.get_messages( session_id=session.id, - format="openai", - time_desc=False # Chronological order + format="openai" ) print(f"\nRetrieved conversation ({len(result.items)} messages):") @@ -386,8 +365,7 @@ async function storeAndRetrieveConversation() { // 3. Retrieve messages later const result = await client.sessions.getMessages(session.id, { - format: 'openai', - timeDesc: false // Chronological order + format: 'openai' }); console.log(`\nRetrieved conversation (${result.items.length} messages):`); @@ -404,45 +382,6 @@ storeAndRetrieveConversation(); ``` -## Pagination and limits - -When retrieving large message histories, use pagination to efficiently process results: - - - -Start with a reasonable page size (e.g., 50-100 messages) based on your use case. - -```python -result = client.sessions.get_messages(session_id="session_uuid", limit=50) -``` - - - -Look for the `next_cursor` field in the response to determine if more messages exist. - -```python -if result.next_cursor: - print("More messages available") -``` - - - -Use the cursor to retrieve the next page of results. - -```python -next_page = client.sessions.get_messages( - session_id="session_uuid", - cursor=result.next_cursor, - limit=50 -) -``` - - - - -The maximum limit per request is typically 100 messages. Check your plan's specific limits in the dashboard. - - ## Managing sessions ### Delete a session @@ -543,13 +482,6 @@ for (const session of sessions.items.slice(10)) { - Use **Anthropic format** if you're primarily using Claude models - You can convert between formats when retrieving messages - - -- Set appropriate page sizes (50-100 messages typically works well) -- Cache results when possible to reduce API calls -- Use `time_desc=True` to get most recent messages first -- Process pages asynchronously for better performance with large histories - ## Next steps diff --git a/src/client/acontext-py/src/acontext/resources/async_sessions.py b/src/client/acontext-py/src/acontext/resources/async_sessions.py index 7672e695..a57b9788 100644 --- a/src/client/acontext-py/src/acontext/resources/async_sessions.py +++ b/src/client/acontext-py/src/acontext/resources/async_sessions.py @@ -278,7 +278,9 @@ async def get_messages( time_desc: Order by created_at descending if True, ascending if False. Defaults to None. edit_strategies: Optional list of edit strategies to apply before format conversion. Each strategy is a dict with 'type' and 'params' keys. - Example: [{"type": "remove_tool_result", "params": {"keep_recent_n_tool_results": 3}}] + Examples: + - Remove tool results: [{"type": "remove_tool_result", "params": {"keep_recent_n_tool_results": 3}}] + - Token limit: [{"type": "token_limit", "params": {"limit_tokens": 20000}}] Defaults to None. Returns: diff --git a/src/client/acontext-py/src/acontext/resources/sessions.py b/src/client/acontext-py/src/acontext/resources/sessions.py index d73d53fd..496395e2 100644 --- a/src/client/acontext-py/src/acontext/resources/sessions.py +++ b/src/client/acontext-py/src/acontext/resources/sessions.py @@ -278,7 +278,9 @@ def get_messages( time_desc: Order by created_at descending if True, ascending if False. Defaults to None. edit_strategies: Optional list of edit strategies to apply before format conversion. Each strategy is a dict with 'type' and 'params' keys. - Example: [{"type": "remove_tool_result", "params": {"keep_recent_n_tool_results": 3}}] + Examples: + - Remove tool results: [{"type": "remove_tool_result", "params": {"keep_recent_n_tool_results": 3}}] + - Token limit: [{"type": "token_limit", "params": {"limit_tokens": 20000}}] Defaults to None. Returns: diff --git a/src/client/acontext-py/src/acontext/types/session.py b/src/client/acontext-py/src/acontext/types/session.py index 233b3e76..9e165dce 100644 --- a/src/client/acontext-py/src/acontext/types/session.py +++ b/src/client/acontext-py/src/acontext/types/session.py @@ -30,9 +30,36 @@ class RemoveToolResultStrategy(TypedDict): params: RemoveToolResultParams +class TokenLimitParams(TypedDict): + """Parameters for the token_limit edit strategy. + + Attributes: + limit_tokens: Maximum number of tokens to keep. Required parameter. + Messages will be removed from oldest to newest until total tokens <= limit_tokens. + Tool-call and tool-result pairs are always removed together. + """ + + limit_tokens: int + + +class TokenLimitStrategy(TypedDict): + """Edit strategy to truncate messages based on token count. + + Removes oldest messages until the total token count is within the specified limit. + Maintains tool-call/tool-result pairing - when removing a message with tool-calls, + the corresponding tool-result messages are also removed. + + Example: + {"type": "token_limit", "params": {"limit_tokens": 20000}} + """ + + type: Literal["token_limit"] + params: TokenLimitParams + + # Union type for all edit strategies # When adding new strategies, add them to this Union: EditStrategy = Union[RemoveToolResultStrategy, OtherStrategy, ...] -EditStrategy = Union[RemoveToolResultStrategy] +EditStrategy = Union[RemoveToolResultStrategy, TokenLimitStrategy] class Asset(BaseModel): diff --git a/src/client/acontext-ts/src/resources/sessions.ts b/src/client/acontext-ts/src/resources/sessions.ts index b48c8ddb..35bf3f6e 100644 --- a/src/client/acontext-ts/src/resources/sessions.ts +++ b/src/client/acontext-ts/src/resources/sessions.ts @@ -179,6 +179,22 @@ export class SessionsAPI { } } + /** + * Get messages for a session. + * + * @param sessionId - The UUID of the session. + * @param options - Options for retrieving messages. + * @param options.limit - Maximum number of messages to return. + * @param options.cursor - Cursor for pagination. + * @param options.withAssetPublicUrl - Whether to include presigned URLs for assets. + * @param options.format - The format of the messages ('acontext', 'openai', or 'anthropic'). + * @param options.timeDesc - Order by created_at descending if true, ascending if false. + * @param options.editStrategies - Optional list of edit strategies to apply before format conversion. + * Examples: + * - Remove tool results: [{ type: 'remove_tool_result', params: { keep_recent_n_tool_results: 3 } }] + * - Token limit: [{ type: 'token_limit', params: { limit_tokens: 20000 } }] + * @returns GetMessagesOutput containing the list of messages and pagination information. + */ async getMessages( sessionId: string, options?: { diff --git a/src/client/acontext-ts/src/types/session.ts b/src/client/acontext-ts/src/types/session.ts index a231b842..49b81e70 100644 --- a/src/client/acontext-ts/src/types/session.ts +++ b/src/client/acontext-ts/src/types/session.ts @@ -142,11 +142,44 @@ export const RemoveToolResultStrategySchema = z.object({ export type RemoveToolResultStrategy = z.infer; +/** + * Parameters for the token_limit edit strategy. + */ +export const TokenLimitParamsSchema = z.object({ + /** + * Maximum number of tokens to keep. Required parameter. + * Messages will be removed from oldest to newest until total tokens <= limit_tokens. + * Tool-call and tool-result pairs are always removed together. + */ + limit_tokens: z.number(), +}); + +export type TokenLimitParams = z.infer; + +/** + * Edit strategy to truncate messages based on token count. + * + * Removes oldest messages until the total token count is within the specified limit. + * Maintains tool-call/tool-result pairing - when removing a message with tool-calls, + * the corresponding tool-result messages are also removed. + * + * Example: { type: 'token_limit', params: { limit_tokens: 20000 } } + */ +export const TokenLimitStrategySchema = z.object({ + type: z.literal('token_limit'), + params: TokenLimitParamsSchema, +}); + +export type TokenLimitStrategy = z.infer; + /** * Union schema for all edit strategies. * When adding new strategies, extend this union: z.union([RemoveToolResultStrategySchema, OtherStrategySchema, ...]) */ -export const EditStrategySchema = z.union([RemoveToolResultStrategySchema]); +export const EditStrategySchema = z.union([ + RemoveToolResultStrategySchema, + TokenLimitStrategySchema, +]); export type EditStrategy = z.infer; diff --git a/src/server/api/go/docs/docs.go b/src/server/api/go/docs/docs.go index d1444a4d..92f355c9 100644 --- a/src/server/api/go/docs/docs.go +++ b/src/server/api/go/docs/docs.go @@ -1086,7 +1086,7 @@ const docTemplate = `{ }, { "type": "integer", - "description": "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned.", + "description": "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned. \n\nWARNING!\n Use ` + "`" + `limit` + "`" + ` only for read-only/display purposes (pagination, viewing). Do NOT use ` + "`" + `limit` + "`" + ` to truncate messages before sending to LLM as it may cause tool-call and tool-result unpairing issues. Instead, use the ` + "`" + `token_limit` + "`" + ` edit strategy in ` + "`" + `edit_strategies` + "`" + ` parameter to safely manage message context size.", "name": "limit", "in": "query" }, diff --git a/src/server/api/go/docs/swagger.json b/src/server/api/go/docs/swagger.json index fa3a0d83..c924072e 100644 --- a/src/server/api/go/docs/swagger.json +++ b/src/server/api/go/docs/swagger.json @@ -1083,7 +1083,7 @@ }, { "type": "integer", - "description": "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned.", + "description": "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned. \n\nWARNING!\n Use `limit` only for read-only/display purposes (pagination, viewing). Do NOT use `limit` to truncate messages before sending to LLM as it may cause tool-call and tool-result unpairing issues. Instead, use the `token_limit` edit strategy in `edit_strategies` parameter to safely manage message context size.", "name": "limit", "in": "query" }, diff --git a/src/server/api/go/docs/swagger.yaml b/src/server/api/go/docs/swagger.yaml index 32965435..ae45e9d4 100644 --- a/src/server/api/go/docs/swagger.yaml +++ b/src/server/api/go/docs/swagger.yaml @@ -1399,8 +1399,12 @@ paths: name: session_id required: true type: string - - description: Limit of messages to return. Max 200. If limit is 0 or not provided, - all messages will be returned. + - description: "Limit of messages to return. Max 200. If limit is 0 or not provided, + all messages will be returned. \n\nWARNING!\n Use `limit` only for read-only/display + purposes (pagination, viewing). Do NOT use `limit` to truncate messages + before sending to LLM as it may cause tool-call and tool-result unpairing + issues. Instead, use the `token_limit` edit strategy in `edit_strategies` + parameter to safely manage message context size." in: query name: limit type: integer diff --git a/src/server/api/go/internal/modules/handler/session.go b/src/server/api/go/internal/modules/handler/session.go index dd827da9..99cb272b 100644 --- a/src/server/api/go/internal/modules/handler/session.go +++ b/src/server/api/go/internal/modules/handler/session.go @@ -480,7 +480,7 @@ type GetMessagesReq struct { // @Accept json // @Produce json // @Param session_id path string true "Session ID" format(uuid) -// @Param limit query integer false "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned." +// @Param limit query integer false "Limit of messages to return. Max 200. If limit is 0 or not provided, all messages will be returned. \n\nWARNING!\n Use `limit` only for read-only/display purposes (pagination, viewing). Do NOT use `limit` to truncate messages before sending to LLM as it may cause tool-call and tool-result unpairing issues. Instead, use the `token_limit` edit strategy in `edit_strategies` parameter to safely manage message context size." // @Param cursor query string false "Cursor for pagination. Use the cursor from the previous response to get the next page." // @Param with_asset_public_url query string false "Whether to return asset public url, default is true" example(true) // @Param format query string false "Format to convert messages to: acontext (original), openai (default), anthropic." enums(acontext,openai,anthropic) diff --git a/src/server/api/go/internal/pkg/editor/editor.go b/src/server/api/go/internal/pkg/editor/editor.go index 8dd5c9ba..d2247e2d 100644 --- a/src/server/api/go/internal/pkg/editor/editor.go +++ b/src/server/api/go/internal/pkg/editor/editor.go @@ -2,7 +2,7 @@ package editor import ( "fmt" - "strings" + "sort" "github.com/memodb-io/Acontext/internal/modules/model" ) @@ -19,112 +19,63 @@ type StrategyConfig struct { Params map[string]interface{} `json:"params"` } -// RemoveToolResultStrategy replaces old tool-result parts' text with a placeholder -type RemoveToolResultStrategy struct { - KeepRecentN int - Placeholder string -} - -// Name returns the strategy name -func (s *RemoveToolResultStrategy) Name() string { - return "remove_tool_result" -} - -// Apply replaces old tool-result parts' text with a placeholder -// Keeps the most recent N tool-result parts with their original content -func (s *RemoveToolResultStrategy) Apply(messages []model.Message) ([]model.Message, error) { - if s.KeepRecentN < 0 { - return nil, fmt.Errorf("keep_recent_n_tool_results must be >= 0, got %d", s.KeepRecentN) - } - - // First, collect all tool-result parts with their positions - type toolResultPosition struct { - messageIdx int - partIdx int - } - var toolResultPositions []toolResultPosition - - for msgIdx, msg := range messages { - for partIdx, part := range msg.Parts { - if part.Type == "tool-result" { - toolResultPositions = append(toolResultPositions, toolResultPosition{ - messageIdx: msgIdx, - partIdx: partIdx, - }) - } - } - } - - // Calculate how many to replace (all except the most recent KeepRecentN) - totalToolResults := len(toolResultPositions) - if totalToolResults <= s.KeepRecentN { - // Nothing to replace - return messages, nil - } - - numToReplace := totalToolResults - s.KeepRecentN - - // Use the placeholder text (defaults to "Done" if not set) - placeholder := s.Placeholder - if placeholder == "" { - placeholder = "Done" - } - - // Replace the text of the oldest tool-result parts - for i := 0; i < numToReplace; i++ { - pos := toolResultPositions[i] - messages[pos.messageIdx].Parts[pos.partIdx].Text = placeholder - } - - return messages, nil -} - // CreateStrategy creates a strategy from a config func CreateStrategy(config StrategyConfig) (EditStrategy, error) { switch config.Type { case "remove_tool_result": - // Default to keeping 3 most recent tool results if parameter not provided - keepRecentNInt := 3 - - if keepRecentN, ok := config.Params["keep_recent_n_tool_results"]; ok { - // Handle both float64 (from JSON unmarshaling) and int - switch v := keepRecentN.(type) { - case float64: - keepRecentNInt = int(v) - case int: - keepRecentNInt = v - default: - return nil, fmt.Errorf("keep_recent_n_tool_results must be an integer, got %T", keepRecentN) - } - } - - // Get placeholder text (defaults to "Done" if not provided) - placeholder := "Done" - if placeholderValue, ok := config.Params["tool_result_placeholder"]; ok { - if placeholderStr, ok := placeholderValue.(string); ok { - placeholder = strings.TrimSpace(placeholderStr) - } else { - return nil, fmt.Errorf("tool_result_placeholder must be a string, got %T", placeholderValue) - } - } - - return &RemoveToolResultStrategy{ - KeepRecentN: keepRecentNInt, - Placeholder: placeholder, - }, nil + return createRemoveToolResultStrategy(config.Params) + case "token_limit": + return createTokenLimitStrategy(config.Params) default: return nil, fmt.Errorf("unknown strategy type: %s", config.Type) } } -// ApplyStrategies applies multiple editing strategies in sequence +// getStrategyPriority returns the priority of a strategy type. +// Lower numbers are applied first, higher numbers are applied last. +// This ensures strategies are executed in an optimal order. +func getStrategyPriority(strategyType string) int { + switch strategyType { + case "remove_tool_result": + return 1 // Content reduction strategies go first + case "token_limit": + return 100 // Token limit always goes last + default: + return 50 // unmarked strategies go in the middle + } +} + +// sortStrategies sorts strategy configs by their priority. +// This ensures strategies are applied in the optimal order: +// 1. Content reduction strategies (e.g., remove_tool_result) +// 2. Other strategies +// 3. Token limit (always last) +func sortStrategies(configs []StrategyConfig) []StrategyConfig { + // Create a copy to avoid modifying the original slice + sorted := make([]StrategyConfig, len(configs)) + copy(sorted, configs) + + // Sort by priority + sort.SliceStable(sorted, func(i, j int) bool { + return getStrategyPriority(sorted[i].Type) < getStrategyPriority(sorted[j].Type) + }) + + return sorted +} + +// ApplyStrategies applies multiple editing strategies in sequence. +// Strategies are automatically sorted to ensure optimal execution order, +// with token_limit always applied last. func ApplyStrategies(messages []model.Message, configs []StrategyConfig) ([]model.Message, error) { if len(configs) == 0 { return messages, nil } + // Sort strategies to ensure optimal execution order + sortedConfigs := sortStrategies(configs) + result := messages - for _, config := range configs { + for _, config := range sortedConfigs { strategy, err := CreateStrategy(config) if err != nil { return nil, fmt.Errorf("failed to create strategy: %w", err) diff --git a/src/server/api/go/internal/pkg/editor/editor_test.go b/src/server/api/go/internal/pkg/editor/editor_test.go index 881f7baa..f1899187 100644 --- a/src/server/api/go/internal/pkg/editor/editor_test.go +++ b/src/server/api/go/internal/pkg/editor/editor_test.go @@ -8,306 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestRemoveToolResultStrategy_Apply(t *testing.T) { - t.Run("replace oldest tool results", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "text", Text: "What's the weather?"}, - }, - }, - { - Role: "assistant", - Parts: []model.Part{ - {Type: "tool-call", Meta: map[string]interface{}{"id": "call1", "name": "get_weather"}}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Sunny, 75°F", Meta: map[string]interface{}{"tool_call_id": "call1"}}, - }, - }, - { - Role: "assistant", - Parts: []model.Part{ - {Type: "tool-call", Meta: map[string]interface{}{"id": "call2", "name": "get_forecast"}}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Clear skies tomorrow", Meta: map[string]interface{}{"tool_call_id": "call2"}}, - }, - }, - { - Role: "assistant", - Parts: []model.Part{ - {Type: "tool-call", Meta: map[string]interface{}{"id": "call3", "name": "get_temperature"}}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Current temp: 75°F", Meta: map[string]interface{}{"tool_call_id": "call3"}}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 1} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - assert.Len(t, result, 7) - - // First tool result should be replaced - assert.Equal(t, "Done", result[2].Parts[0].Text) - assert.Equal(t, "tool-result", result[2].Parts[0].Type) - assert.NotNil(t, result[2].Parts[0].Meta) - - // Second tool result should be replaced - assert.Equal(t, "Done", result[4].Parts[0].Text) - assert.Equal(t, "tool-result", result[4].Parts[0].Type) - - // Most recent tool result should keep original text - assert.Equal(t, "Current temp: 75°F", result[6].Parts[0].Text) - assert.Equal(t, "tool-result", result[6].Parts[0].Type) - }) - - t.Run("keep all when KeepRecentN >= total", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 1"}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 2"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 5} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - // Both should keep original text - assert.Equal(t, "Result 1", result[0].Parts[0].Text) - assert.Equal(t, "Result 2", result[1].Parts[0].Text) - }) - - t.Run("replace all when KeepRecentN is 0", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 1"}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 2"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 0} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - // All should be replaced - assert.Equal(t, "Done", result[0].Parts[0].Text) - assert.Equal(t, "Done", result[1].Parts[0].Text) - }) - - t.Run("no tool results in messages", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "text", Text: "Hello"}, - }, - }, - { - Role: "assistant", - Parts: []model.Part{ - {Type: "text", Text: "Hi there"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 1} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - assert.Len(t, result, 2) - assert.Equal(t, "Hello", result[0].Parts[0].Text) - assert.Equal(t, "Hi there", result[1].Parts[0].Text) - }) - - t.Run("multiple parts with some tool-results", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "text", Text: "Question"}, - {Type: "tool-result", Text: "Old result"}, - }, - }, - { - Role: "assistant", - Parts: []model.Part{ - {Type: "text", Text: "Answer"}, - {Type: "tool-result", Text: "Recent result"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 1} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - // First part should remain unchanged - assert.Equal(t, "Question", result[0].Parts[0].Text) - // First tool-result should be replaced - assert.Equal(t, "Done", result[0].Parts[1].Text) - // Second message first part should remain unchanged - assert.Equal(t, "Answer", result[1].Parts[0].Text) - // Recent tool-result should keep original text - assert.Equal(t, "Recent result", result[1].Parts[1].Text) - }) - - t.Run("negative KeepRecentN returns error", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: -1} - _, err := strategy.Apply(messages) - - require.Error(t, err) - assert.Contains(t, err.Error(), "must be >= 0") - }) - - t.Run("custom placeholder text", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 1"}, - }, - }, - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 2"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 1, Placeholder: "Removed"} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - // First should be replaced with custom placeholder - assert.Equal(t, "Removed", result[0].Parts[0].Text) - // Second should keep original - assert.Equal(t, "Result 2", result[1].Parts[0].Text) - }) - - t.Run("empty placeholder defaults to Done", func(t *testing.T) { - messages := []model.Message{ - { - Role: "user", - Parts: []model.Part{ - {Type: "tool-result", Text: "Result 1"}, - }, - }, - } - - strategy := &RemoveToolResultStrategy{KeepRecentN: 0, Placeholder: ""} - result, err := strategy.Apply(messages) - - require.NoError(t, err) - assert.Equal(t, "Done", result[0].Parts[0].Text) - }) -} - func TestCreateStrategy(t *testing.T) { - t.Run("create remove_tool_result strategy", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "keep_recent_n_tool_results": float64(3), - }, - } - - strategy, err := CreateStrategy(config) - - require.NoError(t, err) - assert.NotNil(t, strategy) - assert.Equal(t, "remove_tool_result", strategy.Name()) - - rtr, ok := strategy.(*RemoveToolResultStrategy) - require.True(t, ok) - assert.Equal(t, 3, rtr.KeepRecentN) - }) - - t.Run("create with int parameter", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "keep_recent_n_tool_results": 5, - }, - } - - strategy, err := CreateStrategy(config) - - require.NoError(t, err) - rtr, ok := strategy.(*RemoveToolResultStrategy) - require.True(t, ok) - assert.Equal(t, 5, rtr.KeepRecentN) - }) - - t.Run("use default value when parameter not provided", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{}, - } - - strategy, err := CreateStrategy(config) - - require.NoError(t, err) - assert.NotNil(t, strategy) - rtr, ok := strategy.(*RemoveToolResultStrategy) - require.True(t, ok) - assert.Equal(t, 3, rtr.KeepRecentN, "should default to 3 when parameter not provided") - }) - - t.Run("invalid parameter type", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "keep_recent_n_tool_results": "invalid", - }, - } - - _, err := CreateStrategy(config) - - require.Error(t, err) - assert.Contains(t, err.Error(), "must be an integer") - }) - t.Run("unknown strategy type", func(t *testing.T) { config := StrategyConfig{ Type: "unknown_strategy", @@ -319,54 +20,6 @@ func TestCreateStrategy(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "unknown strategy type") }) - - t.Run("create with custom placeholder", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "keep_recent_n_tool_results": 5, - "tool_result_placeholder": "Cleared", - }, - } - - strategy, err := CreateStrategy(config) - - require.NoError(t, err) - rtr, ok := strategy.(*RemoveToolResultStrategy) - require.True(t, ok) - assert.Equal(t, 5, rtr.KeepRecentN) - assert.Equal(t, "Cleared", rtr.Placeholder) - }) - - t.Run("invalid placeholder type returns error", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "tool_result_placeholder": 123, // Not a string - }, - } - - _, err := CreateStrategy(config) - - require.Error(t, err) - assert.Contains(t, err.Error(), "tool_result_placeholder must be a string") - }) - - t.Run("placeholder string is trimmed", func(t *testing.T) { - config := StrategyConfig{ - Type: "remove_tool_result", - Params: map[string]interface{}{ - "tool_result_placeholder": " Trimmed ", - }, - } - - strategy, err := CreateStrategy(config) - - require.NoError(t, err) - rtr, ok := strategy.(*RemoveToolResultStrategy) - require.True(t, ok) - assert.Equal(t, "Trimmed", rtr.Placeholder, "should trim whitespace from placeholder") - }) } func TestApplyStrategies(t *testing.T) { diff --git a/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result.go b/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result.go new file mode 100644 index 00000000..150a87fc --- /dev/null +++ b/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result.go @@ -0,0 +1,101 @@ +package editor + +import ( + "fmt" + "strings" + + "github.com/memodb-io/Acontext/internal/modules/model" +) + +// RemoveToolResultStrategy replaces old tool-result parts' text with a placeholder +type RemoveToolResultStrategy struct { + KeepRecentN int + Placeholder string +} + +// Name returns the strategy name +func (s *RemoveToolResultStrategy) Name() string { + return "remove_tool_result" +} + +// Apply replaces old tool-result parts' text with a placeholder +// Keeps the most recent N tool-result parts with their original content +func (s *RemoveToolResultStrategy) Apply(messages []model.Message) ([]model.Message, error) { + if s.KeepRecentN < 0 { + return nil, fmt.Errorf("keep_recent_n_tool_results must be >= 0, got %d", s.KeepRecentN) + } + + // First, collect all tool-result parts with their positions + type toolResultPosition struct { + messageIdx int + partIdx int + } + var toolResultPositions []toolResultPosition + + for msgIdx, msg := range messages { + for partIdx, part := range msg.Parts { + if part.Type == "tool-result" { + toolResultPositions = append(toolResultPositions, toolResultPosition{ + messageIdx: msgIdx, + partIdx: partIdx, + }) + } + } + } + + // Calculate how many to replace (all except the most recent KeepRecentN) + totalToolResults := len(toolResultPositions) + if totalToolResults <= s.KeepRecentN { + // Nothing to replace + return messages, nil + } + + numToReplace := totalToolResults - s.KeepRecentN + + // Use the placeholder text (defaults to "Done" if not set) + placeholder := s.Placeholder + if placeholder == "" { + placeholder = "Done" + } + + // Replace the text of the oldest tool-result parts + for i := 0; i < numToReplace; i++ { + pos := toolResultPositions[i] + messages[pos.messageIdx].Parts[pos.partIdx].Text = placeholder + } + + return messages, nil +} + +// createRemoveToolResultStrategy creates a RemoveToolResultStrategy from config params +func createRemoveToolResultStrategy(params map[string]interface{}) (EditStrategy, error) { + // Default to keeping 3 most recent tool results if parameter not provided + keepRecentNInt := 3 + + if keepRecentN, ok := params["keep_recent_n_tool_results"]; ok { + // Handle both float64 (from JSON unmarshaling) and int + switch v := keepRecentN.(type) { + case float64: + keepRecentNInt = int(v) + case int: + keepRecentNInt = v + default: + return nil, fmt.Errorf("keep_recent_n_tool_results must be an integer, got %T", keepRecentN) + } + } + + // Get placeholder text (defaults to "Done" if not provided) + placeholder := "Done" + if placeholderValue, ok := params["tool_result_placeholder"]; ok { + if placeholderStr, ok := placeholderValue.(string); ok { + placeholder = strings.TrimSpace(placeholderStr) + } else { + return nil, fmt.Errorf("tool_result_placeholder must be a string, got %T", placeholderValue) + } + } + + return &RemoveToolResultStrategy{ + KeepRecentN: keepRecentNInt, + Placeholder: placeholder, + }, nil +} diff --git a/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result_test.go b/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result_test.go new file mode 100644 index 00000000..3622be99 --- /dev/null +++ b/src/server/api/go/internal/pkg/editor/strategy_remove_tool_result_test.go @@ -0,0 +1,358 @@ +package editor + +import ( + "testing" + + "github.com/memodb-io/Acontext/internal/modules/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoveToolResultStrategy_Apply(t *testing.T) { + t.Run("replace oldest tool results", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "What's the weather?"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call1", "name": "get_weather"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Sunny, 75°F", Meta: map[string]interface{}{"tool_call_id": "call1"}}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call2", "name": "get_forecast"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Clear skies tomorrow", Meta: map[string]interface{}{"tool_call_id": "call2"}}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call3", "name": "get_temperature"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Current temp: 75°F", Meta: map[string]interface{}{"tool_call_id": "call3"}}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 1} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Len(t, result, 7) + + // First tool result should be replaced + assert.Equal(t, "Done", result[2].Parts[0].Text) + assert.Equal(t, "tool-result", result[2].Parts[0].Type) + assert.NotNil(t, result[2].Parts[0].Meta) + + // Second tool result should be replaced + assert.Equal(t, "Done", result[4].Parts[0].Text) + assert.Equal(t, "tool-result", result[4].Parts[0].Type) + + // Most recent tool result should keep original text + assert.Equal(t, "Current temp: 75°F", result[6].Parts[0].Text) + assert.Equal(t, "tool-result", result[6].Parts[0].Type) + }) + + t.Run("keep all when KeepRecentN >= total", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 1"}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 2"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 5} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + // Both should keep original text + assert.Equal(t, "Result 1", result[0].Parts[0].Text) + assert.Equal(t, "Result 2", result[1].Parts[0].Text) + }) + + t.Run("replace all when KeepRecentN is 0", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 1"}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 2"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 0} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + // All should be replaced + assert.Equal(t, "Done", result[0].Parts[0].Text) + assert.Equal(t, "Done", result[1].Parts[0].Text) + }) + + t.Run("no tool results in messages", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Hello"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Hi there"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 1} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Len(t, result, 2) + assert.Equal(t, "Hello", result[0].Parts[0].Text) + assert.Equal(t, "Hi there", result[1].Parts[0].Text) + }) + + t.Run("multiple parts with some tool-results", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Question"}, + {Type: "tool-result", Text: "Old result"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Answer"}, + {Type: "tool-result", Text: "Recent result"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 1} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + // First part should remain unchanged + assert.Equal(t, "Question", result[0].Parts[0].Text) + // First tool-result should be replaced + assert.Equal(t, "Done", result[0].Parts[1].Text) + // Second message first part should remain unchanged + assert.Equal(t, "Answer", result[1].Parts[0].Text) + // Recent tool-result should keep original text + assert.Equal(t, "Recent result", result[1].Parts[1].Text) + }) + + t.Run("negative KeepRecentN returns error", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: -1} + _, err := strategy.Apply(messages) + + require.Error(t, err) + assert.Contains(t, err.Error(), "must be >= 0") + }) + + t.Run("custom placeholder text", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 1"}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 2"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 1, Placeholder: "Removed"} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + // First should be replaced with custom placeholder + assert.Equal(t, "Removed", result[0].Parts[0].Text) + // Second should keep original + assert.Equal(t, "Result 2", result[1].Parts[0].Text) + }) + + t.Run("empty placeholder defaults to Done", func(t *testing.T) { + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 1"}, + }, + }, + } + + strategy := &RemoveToolResultStrategy{KeepRecentN: 0, Placeholder: ""} + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Equal(t, "Done", result[0].Parts[0].Text) + }) +} + +func TestCreateRemoveToolResultStrategy(t *testing.T) { + t.Run("create remove_tool_result strategy", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "keep_recent_n_tool_results": float64(3), + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + assert.NotNil(t, strategy) + assert.Equal(t, "remove_tool_result", strategy.Name()) + + rtr, ok := strategy.(*RemoveToolResultStrategy) + require.True(t, ok) + assert.Equal(t, 3, rtr.KeepRecentN) + }) + + t.Run("create with int parameter", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "keep_recent_n_tool_results": 5, + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + rtr, ok := strategy.(*RemoveToolResultStrategy) + require.True(t, ok) + assert.Equal(t, 5, rtr.KeepRecentN) + }) + + t.Run("use default value when parameter not provided", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{}, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + assert.NotNil(t, strategy) + rtr, ok := strategy.(*RemoveToolResultStrategy) + require.True(t, ok) + assert.Equal(t, 3, rtr.KeepRecentN, "should default to 3 when parameter not provided") + }) + + t.Run("invalid parameter type", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "keep_recent_n_tool_results": "invalid", + }, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "must be an integer") + }) + + t.Run("create with custom placeholder", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "keep_recent_n_tool_results": 5, + "tool_result_placeholder": "Cleared", + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + rtr, ok := strategy.(*RemoveToolResultStrategy) + require.True(t, ok) + assert.Equal(t, 5, rtr.KeepRecentN) + assert.Equal(t, "Cleared", rtr.Placeholder) + }) + + t.Run("invalid placeholder type returns error", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "tool_result_placeholder": 123, // Not a string + }, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "tool_result_placeholder must be a string") + }) + + t.Run("placeholder string is trimmed", func(t *testing.T) { + config := StrategyConfig{ + Type: "remove_tool_result", + Params: map[string]interface{}{ + "tool_result_placeholder": " Trimmed ", + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + rtr, ok := strategy.(*RemoveToolResultStrategy) + require.True(t, ok) + assert.Equal(t, "Trimmed", rtr.Placeholder, "should trim whitespace from placeholder") + }) +} diff --git a/src/server/api/go/internal/pkg/editor/strategy_token_limit.go b/src/server/api/go/internal/pkg/editor/strategy_token_limit.go new file mode 100644 index 00000000..e044d86e --- /dev/null +++ b/src/server/api/go/internal/pkg/editor/strategy_token_limit.go @@ -0,0 +1,133 @@ +package editor + +import ( + "context" + "fmt" + + "github.com/memodb-io/Acontext/internal/modules/model" + "github.com/memodb-io/Acontext/internal/pkg/tokenizer" +) + +// TokenLimitStrategy removes oldest messages until total token count is within limit +type TokenLimitStrategy struct { + LimitTokens int +} + +// Name returns the strategy name +func (s *TokenLimitStrategy) Name() string { + return "token_limit" +} + +// Apply removes oldest messages until total token count is within the limit +// Maintains tool-call/tool-result pairing +func (s *TokenLimitStrategy) Apply(messages []model.Message) ([]model.Message, error) { + if s.LimitTokens <= 0 { + return nil, fmt.Errorf("limit_tokens must be > 0, got %d", s.LimitTokens) + } + + if len(messages) == 0 { + return messages, nil + } + + ctx := context.Background() + + // Count total tokens + totalTokens, err := tokenizer.CountMessagePartsTokens(ctx, messages) + if err != nil { + return nil, fmt.Errorf("failed to count tokens: %w", err) + } + + // If already within limit, return as-is + if totalTokens <= s.LimitTokens { + return messages, nil + } + + // Build a map of tool-call IDs to their corresponding tool-result message indices + // This allows O(1) lookup when we need to remove paired tool-results + toolCallIDToResultIndex := make(map[string]int) + for i, msg := range messages { + for _, part := range msg.Parts { + if part.Type == "tool-result" && part.Meta != nil { + if toolCallID, ok := part.Meta["tool_call_id"].(string); ok { + toolCallIDToResultIndex[toolCallID] = i + } + } + } + } + + // Mark messages to remove, starting from the oldest + toRemove := make(map[int]bool) + + // Remove messages one by one until we're within the limit + for i := 0; i < len(messages) && totalTokens > s.LimitTokens; i++ { + if toRemove[i] { + continue // Already marked for removal + } + + // Count tokens for this message + msgTokens, err := tokenizer.CountSingleMessageTokens(ctx, messages[i]) + if err != nil { + return nil, fmt.Errorf("failed to count tokens for message %d: %w", i, err) + } + + // Mark this message for removal + toRemove[i] = true + totalTokens -= msgTokens + + // Check if this message has tool-call parts and remove corresponding tool-results + for _, part := range messages[i].Parts { + if part.Type == "tool-call" && part.Meta != nil { + if id, ok := part.Meta["id"].(string); ok { + // Use the map to find the corresponding tool-result message (O(1) lookup) + if resultIdx, found := toolCallIDToResultIndex[id]; found && !toRemove[resultIdx] { + // Mark the tool-result message for removal + resultTokens, err := tokenizer.CountSingleMessageTokens(ctx, messages[resultIdx]) + if err != nil { + return nil, fmt.Errorf("failed to count tokens for message %d: %w", resultIdx, err) + } + toRemove[resultIdx] = true + totalTokens -= resultTokens + } + } + } + } + } + + // Build the result by excluding removed messages + result := make([]model.Message, 0, len(messages)-len(toRemove)) + for i, msg := range messages { + if !toRemove[i] { + result = append(result, msg) + } + } + + return result, nil +} + +// createTokenLimitStrategy creates a TokenLimitStrategy from config params +func createTokenLimitStrategy(params map[string]interface{}) (EditStrategy, error) { + // Extract limit_tokens parameter (required) + limitTokens, ok := params["limit_tokens"] + if !ok { + return nil, fmt.Errorf("token_limit strategy requires 'limit_tokens' parameter") + } + + var limitTokensInt int + // Handle both float64 (from JSON unmarshaling) and int + switch v := limitTokens.(type) { + case float64: + limitTokensInt = int(v) + case int: + limitTokensInt = v + default: + return nil, fmt.Errorf("limit_tokens must be an integer, got %T", limitTokens) + } + + if limitTokensInt <= 0 { + return nil, fmt.Errorf("limit_tokens must be > 0, got %d", limitTokensInt) + } + + return &TokenLimitStrategy{ + LimitTokens: limitTokensInt, + }, nil +} diff --git a/src/server/api/go/internal/pkg/editor/strategy_token_limit_test.go b/src/server/api/go/internal/pkg/editor/strategy_token_limit_test.go new file mode 100644 index 00000000..519ab23d --- /dev/null +++ b/src/server/api/go/internal/pkg/editor/strategy_token_limit_test.go @@ -0,0 +1,591 @@ +package editor + +import ( + "context" + "testing" + + "github.com/memodb-io/Acontext/internal/modules/model" + "github.com/memodb-io/Acontext/internal/pkg/tokenizer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +// initTokenizer is a helper to initialize the tokenizer for tests +func initTokenizer(t *testing.T) { + t.Helper() + log := zaptest.NewLogger(t) + err := tokenizer.Init(log) + require.NoError(t, err, "failed to initialize tokenizer") +} + +// TestCreateTokenLimitStrategy tests the factory function for TokenLimitStrategy +func TestCreateTokenLimitStrategy(t *testing.T) { + t.Run("create with valid parameters", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{ + "limit_tokens": float64(1000), + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + assert.NotNil(t, strategy) + assert.Equal(t, "token_limit", strategy.Name()) + + tls, ok := strategy.(*TokenLimitStrategy) + require.True(t, ok) + assert.Equal(t, 1000, tls.LimitTokens) + }) + + t.Run("create with int parameter", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{ + "limit_tokens": 2000, + }, + } + + strategy, err := CreateStrategy(config) + + require.NoError(t, err) + tls, ok := strategy.(*TokenLimitStrategy) + require.True(t, ok) + assert.Equal(t, 2000, tls.LimitTokens) + }) + + t.Run("missing limit_tokens parameter", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{}, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires 'limit_tokens' parameter") + }) + + t.Run("invalid parameter type", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{ + "limit_tokens": "invalid", + }, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "must be an integer") + }) + + t.Run("zero limit_tokens", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{ + "limit_tokens": 0, + }, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "must be > 0") + }) + + t.Run("negative limit_tokens", func(t *testing.T) { + config := StrategyConfig{ + Type: "token_limit", + Params: map[string]interface{}{ + "limit_tokens": -100, + }, + } + + _, err := CreateStrategy(config) + + require.Error(t, err) + assert.Contains(t, err.Error(), "must be > 0") + }) +} + +// TestTokenLimitStrategy_EmptyMessages tests handling of empty message arrays +func TestTokenLimitStrategy_EmptyMessages(t *testing.T) { + t.Run("empty messages array", func(t *testing.T) { + strategy := &TokenLimitStrategy{LimitTokens: 1000} + messages := []model.Message{} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Empty(t, result) + assert.Len(t, result, 0) + }) + + t.Run("nil messages array", func(t *testing.T) { + strategy := &TokenLimitStrategy{LimitTokens: 1000} + var messages []model.Message + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Nil(t, result) + }) +} + +// TestTokenLimitStrategy_MessagesWithinLimit tests that messages under the limit are unchanged +func TestTokenLimitStrategy_MessagesWithinLimit(t *testing.T) { + t.Run("small messages under high limit", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Hello"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Hi there!"}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "How are you?"}, + }, + }, + } + + // Count actual tokens + ctx := context.Background() + actualTokens, err := tokenizer.CountMessagePartsTokens(ctx, messages) + require.NoError(t, err) + + // Set limit well above actual token count + strategy := &TokenLimitStrategy{LimitTokens: actualTokens + 1000} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Len(t, result, len(messages), "all messages should be kept") + assert.Equal(t, messages, result, "messages should be unchanged") + }) + + t.Run("messages exactly at limit", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Testing exact boundary"}, + }, + }, + } + + // Count actual tokens and set limit to exact amount + ctx := context.Background() + actualTokens, err := tokenizer.CountMessagePartsTokens(ctx, messages) + require.NoError(t, err) + + strategy := &TokenLimitStrategy{LimitTokens: actualTokens} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Len(t, result, len(messages), "all messages should be kept when exactly at limit") + assert.Equal(t, messages, result) + }) +} + +// TestTokenLimitStrategy_MessagesExceedingLimit tests that oldest messages are removed when limit exceeded +func TestTokenLimitStrategy_MessagesExceedingLimit(t *testing.T) { + t.Run("remove oldest messages to get under limit", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "First message - this should be removed"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Second message - this should be removed too"}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Third message - this should be kept"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Fourth message - this should be kept as well"}, + }, + }, + } + + // Count tokens for messages we want to keep (last 2) + ctx := context.Background() + messagesToKeep := messages[2:] + tokensToKeep, err := tokenizer.CountMessagePartsTokens(ctx, messagesToKeep) + require.NoError(t, err) + + // Set limit to keep only last 2 messages (with small buffer) + strategy := &TokenLimitStrategy{LimitTokens: tokensToKeep + 5} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Less(t, len(result), len(messages), "some messages should be removed") + + // Verify result is under token limit + resultTokens, err := tokenizer.CountMessagePartsTokens(ctx, result) + require.NoError(t, err) + assert.LessOrEqual(t, resultTokens, strategy.LimitTokens, "result should be under token limit") + + // Verify the oldest messages were removed (check by content) + if len(result) > 0 { + assert.NotContains(t, result[0].Parts[0].Text, "First message", "oldest message should be removed") + } + }) + + t.Run("remove multiple messages when needed", func(t *testing.T) { + initTokenizer(t) + + // Create many messages + messages := []model.Message{ + {Role: "user", Parts: []model.Part{{Type: "text", Text: "Message 1 - remove"}}}, + {Role: "assistant", Parts: []model.Part{{Type: "text", Text: "Message 2 - remove"}}}, + {Role: "user", Parts: []model.Part{{Type: "text", Text: "Message 3 - remove"}}}, + {Role: "assistant", Parts: []model.Part{{Type: "text", Text: "Message 4 - remove"}}}, + {Role: "user", Parts: []model.Part{{Type: "text", Text: "Message 5 - keep"}}}, + {Role: "assistant", Parts: []model.Part{{Type: "text", Text: "Message 6 - keep"}}}, + } + + // Count tokens for last message only + ctx := context.Background() + lastMessage := messages[len(messages)-1:] + tokensForLast, err := tokenizer.CountMessagePartsTokens(ctx, lastMessage) + require.NoError(t, err) + + // Set a very low limit to force removal of most messages + strategy := &TokenLimitStrategy{LimitTokens: tokensForLast + 10} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + assert.Less(t, len(result), len(messages), "multiple messages should be removed") + + // Verify result is under limit + resultTokens, err := tokenizer.CountMessagePartsTokens(ctx, result) + require.NoError(t, err) + assert.LessOrEqual(t, resultTokens, strategy.LimitTokens) + }) + + t.Run("very low limit removes all or nearly all messages", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + {Role: "user", Parts: []model.Part{{Type: "text", Text: "This is a relatively long message that will definitely exceed a very small token limit"}}}, + {Role: "assistant", Parts: []model.Part{{Type: "text", Text: "Another message"}}}, + } + + // Set an extremely low limit + strategy := &TokenLimitStrategy{LimitTokens: 5} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + // Result should have very few or no messages + assert.LessOrEqual(t, len(result), len(messages)) + + // If there are any messages, verify under limit + if len(result) > 0 { + ctx := context.Background() + resultTokens, err := tokenizer.CountMessagePartsTokens(ctx, result) + require.NoError(t, err) + assert.LessOrEqual(t, resultTokens, strategy.LimitTokens) + } + }) +} + +// TestTokenLimitStrategy_ToolCallPairing tests that tool-call and tool-result pairs are removed together +func TestTokenLimitStrategy_ToolCallPairing(t *testing.T) { + t.Run("remove tool-call with its paired tool-result", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "What's the weather?"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call_123", "name": "get_weather"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Sunny, 75°F", Meta: map[string]interface{}{"tool_call_id": "call_123"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Thank you!"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "You're welcome!"}, + }, + }, + } + + // Count tokens for last 2 messages only + ctx := context.Background() + lastTwo := messages[3:] + tokensForLastTwo, err := tokenizer.CountMessagePartsTokens(ctx, lastTwo) + require.NoError(t, err) + + // Set limit to keep only last 2 messages, forcing removal of tool-call pair + strategy := &TokenLimitStrategy{LimitTokens: tokensForLastTwo + 5} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + + // Verify neither the tool-call nor its result are in the output + hasToolCall := false + hasToolResult := false + for _, msg := range result { + for _, part := range msg.Parts { + if part.Type == "tool-call" { + if meta, ok := part.Meta["id"].(string); ok && meta == "call_123" { + hasToolCall = true + } + } + if part.Type == "tool-result" { + if meta, ok := part.Meta["tool_call_id"].(string); ok && meta == "call_123" { + hasToolResult = true + } + } + } + } + + assert.False(t, hasToolCall, "tool-call should be removed") + assert.False(t, hasToolResult, "tool-result should be removed with its tool-call") + }) + + t.Run("multiple tool-call pairs removed together", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call_1", "name": "tool1"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 1", Meta: map[string]interface{}{"tool_call_id": "call_1"}}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call_2", "name": "tool2"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result 2", Meta: map[string]interface{}{"tool_call_id": "call_2"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Final message to keep"}, + }, + }, + } + + // Count tokens for last message only + ctx := context.Background() + lastMessage := messages[4:] + tokensForLast, err := tokenizer.CountMessagePartsTokens(ctx, lastMessage) + require.NoError(t, err) + + // Set very low limit to remove tool pairs + strategy := &TokenLimitStrategy{LimitTokens: tokensForLast + 10} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + + // Verify no tool-calls or tool-results remain + for _, msg := range result { + for _, part := range msg.Parts { + assert.NotEqual(t, "tool-call", part.Type, "all tool-calls should be removed") + assert.NotEqual(t, "tool-result", part.Type, "all tool-results should be removed") + } + } + }) + + t.Run("assistant message with multiple tool-calls", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "assistant", + Parts: []model.Part{ + {Type: "tool-call", Meta: map[string]interface{}{"id": "call_a", "name": "tool_a"}}, + {Type: "tool-call", Meta: map[string]interface{}{"id": "call_b", "name": "tool_b"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result A", Meta: map[string]interface{}{"tool_call_id": "call_a"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "tool-result", Text: "Result B", Meta: map[string]interface{}{"tool_call_id": "call_b"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Keep this message"}, + }, + }, + } + + // Count tokens for last message + ctx := context.Background() + lastMessage := messages[3:] + tokensForLast, err := tokenizer.CountMessagePartsTokens(ctx, lastMessage) + require.NoError(t, err) + + // Set limit to keep only last message + strategy := &TokenLimitStrategy{LimitTokens: tokensForLast + 5} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + + // When assistant message with multiple tool-calls is removed, + // both tool-results should also be removed + hasCallA := false + hasCallB := false + hasResultA := false + hasResultB := false + + for _, msg := range result { + for _, part := range msg.Parts { + if part.Meta != nil { + if id, ok := part.Meta["id"].(string); ok { + if id == "call_a" { + hasCallA = true + } + if id == "call_b" { + hasCallB = true + } + } + if id, ok := part.Meta["tool_call_id"].(string); ok { + if id == "call_a" { + hasResultA = true + } + if id == "call_b" { + hasResultB = true + } + } + } + } + } + + assert.False(t, hasCallA, "tool-call A should be removed") + assert.False(t, hasCallB, "tool-call B should be removed") + assert.False(t, hasResultA, "tool-result A should be removed with its call") + assert.False(t, hasResultB, "tool-result B should be removed with its call") + }) + + t.Run("orphaned tool-result without matching call", func(t *testing.T) { + initTokenizer(t) + + messages := []model.Message{ + { + Role: "user", + Parts: []model.Part{ + // Orphaned tool-result (no matching tool-call in messages) + {Type: "tool-result", Text: "Orphaned result with some text to make it have tokens", Meta: map[string]interface{}{"tool_call_id": "nonexistent"}}, + }, + }, + { + Role: "user", + Parts: []model.Part{ + {Type: "text", Text: "Second message with content"}, + }, + }, + { + Role: "assistant", + Parts: []model.Part{ + {Type: "text", Text: "Final message"}, + }, + }, + } + + // Count tokens for last message only + ctx := context.Background() + lastMessage := messages[2:] + tokensForLast, err := tokenizer.CountMessagePartsTokens(ctx, lastMessage) + require.NoError(t, err) + + // Set limit to keep only last message, forcing removal of first two + strategy := &TokenLimitStrategy{LimitTokens: tokensForLast + 2} + + result, err := strategy.Apply(messages) + + require.NoError(t, err) + + // Orphaned tool-result should be removable independently (not kept just because it's a tool-result) + hasOrphanedResult := false + for _, msg := range result { + for _, part := range msg.Parts { + if part.Type == "tool-result" { + if part.Meta != nil { + if meta, ok := part.Meta["tool_call_id"].(string); ok && meta == "nonexistent" { + hasOrphanedResult = true + } + } + } + } + } + + assert.False(t, hasOrphanedResult, "orphaned tool-result can be removed independently") + + // Verify we kept fewer messages + assert.Less(t, len(result), len(messages), "some messages should be removed") + }) +} diff --git a/src/server/api/go/internal/pkg/tokenizer/tokenizer.go b/src/server/api/go/internal/pkg/tokenizer/tokenizer.go index bd3fcc0b..e10b5b16 100644 --- a/src/server/api/go/internal/pkg/tokenizer/tokenizer.go +++ b/src/server/api/go/internal/pkg/tokenizer/tokenizer.go @@ -80,23 +80,35 @@ func ExtractTextAndToolContent(parts []model.Part) (string, error) { return content.String(), nil } +// CountSingleMessageTokens counts tokens for a single message +func CountSingleMessageTokens(ctx context.Context, message model.Message) (int, error) { + content, err := ExtractTextAndToolContent(message.Parts) + if err != nil { + return 0, fmt.Errorf("failed to extract content from message %s: %w", message.ID, err) + } + + if content == "" { + return 0, nil + } + + count, err := CountTokens(content) + if err != nil { + return 0, fmt.Errorf("failed to count tokens for message %s: %w", message.ID, err) + } + + return count, nil +} + // CountMessagePartsTokens counts tokens for all text and tool-call parts in messages func CountMessagePartsTokens(ctx context.Context, messages []model.Message) (int, error) { totalTokens := 0 for _, msg := range messages { - content, err := ExtractTextAndToolContent(msg.Parts) + count, err := CountSingleMessageTokens(ctx, msg) if err != nil { - return 0, fmt.Errorf("failed to extract content from message %s: %w", msg.ID, err) - } - - if content != "" { - count, err := CountTokens(content) - if err != nil { - return 0, fmt.Errorf("failed to count tokens for message %s: %w", msg.ID, err) - } - totalTokens += count + return 0, err } + totalTokens += count } return totalTokens, nil