Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions Sources/Agent/chat/openaiChat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ public struct OpenAIUserMessage: Hashable, Codable, Sendable {
self.content = try container.decode(String.self, forKey: .content)
self.createdAt = try container.decodeIfPresent(Date.self, forKey: .createdAt) ?? Date()
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(role, forKey: .role)
try container.encode(content, forKey: .content)
// Exclude id and createdAt - not part of OpenAI API spec
}
}

public struct OpenAIAssistantMessage: Hashable, Codable, Sendable {
Expand Down Expand Up @@ -295,6 +302,14 @@ public struct OpenAIAssistantMessage: Hashable, Codable, Sendable {
reasoningDetails: reasoningDetails
)
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(role, forKey: .role)
try container.encodeIfPresent(content, forKey: .content)
try container.encodeIfPresent(toolCalls, forKey: .toolCalls)
// Exclude id, audio, reasoning, reasoningDetails - not part of request spec
}
}

public struct OpenAISystemMessage: Hashable, Codable, Sendable {
Expand All @@ -319,6 +334,13 @@ public struct OpenAISystemMessage: Hashable, Codable, Sendable {
self.role = try container.decodeIfPresent(OpenAIRole.self, forKey: .role) ?? .system
self.content = try container.decode(String.self, forKey: .content)
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(role, forKey: .role)
try container.encode(content, forKey: .content)
// Exclude id - not part of OpenAI API spec
}
}

public struct OpenAIToolMessage: Hashable, Codable, Sendable {
Expand Down Expand Up @@ -351,6 +373,15 @@ public struct OpenAIToolMessage: Hashable, Codable, Sendable {
self.toolCallId = try container.decode(String.self, forKey: .toolCallId)
self.name = try container.decodeIfPresent(String.self, forKey: .name)
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(role, forKey: .role)
try container.encode(content, forKey: .content)
try container.encode(toolCallId, forKey: .toolCallId)
try container.encodeIfPresent(name, forKey: .name)
// Exclude id - not part of OpenAI API spec
}
}

public enum OpenAIMessage: Hashable, Codable, Sendable {
Expand Down
38 changes: 29 additions & 9 deletions Sources/AgentLayout/ChatProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public class ChatProvider: ChatProviderProtocol {
encoder.outputFormatting = .prettyPrinted
let resultString: String
if let data = try? encoder.encode(AnyEncodable(result)),
let jsonString = String(data: data, encoding: .utf8)
let jsonString = String(data: data, encoding: .utf8)
{
resultString = jsonString
} else {
Expand All @@ -171,8 +171,8 @@ public class ChatProvider: ChatProviderProtocol {
if let chat = self.chat {
for message in chat.messages.reversed() {
if case .openai(let openAIMsg) = message,
case .assistant(let assistant) = openAIMsg,
let toolCalls = assistant.toolCalls
case .assistant(let assistant) = openAIMsg,
let toolCalls = assistant.toolCalls
{
if let toolCall = toolCalls.first(where: { $0.id == id }) {
toolName = toolCall.function?.name
Expand Down Expand Up @@ -341,7 +341,8 @@ public class ChatProvider: ChatProviderProtocol {

public func send(_ message: String) {
guard generationTask == nil else { return }
guard var chat = chat, let currentSource = currentSource, let currentModel = currentModel else { return }
guard var chat = chat, let currentSource = currentSource, let currentModel = currentModel
else { return }

let userMsg = Message.openai(.user(.init(content: message)))
chat.messages.append(userMsg)
Expand All @@ -353,14 +354,19 @@ public class ChatProvider: ChatProviderProtocol {
self?.scrollToBottom?()
}

let source = currentSource
let model = currentModel
startGeneration(source: currentSource, model: currentModel, userMessage: message)
}

/// Internal method to start generation without adding a user message.
/// Used by both `send` and `regenerate`.
private func startGeneration(source: Source, model: Model, userMessage: String? = nil) {
generationTask = Task { [weak self] in
guard let self = self else { return }
self.status = .loading

try? await self.sendMessage(message: message)
if let message = userMessage {
try? await self.sendMessage(message: message)
}

do {
var messagesToSend = self.chat?.messages ?? []
Expand Down Expand Up @@ -493,7 +499,9 @@ public class ChatProvider: ChatProviderProtocol {
guard generationTask == nil else { return }
guard let chat = chat else { return }
guard let index = chat.messages.firstIndex(where: { $0.id == messageId }) else { return }
guard let currentSource = currentSource, let currentModel = currentModel else { return }

// Find the user message content before the target message
var userMessageContent: String? = nil
for i in stride(from: index - 1, through: 0, by: -1) {
if case .openai(let openAIMsg) = chat.messages[i],
Expand All @@ -504,11 +512,14 @@ public class ChatProvider: ChatProviderProtocol {
}
}

guard let content = userMessageContent else { return }
guard userMessageContent != nil else { return }

// Remove the target message and all subsequent messages
self.chat?.messages.removeSubrange(index...)
notifyMessageChange()
send(content)

// Start generation without adding a new user message
startGeneration(source: currentSource, model: currentModel)
}

public func cancel() {
Expand Down Expand Up @@ -624,4 +635,13 @@ public class ChatProvider: ChatProviderProtocol {
public func updateSystemPrompt(_ newSystemPrompt: String?) {
self.systemPrompt = newSystemPrompt
}

/**
Regenerates the conversation starting at the given message.
This will remove all messages after the given message and regenerate the conversation from the given message.
- Parameter message: The message to start regenerating from.
*/
public func regenerate(startsAt message: Message) async throws {

}
}
Loading