diff --git a/Sources/Agent/chat/openaiChat.swift b/Sources/Agent/chat/openaiChat.swift index 42577dd..0813866 100644 --- a/Sources/Agent/chat/openaiChat.swift +++ b/Sources/Agent/chat/openaiChat.swift @@ -219,6 +219,13 @@ public struct OpenAIUserMessage: Hashable, Codable, Sendable { self.content = try container.decode(String.self, forKey: .content) self.createdAt = try container.decodeIfPresent(Date.self, forKey: .createdAt) ?? Date() } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(role, forKey: .role) + try container.encode(content, forKey: .content) + // Exclude id and createdAt - not part of OpenAI API spec + } } public struct OpenAIAssistantMessage: Hashable, Codable, Sendable { @@ -295,6 +302,14 @@ public struct OpenAIAssistantMessage: Hashable, Codable, Sendable { reasoningDetails: reasoningDetails ) } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(role, forKey: .role) + try container.encodeIfPresent(content, forKey: .content) + try container.encodeIfPresent(toolCalls, forKey: .toolCalls) + // Exclude id, audio, reasoning, reasoningDetails - not part of request spec + } } public struct OpenAISystemMessage: Hashable, Codable, Sendable { @@ -319,6 +334,13 @@ public struct OpenAISystemMessage: Hashable, Codable, Sendable { self.role = try container.decodeIfPresent(OpenAIRole.self, forKey: .role) ?? .system self.content = try container.decode(String.self, forKey: .content) } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(role, forKey: .role) + try container.encode(content, forKey: .content) + // Exclude id - not part of OpenAI API spec + } } public struct OpenAIToolMessage: Hashable, Codable, Sendable { @@ -351,6 +373,15 @@ public struct OpenAIToolMessage: Hashable, Codable, Sendable { self.toolCallId = try container.decode(String.self, forKey: .toolCallId) self.name = try container.decodeIfPresent(String.self, forKey: .name) } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(role, forKey: .role) + try container.encode(content, forKey: .content) + try container.encode(toolCallId, forKey: .toolCallId) + try container.encodeIfPresent(name, forKey: .name) + // Exclude id - not part of OpenAI API spec + } } public enum OpenAIMessage: Hashable, Codable, Sendable { diff --git a/Sources/AgentLayout/ChatProvider.swift b/Sources/AgentLayout/ChatProvider.swift index 812c988..1aa3775 100644 --- a/Sources/AgentLayout/ChatProvider.swift +++ b/Sources/AgentLayout/ChatProvider.swift @@ -158,7 +158,7 @@ public class ChatProvider: ChatProviderProtocol { encoder.outputFormatting = .prettyPrinted let resultString: String if let data = try? encoder.encode(AnyEncodable(result)), - let jsonString = String(data: data, encoding: .utf8) + let jsonString = String(data: data, encoding: .utf8) { resultString = jsonString } else { @@ -171,8 +171,8 @@ public class ChatProvider: ChatProviderProtocol { if let chat = self.chat { for message in chat.messages.reversed() { if case .openai(let openAIMsg) = message, - case .assistant(let assistant) = openAIMsg, - let toolCalls = assistant.toolCalls + case .assistant(let assistant) = openAIMsg, + let toolCalls = assistant.toolCalls { if let toolCall = toolCalls.first(where: { $0.id == id }) { toolName = toolCall.function?.name @@ -341,7 +341,8 @@ public class ChatProvider: ChatProviderProtocol { public func send(_ message: String) { guard generationTask == nil else { return } - guard var chat = chat, let currentSource = currentSource, let currentModel = currentModel else { return } + guard var chat = chat, let currentSource = currentSource, let currentModel = currentModel + else { return } let userMsg = Message.openai(.user(.init(content: message))) chat.messages.append(userMsg) @@ -353,14 +354,19 @@ public class ChatProvider: ChatProviderProtocol { self?.scrollToBottom?() } - let source = currentSource - let model = currentModel + startGeneration(source: currentSource, model: currentModel, userMessage: message) + } + /// Internal method to start generation without adding a user message. + /// Used by both `send` and `regenerate`. + private func startGeneration(source: Source, model: Model, userMessage: String? = nil) { generationTask = Task { [weak self] in guard let self = self else { return } self.status = .loading - try? await self.sendMessage(message: message) + if let message = userMessage { + try? await self.sendMessage(message: message) + } do { var messagesToSend = self.chat?.messages ?? [] @@ -493,7 +499,9 @@ public class ChatProvider: ChatProviderProtocol { guard generationTask == nil else { return } guard let chat = chat else { return } guard let index = chat.messages.firstIndex(where: { $0.id == messageId }) else { return } + guard let currentSource = currentSource, let currentModel = currentModel else { return } + // Find the user message content before the target message var userMessageContent: String? = nil for i in stride(from: index - 1, through: 0, by: -1) { if case .openai(let openAIMsg) = chat.messages[i], @@ -504,11 +512,14 @@ public class ChatProvider: ChatProviderProtocol { } } - guard let content = userMessageContent else { return } + guard userMessageContent != nil else { return } + // Remove the target message and all subsequent messages self.chat?.messages.removeSubrange(index...) notifyMessageChange() - send(content) + + // Start generation without adding a new user message + startGeneration(source: currentSource, model: currentModel) } public func cancel() { @@ -624,4 +635,13 @@ public class ChatProvider: ChatProviderProtocol { public func updateSystemPrompt(_ newSystemPrompt: String?) { self.systemPrompt = newSystemPrompt } + + /** + Regenerates the conversation starting at the given message. + This will remove all messages after the given message and regenerate the conversation from the given message. + - Parameter message: The message to start regenerating from. + */ + public func regenerate(startsAt message: Message) async throws { + + } } diff --git a/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift b/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift new file mode 100644 index 0000000..101093b --- /dev/null +++ b/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift @@ -0,0 +1,434 @@ +// +// ChatProviderRegenerateTests.swift +// AgentLayoutTests +// +// Tests for ChatProvider.regenerate function +// + +import Foundation +import SwiftUI +import Testing +import Vapor +import XCTest + +@testable import Agent +@testable import AgentLayout + +// MARK: - Shared Mock Server for Regenerate Tests + +@MainActor +final class RegenerateSharedMockServer { + static let shared = RegenerateSharedMockServer() + + private var app: Application? + private var isRunning = false + let controller = RegenerateMockOpenAIChatController() + private(set) var port: Int = 0 + + private init() {} + + func ensureRunning() async throws { + guard !isRunning else { return } + + // Try random ports with retry, creating fresh app each time + var lastError: Error? + for _ in 0..<10 { + // Create application with empty arguments to avoid Vapor parsing test framework args + let env = Environment(name: "testing", arguments: ["vapor"]) + let application = try await Application.make(env) + let randomPort = Int.random(in: 10000...60000) + application.http.server.configuration.port = randomPort + + controller.registerRoutes(on: application) + + do { + try await application.startup() + self.port = randomPort + self.app = application + self.isRunning = true + return + } catch { + lastError = error + // Shutdown the failed application before trying again + try? await application.asyncShutdown() + continue + } + } + + throw lastError ?? NSError( + domain: "TestError", code: 1, + userInfo: [NSLocalizedDescriptionKey: "Failed to find available port"]) + } + + func shutdown() async throws { + if let app = app { + try await app.asyncShutdown() + } + app = nil + isRunning = false + } +} + +// MARK: - Helper Functions + +@MainActor +private func createUserMessage(_ content: String, id: String? = nil) -> Message { + Message.openai(.user(.init(id: id ?? UUID().uuidString, content: content))) +} + +@MainActor +private func createAssistantMessage(_ content: String, id: String? = nil) -> Message { + Message.openai( + .assistant( + .init( + id: id ?? UUID().uuidString, + content: content, + toolCalls: nil, + audio: nil + ))) +} + +@MainActor +private func createChat(messages: [Message]) -> Chat { + Chat(id: UUID(), gameId: "test", messages: messages) +} + +@MainActor +private func waitForGenerationComplete( + provider: ChatProvider, + timeout: TimeInterval = 5.0 +) async throws { + let start = Date() + + // Wait a bit for the status to transition to loading + try await Task.sleep(nanoseconds: 50_000_000) // 50ms + + // Then wait for it to become idle again + while provider.status == .loading { + if Date().timeIntervalSince(start) > timeout { + throw NSError( + domain: "TestError", code: 1, + userInfo: [NSLocalizedDescriptionKey: "Timeout waiting for generation"]) + } + try await Task.sleep(nanoseconds: 100_000_000) // 100ms + } +} + +// MARK: - Swift Testing Suite for Regenerate Tests + +@MainActor +@Suite("ChatProvider Regenerate Tests", .serialized) +struct ChatProviderRegenerateTests { + + init() async throws { + try await RegenerateSharedMockServer.shared.ensureRunning() + } + + private func createSource() -> Source { + let port = RegenerateSharedMockServer.shared.port + return Source.openAI( + client: OpenAIClient(apiKey: "test", baseURL: URL(string: "http://localhost:\(port)")!), + models: [] + ) + } + + private func createModel() -> Model { + Model.custom(CustomModel(id: "gpt-4")) + } + + private var controller: RegenerateMockOpenAIChatController { + RegenerateSharedMockServer.shared.controller + } + + // MARK: - Test: Regenerate First Assistant Message + + @Test("Regenerating first assistant message replaces it with new response") + func testRegenerateFirstAssistantMessage() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let userMsg = createUserMessage("Hello") + let assistantMsgId = UUID().uuidString + let assistantMsg = createAssistantMessage("Hi there!", id: assistantMsgId) + let chat = createChat(messages: [userMsg, assistantMsg]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 2) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "Hello! How can I help?", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: assistantMsgId) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 2 messages: original user + new assistant + #expect(provider.messages.count == 2) + + // User message should be preserved + #expect(provider.messages[0].id == userMsg.id) + + // New assistant message should have new content and different ID + if case .openai(let openAIMsg) = provider.messages[1], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "Hello! How can I help?") + #expect(newAssistant.id != assistantMsgId) + } else { + Issue.record("Expected assistant message at index 1") + } + } + + // MARK: - Test: Regenerate Last Message + + @Test("Regenerating last message in multi-turn conversation") + func testRegenerateLastMessage() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let user1 = createUserMessage("Hello") + let assistant1 = createAssistantMessage("Hi there!") + let user2 = createUserMessage("How are you?") + let assistant2Id = UUID().uuidString + let assistant2 = createAssistantMessage("I'm doing great!", id: assistant2Id) + + let chat = createChat(messages: [user1, assistant1, user2, assistant2]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 4) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "I'm doing fantastic!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: assistant2Id) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 4 messages + #expect(provider.messages.count == 4) + + // First 3 messages should be preserved + #expect(provider.messages[0].id == user1.id) + #expect(provider.messages[1].id == assistant1.id) + #expect(provider.messages[2].id == user2.id) + + // Last message should have new content + if case .openai(let openAIMsg) = provider.messages[3], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "I'm doing fantastic!") + } else { + Issue.record("Expected assistant message at index 3") + } + } + + // MARK: - Test: Regenerate Middle Message Deletes Subsequent + + @Test("Regenerating middle message deletes all subsequent messages") + func testRegenerateMiddleMessageDeletesSubsequent() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let user1 = createUserMessage("Hello") + let assistant1Id = UUID().uuidString + let assistant1 = createAssistantMessage("Hi there!", id: assistant1Id) + let user2 = createUserMessage("How are you?") + let assistant2 = createAssistantMessage("I'm doing great!") + + let chat = createChat(messages: [user1, assistant1, user2, assistant2]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 4) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "New response!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: assistant1Id) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 2 messages: user1 + new assistant + #expect(provider.messages.count == 2) + + // User1 should be preserved + #expect(provider.messages[0].id == user1.id) + + // user2 and assistant2 should be deleted + let messageIds = provider.messages.map { $0.id } + #expect(!messageIds.contains(user2.id)) + #expect(!messageIds.contains(assistant2.id)) + + // New assistant message should have new content + if case .openai(let openAIMsg) = provider.messages[1], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "New response!") + } else { + Issue.record("Expected assistant message at index 1") + } + } + + // MARK: - Test: No Duplicate Messages + + @Test("Regeneration produces exactly one message, no duplicates") + func testRegenerateNoDuplicates() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let userMsg = createUserMessage("Hello") + let assistantMsgId = UUID().uuidString + let assistantMsg = createAssistantMessage("Hi there!", id: assistantMsgId) + let chat = createChat(messages: [userMsg, assistantMsg]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "New response!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: assistantMsgId) + try await waitForGenerationComplete(provider: provider) + + // Count assistant messages + let assistantMessages = provider.messages.filter { msg in + if case .openai(let openAIMsg) = msg, + case .assistant = openAIMsg + { + return true + } + return false + } + + #expect(assistantMessages.count == 1, "Should have exactly one assistant message") + #expect(provider.messages.count == 2, "Should have exactly 2 messages total") + + // Verify no duplicate IDs + let allIds = provider.messages.map { $0.id } + let uniqueIds = Set(allIds) + #expect(allIds.count == uniqueIds.count, "Should have no duplicate message IDs") + } + + // MARK: - Test: onMessageChange Callback + + @Test("Regenerate calls onMessageChange callback") + func testRegenerateCallsOnMessageChange() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + var messageChangeCallCount = 0 + + let userMsg = createUserMessage("Hello") + let assistantMsgId = UUID().uuidString + let assistantMsg = createAssistantMessage("Hi there!", id: assistantMsgId) + let chat = createChat(messages: [userMsg, assistantMsg]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source, + onMessageChange: { _ in messageChangeCallCount += 1 } + ) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "New!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: assistantMsgId) + try await waitForGenerationComplete(provider: provider) + + #expect(messageChangeCallCount >= 1, "onMessageChange should be called at least once") + } + + // MARK: - Guard Tests + + @Test("Regenerate does nothing when chat is nil") + func testRegenerateNoChat() async throws { + let provider = ChatProvider() + // Don't call setup - chat will be nil + + provider.regenerate(messageId: "some-id") + + #expect(provider.messages.isEmpty) + #expect(provider.status == .idle) + } + + @Test("Regenerate does nothing with invalid message ID") + func testRegenerateInvalidMessageId() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let userMsg = createUserMessage("Hello") + let assistantMsg = createAssistantMessage("Hi there!") + let chat = createChat(messages: [userMsg, assistantMsg]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + let originalMessageCount = provider.messages.count + let originalUserMsgId = provider.messages[0].id + let originalAssistantMsgId = provider.messages[1].id + + provider.regenerate(messageId: "non-existent-id") + + // Messages should be unchanged + #expect(provider.messages.count == originalMessageCount) + #expect(provider.messages[0].id == originalUserMsgId) + #expect(provider.messages[1].id == originalAssistantMsgId) + #expect(provider.status == .idle) + } + + @Test("Regenerate does nothing when no user message found before target") + func testRegenerateNoPriorUserMessage() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + // Edge case: only assistant messages (unusual but possible) + let assistant1 = createAssistantMessage("Message 1") + let assistant2Id = UUID().uuidString + let assistant2 = createAssistantMessage("Message 2", id: assistant2Id) + let chat = createChat(messages: [assistant1, assistant2]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + let originalCount = provider.messages.count + + provider.regenerate(messageId: assistant2Id) + + // Messages should be unchanged since no user message was found + #expect(provider.messages.count == originalCount) + #expect(provider.status == .idle) + } +} diff --git a/Tests/AgentLayoutTests/Mocks/RegenerateMockOpenAIChatController.swift b/Tests/AgentLayoutTests/Mocks/RegenerateMockOpenAIChatController.swift new file mode 100644 index 0000000..52de709 --- /dev/null +++ b/Tests/AgentLayoutTests/Mocks/RegenerateMockOpenAIChatController.swift @@ -0,0 +1,77 @@ +// +// RegenerateMockOpenAIChatController.swift +// AgentLayoutTests +// +// Created for testing ChatProvider.regenerate +// + +import Foundation +import Vapor + +@testable import Agent + +/// A controller that mocks OpenAI chat completion API responses for regenerate tests +@MainActor +class RegenerateMockOpenAIChatController { + private var mockResponseQueue: [[OpenAIAssistantMessage]] + + init() { + self.mockResponseQueue = [] + } + + /// Add a set of mock responses to be returned for a single request + /// - Parameter responses: List of assistant messages to be returned as chunks for one request + func mockChatResponse(_ responses: [OpenAIAssistantMessage]) { + mockResponseQueue.append(responses) + } + + /// Register routes for this controller on a Vapor router + /// - Parameter routes: The router to register routes on + func registerRoutes(on routes: RoutesBuilder) { + let chatRoutes = routes.grouped("chat") + chatRoutes.post("completions", use: handleChatCompletion) + } + + private func handleChatCompletion(request: Request) async throws -> Response { + let responses: [OpenAIAssistantMessage] + if !self.mockResponseQueue.isEmpty { + responses = self.mockResponseQueue.removeFirst() + } else { + responses = [] + } + + let body = Response.Body(stream: { writer in + Task { + let capturedResponses = responses + let id = UUID().uuidString + let created = Date().timeIntervalSince1970 + let model = "gpt-3.5-turbo" + for response in capturedResponses { + let chunk = StreamChunk( + id: id, + created: Int(created), + model: model, + choices: [ + StreamChoice( + index: 0, delta: response, finishReason: nil + ) + ] + ) + if let jsonData = try? JSONEncoder().encode(chunk), + let jsonString = String(data: jsonData, encoding: .utf8) + { + _ = writer.write(.buffer(ByteBuffer(string: "data: \(jsonString)\n\n"))) + } + } + + _ = writer.write(.end) + } + }) + + let response = Response(status: .ok, body: body) + response.headers.replaceOrAdd(name: .contentType, value: "text/event-stream") + response.headers.replaceOrAdd(name: .cacheControl, value: "no-cache") + response.headers.replaceOrAdd(name: .connection, value: "keep-alive") + return response + } +} diff --git a/Tests/AgentTests/ChatTests.swift b/Tests/AgentTests/ChatTests.swift index 897b43c..3f0d4be 100644 --- a/Tests/AgentTests/ChatTests.swift +++ b/Tests/AgentTests/ChatTests.swift @@ -15,7 +15,8 @@ struct ChatTests { let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(Message.self, from: encoded) - #expect(decoded.id == "user-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) if case .openai(.user(let decodedUser)) = decoded { #expect(decodedUser.content == "Hello") #expect(decodedUser.role == .user) @@ -32,7 +33,8 @@ struct ChatTests { let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(Message.self, from: encoded) - #expect(decoded.id == "assistant-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) if case .openai(.assistant(let decodedAssistant)) = decoded { #expect(decodedAssistant.content == "Hi there") #expect(decodedAssistant.role == .assistant) @@ -48,7 +50,8 @@ struct ChatTests { let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(Message.self, from: encoded) - #expect(decoded.id == "system-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) if case .openai(.system(let decodedSystem)) = decoded { #expect(decodedSystem.content == "Be helpful") #expect(decodedSystem.role == .system) @@ -64,7 +67,8 @@ struct ChatTests { let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(Message.self, from: encoded) - #expect(decoded.id == "tool-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) if case .openai(.tool(let decodedTool)) = decoded { #expect(decodedTool.content == "Result") #expect(decodedTool.toolCallId == "call_123") @@ -89,7 +93,8 @@ struct ChatTests { let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(Message.self, from: encoded) - #expect(decoded.id == "assistant-2") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) if case .openai(.assistant(let decodedAssistant)) = decoded { #expect(decodedAssistant.content == nil) #expect(decodedAssistant.toolCalls?.count == 1) diff --git a/Tests/AgentTests/OpenAIChatTests.swift b/Tests/AgentTests/OpenAIChatTests.swift index 54a2703..f046815 100644 --- a/Tests/AgentTests/OpenAIChatTests.swift +++ b/Tests/AgentTests/OpenAIChatTests.swift @@ -174,7 +174,8 @@ struct OpenAIChatTests { let message = OpenAIUserMessage(id: "msg-1", content: "User message") let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(OpenAIUserMessage.self, from: encoded) - #expect(decoded.id == "msg-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) #expect(decoded.content == "User message") #expect(decoded.role == .user) } @@ -224,7 +225,8 @@ struct OpenAIChatTests { id: "msg-1", content: "Response", toolCalls: nil, audio: nil) let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(OpenAIAssistantMessage.self, from: encoded) - #expect(decoded.id == "msg-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) #expect(decoded.content == "Response") #expect(decoded.role == .assistant) } @@ -258,7 +260,8 @@ struct OpenAIChatTests { let message = OpenAISystemMessage(id: "sys-1", content: "Be helpful") let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(OpenAISystemMessage.self, from: encoded) - #expect(decoded.id == "sys-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) #expect(decoded.content == "Be helpful") #expect(decoded.role == .system) } @@ -285,7 +288,8 @@ struct OpenAIChatTests { id: "tool-1", content: "{\"result\": 42}", toolCallId: "call_789") let encoded = try JSONEncoder().encode(message) let decoded = try JSONDecoder().decode(OpenAIToolMessage.self, from: encoded) - #expect(decoded.id == "tool-1") + // ID is not encoded (excluded for API compatibility), so a new one is generated on decode + #expect(!decoded.id.isEmpty) #expect(decoded.content == "{\"result\": 42}") #expect(decoded.toolCallId == "call_789") #expect(decoded.role == .tool) @@ -387,4 +391,68 @@ struct OpenAIChatTests { ) #expect(tool.strict == true) } + + // MARK: - Message Encoding Tests (ID Exclusion) + + @Test func testOpenAIUserMessageEncodingExcludesId() throws { + let message = OpenAIUserMessage(id: "test-id-123", content: "Hello") + let encoded = try JSONEncoder().encode(message) + let jsonDict = try JSONSerialization.jsonObject(with: encoded) as! [String: Any] + + // ID should NOT be included in encoded JSON (OpenAI API doesn't accept it) + #expect(jsonDict["id"] == nil, "id field should not be encoded") + #expect(jsonDict["createdAt"] == nil, "createdAt field should not be encoded") + #expect(jsonDict["role"] as? String == "user") + #expect(jsonDict["content"] as? String == "Hello") + } + + @Test func testOpenAIAssistantMessageEncodingExcludesId() throws { + let message = OpenAIAssistantMessage( + id: "test-id-456", + content: "Response", + toolCalls: nil, + audio: nil, + reasoning: "Some reasoning", + reasoningDetails: nil + ) + let encoded = try JSONEncoder().encode(message) + let jsonDict = try JSONSerialization.jsonObject(with: encoded) as! [String: Any] + + // ID and response-only fields should NOT be included + #expect(jsonDict["id"] == nil, "id field should not be encoded") + #expect(jsonDict["audio"] == nil, "audio field should not be encoded") + #expect(jsonDict["reasoning"] == nil, "reasoning field should not be encoded") + #expect(jsonDict["reasoning_details"] == nil, "reasoning_details field should not be encoded") + #expect(jsonDict["role"] as? String == "assistant") + #expect(jsonDict["content"] as? String == "Response") + } + + @Test func testOpenAISystemMessageEncodingExcludesId() throws { + let message = OpenAISystemMessage(id: "test-id-789", content: "System prompt") + let encoded = try JSONEncoder().encode(message) + let jsonDict = try JSONSerialization.jsonObject(with: encoded) as! [String: Any] + + // ID should NOT be included + #expect(jsonDict["id"] == nil, "id field should not be encoded") + #expect(jsonDict["role"] as? String == "system") + #expect(jsonDict["content"] as? String == "System prompt") + } + + @Test func testOpenAIToolMessageEncodingExcludesId() throws { + let message = OpenAIToolMessage( + id: "test-id-abc", + content: "Tool result", + toolCallId: "call_123", + name: "my_tool" + ) + let encoded = try JSONEncoder().encode(message) + let jsonDict = try JSONSerialization.jsonObject(with: encoded) as! [String: Any] + + // ID should NOT be included, but toolCallId should be + #expect(jsonDict["id"] == nil, "id field should not be encoded") + #expect(jsonDict["role"] as? String == "tool") + #expect(jsonDict["content"] as? String == "Tool result") + #expect(jsonDict["tool_call_id"] as? String == "call_123") + #expect(jsonDict["name"] as? String == "my_tool") + } } diff --git a/Tests/AgentTests/e2e/openrouter/OpenAIOpenRouterTests.swift b/Tests/AgentTests/e2e/openrouter/OpenAIOpenRouterTests.swift index 709826a..782bed8 100644 --- a/Tests/AgentTests/e2e/openrouter/OpenAIOpenRouterTests.swift +++ b/Tests/AgentTests/e2e/openrouter/OpenAIOpenRouterTests.swift @@ -83,12 +83,16 @@ struct OpenAIOpenRouterTests { // Check we have assistant messages with tool calls let assistantMessagesWithToolCalls = generatedMessages.filter { msg in - if case .assistant(let assistantMsg) = msg, let toolCalls = assistantMsg.toolCalls, !toolCalls.isEmpty { + if case .assistant(let assistantMsg) = msg, let toolCalls = assistantMsg.toolCalls, + !toolCalls.isEmpty + { return true } return false } - #expect(assistantMessagesWithToolCalls.count >= 2, "Should have at least 2 assistant messages with tool calls") + #expect( + assistantMessagesWithToolCalls.count >= 2, + "Should have at least 2 assistant messages with tool calls") // Check we have tool result messages let toolMessages: [OpenAIToolMessage] = generatedMessages.filter { $0.role == .tool } @@ -100,4 +104,91 @@ struct OpenAIOpenRouterTests { // make sure the last message is an assistant message #expect(lastMessage?.role == .assistant, "Last message should be an assistant message") } + + @Test + /** + Test that multi-turn conversation with tool calls works correctly. + This test reproduces the scenario where: + 1. User sends message + 2. Assistant responds with tool call + 3. Tool result is sent back + 4. Assistant responds + 5. User sends another message - this previously failed with "Expected an ID that begins with 'msg'" error + + The fix: Message IDs are no longer included when encoding messages for the OpenAI API. + */ + func testMultiTurnConversationAfterToolCallDoesNotFailWithInvalidId() async throws { + struct GreetInput: Decodable { + let name: String + } + + let (client, source, _) = try await setUpTests() + let model = Model.custom(CustomModel(id: "openai/gpt-4.1-mini")) + + let greetTool = AgentTool( + name: "greet", + description: "Greet someone by name", + parameters: .object(properties: ["name": .string()], required: ["name"]) + ) { (args: GreetInput) async in + return "Hello, \(args.name)!" + } + + // First turn: user message triggers tool call + let messages1: [Message] = [ + .openai(.system(.init(content: "You are a greeting assistant. Use the greet tool when asked to greet someone."))), + .openai(.user(.init(content: "Please greet Alice"))), + ] + + let stream1 = await client.process( + messages: messages1, + model: model, + source: source, + tools: [greetTool] + ) + + var allMessages: [Message] = messages1 + for try await part in stream1 { + if case .message(let msg) = part { + allMessages.append(msg) + } + } + + // Verify we got an assistant response with content + let lastAssistantMessage = allMessages.last { msg in + if case .openai(let openAIMsg) = msg, case .assistant = openAIMsg { + return true + } + return false + } + #expect(lastAssistantMessage != nil, "Should have an assistant response") + + // Second turn: send another user message using the same conversation history + // This is where the bug occurred - OpenAI rejected the message ID format + let secondUserMessage = Message.openai(.user(.init(content: "Now greet Bob"))) + allMessages.append(secondUserMessage) + + let stream2 = await client.process( + messages: allMessages, + model: model, + source: source, + tools: [greetTool] + ) + + var secondTurnMessages: [Message] = [] + for try await part in stream2 { + if case .message(let msg) = part { + secondTurnMessages.append(msg) + } + } + + // Verify the second turn completed successfully + #expect(!secondTurnMessages.isEmpty, "Second turn should produce messages") + + // The last message should be an assistant message + if let lastMsg = secondTurnMessages.last, case .openai(let openAIMsg) = lastMsg { + #expect(openAIMsg.role == .assistant, "Last message should be an assistant message") + } else { + Issue.record("Expected assistant message in second turn") + } + } }