diff --git a/Sources/AgentLayout/AgentLayout.swift b/Sources/AgentLayout/AgentLayout.swift index 044bcbe..19ddf0e 100644 --- a/Sources/AgentLayout/AgentLayout.swift +++ b/Sources/AgentLayout/AgentLayout.swift @@ -259,6 +259,10 @@ public struct AgentLayout: View { proxy.scrollTo(lastMessage.id, anchor: .top) } } + chatProvider.onError = { err in + error = err + showAlert = true + } // Scroll to bottom when view first appears DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { scrollToBottom() diff --git a/Sources/AgentLayout/ChatProvider.swift b/Sources/AgentLayout/ChatProvider.swift index 1ee3626..f86d4fa 100644 --- a/Sources/AgentLayout/ChatProvider.swift +++ b/Sources/AgentLayout/ChatProvider.swift @@ -49,6 +49,7 @@ public class ChatProvider: ChatProviderProtocol { public var onDelete: ((Int) -> Void)? public var onEdit: ((Int, Message) -> Void)? public var onMessageChange: (([Message]) -> Void)? + public var onError: ((Error) -> Void)? // MARK: - Internal State (not observed) @ObservationIgnored private var agentClient = AgentClient() @@ -360,6 +361,7 @@ public class ChatProvider: ChatProviderProtocol { self.currentStreamingMessageId = nil } catch { print("Error continuing conversation: \(error)") + self.onError?(error) if let msgId = self.currentStreamingMessageId { self.chat?.messages.removeAll { $0.id == msgId } self.notifyMessageChange() @@ -540,6 +542,7 @@ public class ChatProvider: ChatProviderProtocol { self.currentStreamingMessageId = nil } catch { print("Error sending message: \(error)") + self.onError?(error) if let msgId = self.currentStreamingMessageId { self.chat?.messages.removeAll { $0.id == msgId } self.notifyMessageChange() @@ -569,7 +572,22 @@ public class ChatProvider: ChatProviderProtocol { guard let index = chat.messages.firstIndex(where: { $0.id == messageId }) else { return } guard let currentSource = currentSource, let currentModel = currentModel else { return } - // Find the user message content before the target message + let targetMessage = chat.messages[index] + + // Check if target is a user message + if case .openai(let openAIMsg) = targetMessage, + case .user = openAIMsg + { + // User message: remove everything after it (keep the user message) + if index + 1 < chat.messages.count { + self.chat?.messages.removeSubrange((index + 1)...) + } + notifyMessageChange() + startGeneration(source: currentSource, model: currentModel) + return + } + + // Assistant/other message: find preceding user message and remove from target onwards var userMessageContent: String? = nil for i in stride(from: index - 1, through: 0, by: -1) { if case .openai(let openAIMsg) = chat.messages[i], @@ -582,7 +600,7 @@ public class ChatProvider: ChatProviderProtocol { guard userMessageContent != nil else { return } - // Remove the target message and all subsequent messages + // Remove target message and all subsequent messages self.chat?.messages.removeSubrange(index...) notifyMessageChange() diff --git a/Sources/AgentLayout/Message/ThinkingContentView.swift b/Sources/AgentLayout/Message/ThinkingContentView.swift index e8a4766..25b2346 100644 --- a/Sources/AgentLayout/Message/ThinkingContentView.swift +++ b/Sources/AgentLayout/Message/ThinkingContentView.swift @@ -22,10 +22,10 @@ struct ThinkingContentView: View { @State private var isExpanded = false - /// Title text to display (summary or fallback) + /// Title text to display (summary with markdown stripped, or fallback) private var titleText: String { if let summary = summary, !summary.isEmpty { - return summary + return MarkdownStripper.stripMarkdown(summary) } return "Thinking..." } @@ -61,8 +61,7 @@ struct ThinkingContentView: View { .foregroundColor(.orange.mix(with: .mint, by: 0.9)) .shimmering() } else { - Markdown(titleText) - .markdownTheme(.chatTheme) + Text(titleText) .lineLimit(1) .foregroundColor(.orange.mix(with: .mint, by: 0.9)) } diff --git a/Sources/AgentLayout/ModelPicker.swift b/Sources/AgentLayout/ModelPicker.swift index ad7b1b2..11c4ecb 100644 --- a/Sources/AgentLayout/ModelPicker.swift +++ b/Sources/AgentLayout/ModelPicker.swift @@ -8,6 +8,7 @@ struct ModelPicker: View { let onClose: () -> Void @State private var hoveredModel: Model? + @State private var searchText: String = "" public init( currentModel: Binding, @@ -22,54 +23,94 @@ struct ModelPicker: View { } var body: some View { - ScrollView { - VStack(alignment: .leading, spacing: 5) { - ForEach(sources) { source in - Text(source.displayName) - .foregroundColor(Color.gray) + ScrollViewReader { proxy in + ScrollView { + VStack(alignment: .leading, spacing: 5) { + ForEach(sources) { source in + Text(source.displayName) + .foregroundColor(Color.gray) - ForEach(source.models) { model in - HStack { - Text(model.displayName) - .padding(.vertical, 12) - .padding(.horizontal, 12) - .frame(maxWidth: .infinity, alignment: .leading) - // if model is custom model, show custom icon - if case .custom = model { - Image(systemName: "gear") - .padding() + ForEach(source.models) { model in + HStack { + Text(model.displayName) + .padding(.vertical, 12) + .padding(.horizontal, 12) + .frame(maxWidth: .infinity, alignment: .leading) + // if model is custom model, show custom icon + if case .custom = model { + Image(systemName: "gear") + .padding() + } + if model == currentModel { + Spacer() + Image(systemName: "checkmark") + .padding(.trailing, 12) + } } - if model == currentModel { - Spacer() - Image(systemName: "checkmark") - .padding(.trailing, 12) + .id(model) + .onHover { hovering in + if hovering { + hoveredModel = model + } else { + hoveredModel = nil + } } - } - .onHover { hovering in - if hovering { - hoveredModel = model - } else { - hoveredModel = nil - } - } - .background( - hoveredModel == model ? Color.gray.opacity(0.12) : Color.clear - ) - .cornerRadius(10) - .frame(width: 220) - .clipShape(RoundedRectangle(cornerRadius: 10)) - .onTapGesture { - withAnimation { - currentModel = model + .background( + hoveredModel == model ? Color.gray.opacity(0.12) : Color.clear + ) + .cornerRadius(10) + .frame(width: 220) + .clipShape(RoundedRectangle(cornerRadius: 10)) + .onTapGesture { + withAnimation { + currentModel = model + } + onClose() } - onClose() } } } + .padding() + } + .frame(maxHeight: 400) + .onAppear { + DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { + withAnimation { + proxy.scrollTo(currentModel, anchor: .center) + } + } + } + .onKeyPress { press in + if press.key == .escape { + searchText = "" + return .handled + } + if press.key == .delete { + if !searchText.isEmpty { + searchText.removeLast() + } + scrollToFirstMatch(proxy: proxy) + return .handled + } + if let char = press.characters.first, char.isLetter || char.isNumber { + searchText.append(char) + scrollToFirstMatch(proxy: proxy) + return .handled + } + return .ignored + } + } + } + + private func scrollToFirstMatch(proxy: ScrollViewProxy) { + let allModels = sources.flatMap { $0.models } + if let match = allModels.first(where: { + $0.displayName.localizedCaseInsensitiveContains(searchText) + }) { + withAnimation { + proxy.scrollTo(match, anchor: .center) } - .padding() } - .frame(maxHeight: 400) } } diff --git a/Sources/AgentLayout/Utils/MarkdownStripper.swift b/Sources/AgentLayout/Utils/MarkdownStripper.swift new file mode 100644 index 0000000..877d85e --- /dev/null +++ b/Sources/AgentLayout/Utils/MarkdownStripper.swift @@ -0,0 +1,92 @@ +// +// MarkdownStripper.swift +// AgentLayout +// +// Created by Claude on 12/2/25. +// + +import Foundation + +/// Utility for stripping markdown syntax from text +enum MarkdownStripper { + /// Strips common markdown syntax from text and returns plain text + /// - Parameter text: The markdown-formatted text + /// - Returns: Plain text with markdown syntax removed + static func stripMarkdown(_ text: String) -> String { + var result = text + + // Remove bold: **text** or __text__ + result = result.replacingOccurrences( + of: #"\*\*(.+?)\*\*"#, + with: "$1", + options: .regularExpression + ) + result = result.replacingOccurrences( + of: #"__(.+?)__"#, + with: "$1", + options: .regularExpression + ) + + // Remove italic: *text* or _text_ + result = result.replacingOccurrences( + of: #"\*(.+?)\*"#, + with: "$1", + options: .regularExpression + ) + result = result.replacingOccurrences( + of: #"(? alt (before links to avoid conflict) + result = result.replacingOccurrences( + of: #"!\[([^\]]*)\]\([^)]+\)"#, + with: "$1", + options: .regularExpression + ) + + // Remove links: [text](url) -> text + result = result.replacingOccurrences( + of: #"\[([^\]]+)\]\([^)]+\)"#, + with: "$1", + options: .regularExpression + ) + + // Remove any remaining unmatched markdown characters (e.g., incomplete **) + // This handles cases like "**incomplete" where there's no closing ** + result = result.replacingOccurrences( + of: #"^\*\*"#, + with: "", + options: .regularExpression + ) + result = result.replacingOccurrences( + of: #"\*\*$"#, + with: "", + options: .regularExpression + ) + + return result + } +} diff --git a/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift b/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift index 08a2b01..85d7357 100644 --- a/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift +++ b/Tests/AgentLayoutTests/ChatProviderRegenerateTests.swift @@ -140,6 +140,148 @@ struct ChatProviderRegenerateTests { RegenerateSharedMockServer.shared.controller } + // MARK: - Test: Regenerate User Message + + @Test("Regenerating user message keeps it and replaces subsequent messages") + func testRegenerateUserMessage() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let userMsgId = UUID().uuidString + let userMsg = createUserMessage("Hello", id: userMsgId) + let assistantMsg = createAssistantMessage("Hi there!") + let chat = createChat(messages: [userMsg, assistantMsg]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 2) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "Hello! How can I help?", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: userMsgId) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 2 messages: original user + new assistant + #expect(provider.messages.count == 2) + + // User message should be preserved with same ID + #expect(provider.messages[0].id == userMsgId) + + // New assistant message should have new content + if case .openai(let openAIMsg) = provider.messages[1], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "Hello! How can I help?") + } else { + Issue.record("Expected assistant message at index 1") + } + } + + @Test("Regenerating user message in middle of conversation keeps prior messages") + func testRegenerateUserMessageInMiddle() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let user1 = createUserMessage("Hello") + let assistant1 = createAssistantMessage("Hi there!") + let user2Id = UUID().uuidString + let user2 = createUserMessage("How are you?", id: user2Id) + let assistant2 = createAssistantMessage("I'm doing great!") + + let chat = createChat(messages: [user1, assistant1, user2, assistant2]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 4) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "I'm doing fantastic!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: user2Id) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 4 messages: user1, assistant1, user2 (preserved), new assistant + #expect(provider.messages.count == 4) + + // First 3 messages should be preserved + #expect(provider.messages[0].id == user1.id) + #expect(provider.messages[1].id == assistant1.id) + #expect(provider.messages[2].id == user2Id) + + // Last message should be new assistant with new content + if case .openai(let openAIMsg) = provider.messages[3], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "I'm doing fantastic!") + } else { + Issue.record("Expected assistant message at index 3") + } + } + + @Test("Regenerating first user message removes all subsequent messages") + func testRegenerateFirstUserMessage() async throws { + let provider = ChatProvider() + let source = createSource() + let model = createModel() + + let user1Id = UUID().uuidString + let user1 = createUserMessage("Hello", id: user1Id) + let assistant1 = createAssistantMessage("Hi there!") + let user2 = createUserMessage("How are you?") + let assistant2 = createAssistantMessage("I'm doing great!") + + let chat = createChat(messages: [user1, assistant1, user2, assistant2]) + + provider.setup( + chat: chat, + currentModel: model, + currentSource: source + ) + + #expect(provider.messages.count == 4) + + controller.mockChatResponse([ + OpenAIAssistantMessage(content: "New response!", toolCalls: nil, audio: nil) + ]) + + provider.regenerate(messageId: user1Id) + try await waitForGenerationComplete(provider: provider) + + // Should have exactly 2 messages: user1 (preserved) + new assistant + #expect(provider.messages.count == 2) + + // User1 should be preserved + #expect(provider.messages[0].id == user1Id) + + // assistant1, user2, assistant2 should be deleted + let messageIds = provider.messages.map { $0.id } + #expect(!messageIds.contains(assistant1.id)) + #expect(!messageIds.contains(user2.id)) + #expect(!messageIds.contains(assistant2.id)) + + // New assistant message should have new content + if case .openai(let openAIMsg) = provider.messages[1], + case .assistant(let newAssistant) = openAIMsg + { + #expect(newAssistant.content == "New response!") + } else { + Issue.record("Expected assistant message at index 1") + } + } + // MARK: - Test: Regenerate First Assistant Message @Test("Regenerating first assistant message replaces it with new response") diff --git a/Tests/AgentLayoutTests/MarkdownStripperTests.swift b/Tests/AgentLayoutTests/MarkdownStripperTests.swift new file mode 100644 index 0000000..6b59dcf --- /dev/null +++ b/Tests/AgentLayoutTests/MarkdownStripperTests.swift @@ -0,0 +1,109 @@ +// +// MarkdownStripperTests.swift +// AgentLayout +// +// Created by Claude on 12/2/25. +// + +import Testing + +@testable import AgentLayout + +struct MarkdownStripperTests { + + @Test func testStripBoldDoubleAsterisk() { + #expect(MarkdownStripper.stripMarkdown("**bold**") == "bold") + } + + @Test func testStripBoldDoubleUnderscore() { + #expect(MarkdownStripper.stripMarkdown("__bold__") == "bold") + } + + @Test func testStripItalicAsterisk() { + #expect(MarkdownStripper.stripMarkdown("*italic*") == "italic") + } + + @Test func testStripItalicUnderscore() { + #expect(MarkdownStripper.stripMarkdown("_italic_") == "italic") + } + + @Test func testStripCombinedBoldAndItalic() { + #expect(MarkdownStripper.stripMarkdown("**bold** and *italic*") == "bold and italic") + } + + @Test func testStripInlineCode() { + #expect(MarkdownStripper.stripMarkdown("`code`") == "code") + } + + @Test func testStripLinks() { + #expect(MarkdownStripper.stripMarkdown("[text](https://example.com)") == "text") + } + + @Test func testStripImages() { + #expect(MarkdownStripper.stripMarkdown("![alt text](https://example.com/image.png)") == "alt text") + } + + @Test func testStripH1Header() { + #expect(MarkdownStripper.stripMarkdown("# Header") == "Header") + } + + @Test func testStripH2Header() { + #expect(MarkdownStripper.stripMarkdown("## Header") == "Header") + } + + @Test func testStripH3Header() { + #expect(MarkdownStripper.stripMarkdown("### Header") == "Header") + } + + @Test func testPlainTextUnchanged() { + #expect(MarkdownStripper.stripMarkdown("plain text") == "plain text") + } + + @Test func testEmptyString() { + #expect(MarkdownStripper.stripMarkdown("") == "") + } + + @Test func testIncompleteMarkdownBoldStart() { + // Should handle incomplete markdown gracefully - removes leading ** + let result = MarkdownStripper.stripMarkdown("**incomplete") + #expect(result == "incomplete") + } + + @Test func testIncompleteMarkdownBoldEnd() { + // Should handle incomplete markdown gracefully - removes trailing ** + let result = MarkdownStripper.stripMarkdown("incomplete**") + #expect(result == "incomplete") + } + + @Test func testStrikethrough() { + #expect(MarkdownStripper.stripMarkdown("~~strikethrough~~") == "strikethrough") + } + + @Test func testComplexMarkdown() { + let input = "**Bold** text with *italic* and `code` plus [link](url)" + let expected = "Bold text with italic and code plus link" + #expect(MarkdownStripper.stripMarkdown(input) == expected) + } + + @Test func testPreparingExample() { + // This is the actual use case from the bug report + #expect(MarkdownStripper.stripMarkdown("**Preparing**") == "Preparing") + } + + @Test func testNestedFormatting() { + // Test bold containing italic-like content + #expect(MarkdownStripper.stripMarkdown("**bold text**") == "bold text") + } + + @Test func testMultipleLinks() { + let input = "[link1](url1) and [link2](url2)" + let expected = "link1 and link2" + #expect(MarkdownStripper.stripMarkdown(input) == expected) + } + + @Test func testUnderscoreInWord() { + // Underscores within words should be preserved (like snake_case) + let result = MarkdownStripper.stripMarkdown("snake_case_variable") + #expect(result == "snake_case_variable") + } +}