diff --git a/Directory.Packages.props b/Directory.Packages.props index cec2ac886..38e68d3ad 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -26,6 +26,7 @@ + diff --git a/src/Infrastructure/BotSharp.Abstraction/BotSharp.Abstraction.csproj b/src/Infrastructure/BotSharp.Abstraction/BotSharp.Abstraction.csproj index 97e645d0a..2008c6a2e 100644 --- a/src/Infrastructure/BotSharp.Abstraction/BotSharp.Abstraction.csproj +++ b/src/Infrastructure/BotSharp.Abstraction/BotSharp.Abstraction.csproj @@ -36,6 +36,7 @@ + diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/Dtos/ChatResponseDto.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/Dtos/ChatResponseDto.cs index b391bc7b1..19b5e8b28 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/Dtos/ChatResponseDto.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/Dtos/ChatResponseDto.cs @@ -35,6 +35,9 @@ public class ChatResponseDto : InstructResult [JsonPropertyName("has_message_files")] public bool HasMessageFiles { get; set; } + [JsonPropertyName("is_streaming")] + public bool IsStreaming { get; set; } + [JsonPropertyName("created_at")] public DateTime CreatedAt { get; set; } = DateTime.UtcNow; } diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationService.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationService.cs index 656b397d8..322d86b6f 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationService.cs @@ -36,7 +36,7 @@ public interface IConversationService /// Received the response from AI Agent /// Task SendMessage(string agentId, - RoleDialogModel lastDialog, + RoleDialogModel message, PostbackMessageModel? replyMessage, Func onResponseReceived); diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs index de2697595..6a8af2d62 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs @@ -117,6 +117,9 @@ public class RoleDialogModel : ITrackableMessage [JsonIgnore(Condition = JsonIgnoreCondition.Always)] public string RenderedInstruction { get; set; } = string.Empty; + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool IsStreaming { get; set; } + private RoleDialogModel() { } @@ -159,7 +162,8 @@ public static RoleDialogModel From(RoleDialogModel source, Payload = source.Payload, StopCompletion = source.StopCompletion, Instruction = source.Instruction, - Data = source.Data + Data = source.Data, + IsStreaming = source.IsStreaming }; } } diff --git a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IChatCompletion.cs b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IChatCompletion.cs index 7cf52fe47..51f89375e 100644 --- a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IChatCompletion.cs +++ b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IChatCompletion.cs @@ -23,7 +23,6 @@ Task GetChatCompletionsAsync(Agent agent, Func onMessageReceived, Func onFunctionExecuting); - Task GetChatCompletionsStreamingAsync(Agent agent, - List conversations, - Func onMessageReceived); + Task GetChatCompletionsStreamingAsync(Agent agent, + List conversations); } diff --git a/src/Infrastructure/BotSharp.Abstraction/Observables/Models/HubObserveData.cs b/src/Infrastructure/BotSharp.Abstraction/Observables/Models/HubObserveData.cs new file mode 100644 index 000000000..bf771cb4d --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Observables/Models/HubObserveData.cs @@ -0,0 +1,7 @@ +namespace BotSharp.Abstraction.Observables.Models; + +public class HubObserveData : ObserveDataBase +{ + public string EventName { get; set; } = null!; + public RoleDialogModel Data { get; set; } = null!; +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Observables/Models/ObserveDataBase.cs b/src/Infrastructure/BotSharp.Abstraction/Observables/Models/ObserveDataBase.cs new file mode 100644 index 000000000..177732726 --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Observables/Models/ObserveDataBase.cs @@ -0,0 +1,6 @@ +namespace BotSharp.Abstraction.Observables.Models; + +public abstract class ObserveDataBase +{ + public IServiceProvider ServiceProvider { get; set; } = null!; +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Routing/IRoutingService.cs b/src/Infrastructure/BotSharp.Abstraction/Routing/IRoutingService.cs index 97be88741..33803a98e 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Routing/IRoutingService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Routing/IRoutingService.cs @@ -26,11 +26,7 @@ public interface IRoutingService /// RoutingRule[] GetRulesByAgentId(string id); - //void ResetRecursiveCounter(); - //int GetRecursiveCounter(); - //void SetRecursiveCounter(int counter); - - Task InvokeAgent(string agentId, List dialogs); + Task InvokeAgent(string agentId, List dialogs, bool useStream = false); Task InvokeFunction(string name, RoleDialogModel messages); Task InstructLoop(Agent agent, RoleDialogModel message, List dialogs); diff --git a/src/Infrastructure/BotSharp.Core/Conversations/ConversationPlugin.cs b/src/Infrastructure/BotSharp.Core/Conversations/ConversationPlugin.cs index 13ee1de62..a559877f8 100644 --- a/src/Infrastructure/BotSharp.Core/Conversations/ConversationPlugin.cs +++ b/src/Infrastructure/BotSharp.Core/Conversations/ConversationPlugin.cs @@ -10,7 +10,9 @@ using BotSharp.Core.Routing.Reasoning; using BotSharp.Core.Templating; using BotSharp.Core.Translation; +using BotSharp.Core.Observables.Queues; using Microsoft.Extensions.Configuration; +using BotSharp.Abstraction.Observables.Models; namespace BotSharp.Core.Conversations; @@ -41,6 +43,8 @@ public void RegisterDI(IServiceCollection services, IConfiguration config) return settingService.Bind("GoogleApi"); }); + services.AddSingleton>(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Infrastructure/BotSharp.Core/Demo/Functions/GetWeatherFn.cs b/src/Infrastructure/BotSharp.Core/Demo/Functions/GetWeatherFn.cs index a78ad4ec7..cb6bcd886 100644 --- a/src/Infrastructure/BotSharp.Core/Demo/Functions/GetWeatherFn.cs +++ b/src/Infrastructure/BotSharp.Core/Demo/Functions/GetWeatherFn.cs @@ -17,7 +17,7 @@ public GetWeatherFn(IServiceProvider services) public async Task Execute(RoleDialogModel message) { message.Content = $"It is a sunny day!"; - //message.StopCompletion = true; + message.StopCompletion = false; return true; } } \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.GoogleAI/Models/Realtime/RealtimeTranscriptionResponse.cs b/src/Infrastructure/BotSharp.Core/Infrastructures/Streams/RealtimeTextStream.cs similarity index 81% rename from src/Plugins/BotSharp.Plugin.GoogleAI/Models/Realtime/RealtimeTranscriptionResponse.cs rename to src/Infrastructure/BotSharp.Core/Infrastructures/Streams/RealtimeTextStream.cs index 0a383c80a..e041a53b5 100644 --- a/src/Plugins/BotSharp.Plugin.GoogleAI/Models/Realtime/RealtimeTranscriptionResponse.cs +++ b/src/Infrastructure/BotSharp.Core/Infrastructures/Streams/RealtimeTextStream.cs @@ -1,12 +1,12 @@ using System.IO; -namespace BotSharp.Plugin.GoogleAI.Models.Realtime; +namespace BotSharp.Core.Infrastructures.Streams; -internal class RealtimeTranscriptionResponse : IDisposable +public class RealtimeTextStream : IDisposable { - public RealtimeTranscriptionResponse() + public RealtimeTextStream() { - + } private bool _disposed = false; @@ -20,6 +20,13 @@ public Stream? ContentStream } } + public long Length => _contentStream.Length; + + public bool IsNullOrEmpty() + { + return _contentStream == null || Length == 0; + } + public void Collect(string text) { if (_disposed) return; diff --git a/src/Infrastructure/BotSharp.Core/Observables/Queues/MessageHub.cs b/src/Infrastructure/BotSharp.Core/Observables/Queues/MessageHub.cs new file mode 100644 index 000000000..affb142da --- /dev/null +++ b/src/Infrastructure/BotSharp.Core/Observables/Queues/MessageHub.cs @@ -0,0 +1,43 @@ +using System.Reactive.Subjects; + +namespace BotSharp.Core.Observables.Queues; + +public class MessageHub where T : class +{ + private readonly ILogger> _logger; + private readonly ISubject _observable = new Subject(); + public IObservable Events => _observable; + + public MessageHub(ILogger> logger) + { + _logger = logger; + } + + /// + /// Push an item to the observers. + /// + /// + public void Push(T item) + { + _observable.OnNext(item); + } + + /// + /// Send a complete notification to the observers. + /// This will stop the observers from receiving data. + /// + public void Complete() + { + _observable.OnCompleted(); + } + + /// + /// Send an error notification to the observers. + /// This will stop the observers from receiving data. + /// + /// + public void Error(Exception error) + { + _observable.OnError(error); + } +} diff --git a/src/Infrastructure/BotSharp.Core/Routing/Reasoning/InstructExecutor.cs b/src/Infrastructure/BotSharp.Core/Routing/Reasoning/InstructExecutor.cs index c93e133ce..18d87cb9a 100644 --- a/src/Infrastructure/BotSharp.Core/Routing/Reasoning/InstructExecutor.cs +++ b/src/Infrastructure/BotSharp.Core/Routing/Reasoning/InstructExecutor.cs @@ -57,7 +57,9 @@ await HookEmitter.Emit(_services, async hook => await hook.OnRouti } else { - var ret = await routing.InvokeAgent(agentId, dialogs); + var state = _services.GetRequiredService(); + var useStreamMsg = state.GetState("use_stream_message"); + var ret = await routing.InvokeAgent(agentId, dialogs, bool.TryParse(useStreamMsg, out var useStream) && useStream); } var response = dialogs.Last(); diff --git a/src/Infrastructure/BotSharp.Core/Routing/RoutingService.InvokeAgent.cs b/src/Infrastructure/BotSharp.Core/Routing/RoutingService.InvokeAgent.cs index 7bfd93523..eb60e9255 100644 --- a/src/Infrastructure/BotSharp.Core/Routing/RoutingService.InvokeAgent.cs +++ b/src/Infrastructure/BotSharp.Core/Routing/RoutingService.InvokeAgent.cs @@ -4,7 +4,7 @@ namespace BotSharp.Core.Routing; public partial class RoutingService { - public async Task InvokeAgent(string agentId, List dialogs) + public async Task InvokeAgent(string agentId, List dialogs, bool useStream = false) { var agentService = _services.GetRequiredService(); var agent = await agentService.LoadAgent(agentId); @@ -30,8 +30,16 @@ public async Task InvokeAgent(string agentId, List dialog provider: provider, model: model); + RoleDialogModel response; var message = dialogs.Last(); - var response = await chatCompletion.GetChatCompletions(agent, dialogs); + if (useStream) + { + response = await chatCompletion.GetChatCompletionsStreamingAsync(agent, dialogs); + } + else + { + response = await chatCompletion.GetChatCompletions(agent, dialogs); + } if (response.Role == AgentRole.Function) { @@ -45,8 +53,9 @@ public async Task InvokeAgent(string agentId, List dialog message.FunctionArgs = response.FunctionArgs; message.Indication = response.Indication; message.CurrentAgentId = agent.Id; + message.IsStreaming = response.IsStreaming; - await InvokeFunction(message, dialogs); + await InvokeFunction(message, dialogs, useStream); } else { @@ -59,6 +68,7 @@ public async Task InvokeAgent(string agentId, List dialog message = RoleDialogModel.From(message, role: AgentRole.Assistant, content: response.Content); message.CurrentAgentId = agent.Id; + message.IsStreaming = response.IsStreaming; dialogs.Add(message); Context.SetDialogs(dialogs); } @@ -66,7 +76,7 @@ public async Task InvokeAgent(string agentId, List dialog return true; } - private async Task InvokeFunction(RoleDialogModel message, List dialogs) + private async Task InvokeFunction(RoleDialogModel message, List dialogs, bool useStream) { // execute function // Save states @@ -102,7 +112,7 @@ private async Task InvokeFunction(RoleDialogModel message, List InstructDirect(Agent agent, RoleDialogModel m } else { - var ret = await routing.InvokeAgent(agentId, dialogs); + var state = _services.GetRequiredService(); + var useStreamMsg = state.GetState("use_stream_message"); + var ret = await routing.InvokeAgent(agentId, dialogs, bool.TryParse(useStreamMsg, out var useStream) && useStream); } var response = dialogs.Last(); diff --git a/src/Infrastructure/BotSharp.OpenAPI/Controllers/ConversationController.cs b/src/Infrastructure/BotSharp.OpenAPI/Controllers/ConversationController.cs index d3da550ff..bd131ca1e 100644 --- a/src/Infrastructure/BotSharp.OpenAPI/Controllers/ConversationController.cs +++ b/src/Infrastructure/BotSharp.OpenAPI/Controllers/ConversationController.cs @@ -377,6 +377,7 @@ await conv.SendMessage(agentId, inputMsg, return response; } + [HttpPost("/conversation/{agentId}/{conversationId}/sse")] public async Task SendMessageSse([FromRoute] string agentId, [FromRoute] string conversationId, [FromBody] NewMessageModel input) { diff --git a/src/Plugins/BotSharp.Plugin.AnthropicAI/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.AnthropicAI/Providers/ChatCompletionProvider.cs index 285e8abdb..50d67a1bb 100644 --- a/src/Plugins/BotSharp.Plugin.AnthropicAI/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.AnthropicAI/Providers/ChatCompletionProvider.cs @@ -96,8 +96,7 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); } - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, - Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { throw new NotImplementedException(); } diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/BotSharp.Plugin.AzureOpenAI.csproj b/src/Plugins/BotSharp.Plugin.AzureOpenAI/BotSharp.Plugin.AzureOpenAI.csproj index 494326fac..372fb3de5 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/BotSharp.Plugin.AzureOpenAI.csproj +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/BotSharp.Plugin.AzureOpenAI.csproj @@ -16,7 +16,7 @@ - + diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Chat/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Chat/ChatCompletionProvider.cs index ab7135ef8..605bceb85 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Chat/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Chat/ChatCompletionProvider.cs @@ -1,6 +1,9 @@ using Azure; using BotSharp.Abstraction.Files.Utilities; using BotSharp.Abstraction.Hooks; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Infrastructures.Streams; +using BotSharp.Core.Observables.Queues; using OpenAI.Chat; using System.ClientModel; @@ -203,39 +206,133 @@ public async Task GetChatCompletionsAsync(Agent agent, return true; } - public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { var client = ProviderHelper.GetClient(Provider, _model, _services); var chatClient = client.GetChatClient(_model); var (prompt, messages, options) = PrepareOptions(agent, conversations); - var response = chatClient.CompleteChatStreamingAsync(messages, options); + var hub = _services.GetRequiredService>(); + var messageId = conversations.LastOrDefault()?.MessageId ?? string.Empty; - await foreach (var choice in response) + var contentHooks = _services.GetHooks(agent.Id); + // Before chat completion hook + foreach (var hook in contentHooks) + { + await hook.BeforeGenerating(agent, conversations); + } + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "BeforeReceiveLlmStreamMessage", + Data = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + } + }); + + using var textStream = new RealtimeTextStream(); + var toolCalls = new List(); + ChatTokenUsage? tokenUsage = null; + + var responseMessage = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + + await foreach (var choice in chatClient.CompleteChatStreamingAsync(messages, options)) { - if (choice.FinishReason == ChatFinishReason.FunctionCall || choice.FinishReason == ChatFinishReason.ToolCalls) + tokenUsage = choice.Usage; + + if (!choice.ToolCallUpdates.IsNullOrEmpty()) + { + toolCalls.AddRange(choice.ToolCallUpdates); + } + + if (!choice.ContentUpdate.IsNullOrEmpty()) { - var update = choice.ToolCallUpdates?.FirstOrDefault()?.FunctionArgumentsUpdate?.ToString() ?? string.Empty; - Console.Write(update); + var text = choice.ContentUpdate[0]?.Text ?? string.Empty; + textStream.Collect(text); - await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, update) +#if DEBUG + _logger.LogCritical($"Content update: {text}"); +#endif + + var content = new RoleDialogModel(AgentRole.Assistant, text) { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + CurrentAgentId = agent.Id, + MessageId = messageId + }; + hub.Push(new() + { + ServiceProvider = _services, + EventName = "OnReceiveLlmStreamMessage", + Data = content }); - continue; } - if (choice.ContentUpdate.IsNullOrEmpty()) continue; + if (choice.FinishReason == ChatFinishReason.ToolCalls || choice.FinishReason == ChatFinishReason.FunctionCall) + { + var meta = toolCalls.FirstOrDefault(x => !string.IsNullOrEmpty(x.FunctionName)); + var functionName = meta?.FunctionName; + var toolCallId = meta?.ToolCallId; + var args = toolCalls.Where(x => x.FunctionArgumentsUpdate != null).Select(x => x.FunctionArgumentsUpdate.ToString()).ToList(); + var functionArgument = string.Join(string.Empty, args); - _logger.LogInformation(choice.ContentUpdate[0]?.Text); +#if DEBUG + _logger.LogCritical($"Tool Call (id: {toolCallId}) => {functionName}({functionArgument})"); +#endif - await onMessageReceived(new RoleDialogModel(choice.Role?.ToString() ?? ChatMessageRole.Assistant.ToString(), choice.ContentUpdate[0]?.Text ?? string.Empty) + responseMessage = new RoleDialogModel(AgentRole.Function, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + ToolCallId = toolCallId, + FunctionName = functionName, + FunctionArgs = functionArgument + }; + } + else if (choice.FinishReason.HasValue) { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + var allText = textStream.GetText(); + _logger.LogCritical($"Text Content: {allText}"); + + responseMessage = new RoleDialogModel(AgentRole.Assistant, allText) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + IsStreaming = true + }; + } + } + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "AfterReceiveLlmStreamMessage", + Data = responseMessage + }); + + + var inputTokenDetails = tokenUsage?.InputTokenDetails; + // After chat completion hook + foreach (var hook in contentHooks) + { + await hook.AfterGenerated(responseMessage, new TokenStatsModel + { + Prompt = prompt, + Provider = Provider, + Model = _model, + TextInputTokens = (tokenUsage?.InputTokenCount ?? 0) - (inputTokenDetails?.CachedTokenCount ?? 0), + CachedTextInputTokens = inputTokenDetails?.CachedTokenCount ?? 0, + TextOutputTokens = tokenUsage?.OutputTokenCount ?? 0 }); } - return true; + return responseMessage; } protected (string, IEnumerable, ChatCompletionOptions) PrepareOptions(Agent agent, List conversations) diff --git a/src/Plugins/BotSharp.Plugin.ChatHub/ChatHubPlugin.cs b/src/Plugins/BotSharp.Plugin.ChatHub/ChatHubPlugin.cs index 725655fce..d3c3f2acb 100644 --- a/src/Plugins/BotSharp.Plugin.ChatHub/ChatHubPlugin.cs +++ b/src/Plugins/BotSharp.Plugin.ChatHub/ChatHubPlugin.cs @@ -1,5 +1,9 @@ using BotSharp.Abstraction.Crontab; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Observables.Queues; using BotSharp.Plugin.ChatHub.Hooks; +using BotSharp.Plugin.ChatHub.Observers; +using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.Configuration; namespace BotSharp.Plugin.ChatHub; @@ -7,7 +11,7 @@ namespace BotSharp.Plugin.ChatHub; /// /// The dialogue channel connects users, AI assistants and customer service representatives. /// -public class ChatHubPlugin : IBotSharpPlugin +public class ChatHubPlugin : IBotSharpPlugin, IBotSharpAppPlugin { public string Id => "6e52d42d-1e23-406b-8599-36af36c83209"; public string Name => "Chat Hub"; @@ -28,4 +32,12 @@ public void RegisterDI(IServiceCollection services, IConfiguration config) services.AddScoped(); services.AddScoped(); } + + public void Configure(IApplicationBuilder app) + { + var services = app.ApplicationServices; + var queue = services.GetRequiredService>(); + var logger = services.GetRequiredService>>(); + queue.Events.Subscribe(new ChatHubObserver(logger)); + } } diff --git a/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/ChatHubConversationHook.cs b/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/ChatHubConversationHook.cs index a41abe874..7bc4600af 100644 --- a/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/ChatHubConversationHook.cs +++ b/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/ChatHubConversationHook.cs @@ -118,6 +118,7 @@ public override async Task OnResponseGenerated(RoleDialogModel message) RichContent = message.SecondaryRichContent ?? message.RichContent, Data = message.Data, States = state.GetStates(), + IsStreaming = message.IsStreaming, Sender = new() { FirstName = "AI", @@ -133,7 +134,11 @@ public override async Task OnResponseGenerated(RoleDialogModel message) SenderAction = SenderActionEnum.TypingOff }; - await GenerateSenderAction(conv.ConversationId, action); + if (!message.IsStreaming) + { + await GenerateSenderAction(conv.ConversationId, action); + } + await ReceiveAssistantMessage(conv.ConversationId, json); await base.OnResponseGenerated(message); } diff --git a/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/StreamingLogHook.cs b/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/StreamingLogHook.cs index 720f3fd0a..b499438e7 100644 --- a/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/StreamingLogHook.cs +++ b/src/Plugins/BotSharp.Plugin.ChatHub/Hooks/StreamingLogHook.cs @@ -74,7 +74,7 @@ public override async Task OnPostbackMessageReceived(RoleDialogModel message, Po var log = $"{GetMessageContent(message)}"; var replyContent = JsonSerializer.Serialize(replyMsg, _options.JsonSerializerOptions); - log += $"\r\n```json\r\n{replyContent}\r\n```"; + log += $"\r\n\r\n```json\r\n{replyContent}\r\n```"; var input = new ContentLogInputModel(conversationId, message) { @@ -233,7 +233,7 @@ public override async Task OnResponseGenerated(RoleDialogModel message) if (message.RichContent != null || message.SecondaryRichContent != null) { var richContent = JsonSerializer.Serialize(message.SecondaryRichContent ?? message.RichContent, _localJsonOptions); - log += $"\r\n```json\r\n{richContent}\r\n```"; + log += $"\r\n\r\n```json\r\n{richContent}\r\n```"; } var input = new ContentLogInputModel(conv.ConversationId, message) diff --git a/src/Plugins/BotSharp.Plugin.ChatHub/Observers/ChatHubObserver.cs b/src/Plugins/BotSharp.Plugin.ChatHub/Observers/ChatHubObserver.cs new file mode 100644 index 000000000..97cb89a4e --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.ChatHub/Observers/ChatHubObserver.cs @@ -0,0 +1,163 @@ +using BotSharp.Abstraction.Conversations.Dtos; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Abstraction.SideCar; +using BotSharp.Plugin.ChatHub.Hooks; +using Microsoft.AspNetCore.SignalR; + +namespace BotSharp.Plugin.ChatHub.Observers; + +public class ChatHubObserver : IObserver +{ + private readonly ILogger _logger; + private IServiceProvider _services; + + private const string BEFORE_RECEIVE_LLM_STREAM_MESSAGE = "BeforeReceiveLlmStreamMessage"; + private const string ON_RECEIVE_LLM_STREAM_MESSAGE = "OnReceiveLlmStreamMessage"; + private const string AFTER_RECEIVE_LLM_STREAM_MESSAGE = "AfterReceiveLlmStreamMessage"; + private const string GENERATE_SENDER_ACTION = "OnSenderActionGenerated"; + + public ChatHubObserver(ILogger logger) + { + _logger = logger; + } + + public void OnCompleted() + { + _logger.LogWarning($"{nameof(ChatHubObserver)} receives complete notification."); + } + + public void OnError(Exception error) + { + _logger.LogError(error, $"{nameof(ChatHubObserver)} receives error notification: {error.Message}"); + } + + public void OnNext(HubObserveData value) + { + _services = value.ServiceProvider; + + if (!AllowSendingMessage()) return; + + var message = value.Data; + var model = new ChatResponseDto(); + if (value.EventName == BEFORE_RECEIVE_LLM_STREAM_MESSAGE) + { + var conv = _services.GetRequiredService(); + model = new ChatResponseDto() + { + ConversationId = conv.ConversationId, + MessageId = message.MessageId, + Text = string.Empty, + Sender = new() + { + FirstName = "AI", + LastName = "Assistant", + Role = AgentRole.Assistant + } + }; + + var action = new ConversationSenderActionModel + { + ConversationId = conv.ConversationId, + SenderAction = SenderActionEnum.TypingOn + }; + + GenerateSenderAction(conv.ConversationId, action).ConfigureAwait(false).GetAwaiter().GetResult(); + } + else if (value.EventName == AFTER_RECEIVE_LLM_STREAM_MESSAGE && message.IsStreaming) + { + var conv = _services.GetRequiredService(); + model = new ChatResponseDto() + { + ConversationId = conv.ConversationId, + MessageId = message.MessageId, + Text = message.Content, + Sender = new() + { + FirstName = "AI", + LastName = "Assistant", + Role = AgentRole.Assistant + } + }; + + var action = new ConversationSenderActionModel + { + ConversationId = conv.ConversationId, + SenderAction = SenderActionEnum.TypingOff + }; + + GenerateSenderAction(conv.ConversationId, action).ConfigureAwait(false).GetAwaiter().GetResult(); + } + else if (value.EventName == ON_RECEIVE_LLM_STREAM_MESSAGE) + { + var conv = _services.GetRequiredService(); + model = new ChatResponseDto() + { + ConversationId = conv.ConversationId, + MessageId = message.MessageId, + Text = !string.IsNullOrEmpty(message.SecondaryContent) ? message.SecondaryContent : message.Content, + Function = message.FunctionName, + RichContent = message.SecondaryRichContent ?? message.RichContent, + Data = message.Data, + Sender = new() + { + FirstName = "AI", + LastName = "Assistant", + Role = AgentRole.Assistant + } + }; + } + + OnReceiveAssistantMessage(value.EventName, model.ConversationId, model).ConfigureAwait(false).GetAwaiter().GetResult(); + } + + private async Task OnReceiveAssistantMessage(string @event, string conversationId, ChatResponseDto model) + { + try + { + var settings = _services.GetRequiredService(); + var chatHub = _services.GetRequiredService>(); + + if (settings.EventDispatchBy == EventDispatchType.Group) + { + await chatHub.Clients.Group(conversationId).SendAsync(@event, model); + } + else + { + var user = _services.GetRequiredService(); + await chatHub.Clients.User(user.Id).SendAsync(@event, model); + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, $"Failed to receive assistant message in {nameof(ChatHubConversationHook)} (conversation id: {conversationId})"); + } + } + + private bool AllowSendingMessage() + { + var sidecar = _services.GetService(); + return sidecar == null || !sidecar.IsEnabled(); + } + + private async Task GenerateSenderAction(string conversationId, ConversationSenderActionModel action) + { + try + { + var settings = _services.GetRequiredService(); + var chatHub = _services.GetRequiredService>(); + if (settings.EventDispatchBy == EventDispatchType.Group) + { + await chatHub.Clients.Group(conversationId).SendAsync(GENERATE_SENDER_ACTION, action); + } + else + { + var user = _services.GetRequiredService(); + await chatHub.Clients.User(user.Id).SendAsync(GENERATE_SENDER_ACTION, action); + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, $"Failed to generate sender action in {nameof(ChatHubConversationHook)} (conversation id: {conversationId})"); + } + } +} diff --git a/src/Plugins/BotSharp.Plugin.DeepSeekAI/BotSharp.Plugin.DeepSeekAI.csproj b/src/Plugins/BotSharp.Plugin.DeepSeekAI/BotSharp.Plugin.DeepSeekAI.csproj index 2f7e326ad..3f9a26ce0 100644 --- a/src/Plugins/BotSharp.Plugin.DeepSeekAI/BotSharp.Plugin.DeepSeekAI.csproj +++ b/src/Plugins/BotSharp.Plugin.DeepSeekAI/BotSharp.Plugin.DeepSeekAI.csproj @@ -15,7 +15,7 @@ - + diff --git a/src/Plugins/BotSharp.Plugin.DeepSeekAI/Providers/Chat/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.DeepSeekAI/Providers/Chat/ChatCompletionProvider.cs index eb14ac994..1335d4dc3 100644 --- a/src/Plugins/BotSharp.Plugin.DeepSeekAI/Providers/Chat/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.DeepSeekAI/Providers/Chat/ChatCompletionProvider.cs @@ -1,8 +1,11 @@ -using Microsoft.Extensions.Logging; -using OpenAI.Chat; using BotSharp.Abstraction.Files; -using BotSharp.Plugin.DeepSeek.Providers; using BotSharp.Abstraction.Hooks; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Infrastructures.Streams; +using BotSharp.Core.Observables.Queues; +using BotSharp.Plugin.DeepSeek.Providers; +using Microsoft.Extensions.Logging; +using OpenAI.Chat; namespace BotSharp.Plugin.DeepSeekAI.Providers.Chat; @@ -170,39 +173,133 @@ public async Task GetChatCompletionsAsync(Agent agent, List GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { var client = ProviderHelper.GetClient(Provider, _model, _services); var chatClient = client.GetChatClient(_model); var (prompt, messages, options) = PrepareOptions(agent, conversations); - var response = chatClient.CompleteChatStreamingAsync(messages, options); + var hub = _services.GetRequiredService>(); + var messageId = conversations.LastOrDefault()?.MessageId ?? string.Empty; + + var contentHooks = _services.GetHooks(agent.Id); + // Before chat completion hook + foreach (var hook in contentHooks) + { + await hook.BeforeGenerating(agent, conversations); + } - await foreach (var choice in response) + hub.Push(new() { - if (choice.FinishReason == ChatFinishReason.FunctionCall || choice.FinishReason == ChatFinishReason.ToolCalls) + ServiceProvider = _services, + EventName = "BeforeReceiveLlmStreamMessage", + Data = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + } + }); + + using var textStream = new RealtimeTextStream(); + var toolCalls = new List(); + ChatTokenUsage? tokenUsage = null; + + var responseMessage = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + + await foreach (var choice in chatClient.CompleteChatStreamingAsync(messages, options)) + { + tokenUsage = choice.Usage; + + if (!choice.ToolCallUpdates.IsNullOrEmpty()) + { + toolCalls.AddRange(choice.ToolCallUpdates); + } + + if (!choice.ContentUpdate.IsNullOrEmpty()) { - var update = choice.ToolCallUpdates?.FirstOrDefault()?.FunctionArgumentsUpdate?.ToString() ?? string.Empty; - _logger.LogInformation(update); + var text = choice.ContentUpdate[0]?.Text ?? string.Empty; + textStream.Collect(text); - await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, update) +#if DEBUG + _logger.LogCritical($"Content update: {text}"); +#endif + + var content = new RoleDialogModel(AgentRole.Assistant, text) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + hub.Push(new() { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + ServiceProvider = _services, + EventName = "OnReceiveLlmStreamMessage", + Data = content }); - continue; } - if (choice.ContentUpdate.IsNullOrEmpty()) continue; + if (choice.FinishReason == ChatFinishReason.ToolCalls || choice.FinishReason == ChatFinishReason.FunctionCall) + { + var meta = toolCalls.FirstOrDefault(x => !string.IsNullOrEmpty(x.FunctionName)); + var functionName = meta?.FunctionName; + var toolCallId = meta?.ToolCallId; + var args = toolCalls.Where(x => x.FunctionArgumentsUpdate != null).Select(x => x.FunctionArgumentsUpdate.ToString()).ToList(); + var functionArgument = string.Join(string.Empty, args); - _logger.LogInformation(choice.ContentUpdate[0]?.Text); +#if DEBUG + _logger.LogCritical($"Tool Call (id: {toolCallId}) => {functionName}({functionArgument})"); +#endif - await onMessageReceived(new RoleDialogModel(choice.Role?.ToString() ?? ChatMessageRole.Assistant.ToString(), choice.ContentUpdate[0]?.Text ?? string.Empty) + responseMessage = new RoleDialogModel(AgentRole.Function, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + ToolCallId = toolCallId, + FunctionName = functionName, + FunctionArgs = functionArgument + }; + } + else if (choice.FinishReason.HasValue) { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + var allText = textStream.GetText(); + _logger.LogCritical($"Text Content: {allText}"); + + responseMessage = new RoleDialogModel(AgentRole.Assistant, allText) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + IsStreaming = true + }; + } + } + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "AfterReceiveLlmStreamMessage", + Data = responseMessage + }); + + + var inputTokenDetails = tokenUsage?.InputTokenDetails; + // After chat completion hook + foreach (var hook in contentHooks) + { + await hook.AfterGenerated(responseMessage, new TokenStatsModel + { + Prompt = prompt, + Provider = Provider, + Model = _model, + TextInputTokens = (tokenUsage?.InputTokenCount ?? 0) - (inputTokenDetails?.CachedTokenCount ?? 0), + CachedTextInputTokens = inputTokenDetails?.CachedTokenCount ?? 0, + TextOutputTokens = tokenUsage?.OutputTokenCount ?? 0 }); } - return true; + return responseMessage; } public void SetModelName(string model) diff --git a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/GeminiChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/GeminiChatCompletionProvider.cs index 1ead28e65..1d67eac8a 100644 --- a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/GeminiChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/GeminiChatCompletionProvider.cs @@ -159,40 +159,9 @@ public async Task GetChatCompletionsAsync(Agent agent, List GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { - var client = ProviderHelper.GetGeminiClient(Provider, _model, _services); - var chatClient = client.CreateGenerativeModel(_model.ToModelId()); - var (prompt, messages) = PrepareOptions(chatClient,agent, conversations); - - var asyncEnumerable = chatClient.StreamContentAsync(messages); - - await foreach (var response in asyncEnumerable) - { - if (response.GetFunction() != null) - { - var func = response.GetFunction(); - var update = func?.Args?.ToJsonString().ToString() ?? string.Empty; - _logger.LogInformation(update); - - await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, update) - { - RenderedInstruction = string.Join("\r\n", renderedInstructions) - }); - continue; - } - - if (response.Text().IsNullOrEmpty()) continue; - - _logger.LogInformation(response.Text()); - - await onMessageReceived(new RoleDialogModel(response.Candidates?.LastOrDefault()?.Content?.Role?.ToString() ?? AgentRole.Assistant.ToString(), response.Text() ?? string.Empty) - { - RenderedInstruction = string.Join("\r\n", renderedInstructions) - }); - } - - return true; + throw new NotImplementedException(); } public void SetModelName(string model) diff --git a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/PalmChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/PalmChatCompletionProvider.cs index 72a47adc0..e992fd603 100644 --- a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/PalmChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Chat/PalmChatCompletionProvider.cs @@ -145,7 +145,7 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); } - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { throw new NotImplementedException(); } diff --git a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Realtime/RealTimeCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Realtime/RealTimeCompletionProvider.cs index 2e95cfa24..f262efd72 100644 --- a/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Realtime/RealTimeCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.GoogleAI/Providers/Realtime/RealTimeCompletionProvider.cs @@ -1,11 +1,12 @@ -using System.Threading; using BotSharp.Abstraction.Hooks; using BotSharp.Abstraction.Realtime.Models.Session; +using BotSharp.Core.Infrastructures.Streams; using BotSharp.Core.Session; using BotSharp.Plugin.GoogleAI.Models.Realtime; using GenerativeAI; using GenerativeAI.Types; using GenerativeAI.Types.Converters; +using System.Threading; namespace BotSharp.Plugin.GoogleAi.Providers.Realtime; @@ -33,8 +34,8 @@ public class GoogleRealTimeProvider : IRealTimeCompletion UnknownTypeHandling = JsonUnknownTypeHandling.JsonElement }; - private RealtimeTranscriptionResponse _inputStream = new(); - private RealtimeTranscriptionResponse _outputStream = new(); + private RealtimeTextStream _inputStream = new(); + private RealtimeTextStream _outputStream = new(); private bool _isBlocking = false; private RealtimeHubConnection _conn; diff --git a/src/Plugins/BotSharp.Plugin.HuggingFace/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.HuggingFace/Providers/ChatCompletionProvider.cs index 460677ba5..5a38b0a9c 100644 --- a/src/Plugins/BotSharp.Plugin.HuggingFace/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.HuggingFace/Providers/ChatCompletionProvider.cs @@ -76,9 +76,9 @@ public async Task GetChatCompletionsAsync(Agent agent, List GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { - return true; + throw new NotImplementedException(); } public void SetModelName(string model) diff --git a/src/Plugins/BotSharp.Plugin.LLamaSharp/BotSharp.Plugin.LLamaSharp.csproj b/src/Plugins/BotSharp.Plugin.LLamaSharp/BotSharp.Plugin.LLamaSharp.csproj index af7a6237a..80807a265 100644 --- a/src/Plugins/BotSharp.Plugin.LLamaSharp/BotSharp.Plugin.LLamaSharp.csproj +++ b/src/Plugins/BotSharp.Plugin.LLamaSharp/BotSharp.Plugin.LLamaSharp.csproj @@ -15,7 +15,7 @@ - + diff --git a/src/Plugins/BotSharp.Plugin.LLamaSharp/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.LLamaSharp/Providers/ChatCompletionProvider.cs index 45f41e95d..321b9aee8 100644 --- a/src/Plugins/BotSharp.Plugin.LLamaSharp/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.LLamaSharp/Providers/ChatCompletionProvider.cs @@ -1,5 +1,12 @@ using BotSharp.Abstraction.Agents; +using BotSharp.Abstraction.Hooks; using BotSharp.Abstraction.Loggers; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Infrastructures.Streams; +using BotSharp.Core.Observables.Queues; +using Microsoft.AspNetCore.SignalR; +using static LLama.Common.ChatHistory; +using static System.Net.Mime.MediaTypeNames; namespace BotSharp.Plugin.LLamaSharp.Providers; @@ -159,12 +166,8 @@ public async Task GetChatCompletionsAsync(Agent agent, return true; } - public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { - string totalResponse = ""; - var content = string.Join("\r\n", conversations.Select(x => $"{x.Role}: {x.Content}")).Trim(); - content += $"\r\n{AgentRole.Assistant}: "; - var state = _services.GetRequiredService(); var model = state.GetState("model", "llama-2-7b-chat.Q8_0"); @@ -180,13 +183,60 @@ public async Task GetChatCompletionsStreamingAsync(Agent agent, List>(); + var messageId = conversations.LastOrDefault()?.MessageId ?? string.Empty; + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "BeforeReceiveLlmStreamMessage", + Data = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + } + }); + + using var textStream = new RealtimeTextStream(); + var responseMessage = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + await foreach (var response in executor.InferAsync(agent.Instruction, inferenceParams)) { Console.Write(response); - totalResponse += response; + textStream.Collect(response); + + var content = new RoleDialogModel(AgentRole.Assistant, response) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + hub.Push(new() + { + ServiceProvider = _services, + EventName = "OnReceiveLlmStreamMessage", + Data = content + }); } - return true; + responseMessage = new RoleDialogModel(AgentRole.Assistant, textStream.GetText()) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + IsStreaming = true + }; + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "AfterReceiveLlmStreamMessage", + Data = responseMessage + }); + + return responseMessage; } public void SetModelName(string model) diff --git a/src/Plugins/BotSharp.Plugin.LangChain/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.LangChain/Providers/ChatCompletionProvider.cs index aa6b1175f..9e2aa1d46 100644 --- a/src/Plugins/BotSharp.Plugin.LangChain/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.LangChain/Providers/ChatCompletionProvider.cs @@ -65,7 +65,7 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); } - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { throw new NotImplementedException(); } diff --git a/src/Plugins/BotSharp.Plugin.MetaGLM/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.MetaGLM/Providers/ChatCompletionProvider.cs index c1d4bef1f..8b3ecea37 100644 --- a/src/Plugins/BotSharp.Plugin.MetaGLM/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.MetaGLM/Providers/ChatCompletionProvider.cs @@ -235,7 +235,7 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); } - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { throw new NotImplementedException(); } diff --git a/src/Plugins/BotSharp.Plugin.MicrosoftExtensionsAI/MicrosoftExtensionsAIChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.MicrosoftExtensionsAI/MicrosoftExtensionsAIChatCompletionProvider.cs index d95a25ae0..e510a9a24 100644 --- a/src/Plugins/BotSharp.Plugin.MicrosoftExtensionsAI/MicrosoftExtensionsAIChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.MicrosoftExtensionsAI/MicrosoftExtensionsAIChatCompletionProvider.cs @@ -169,8 +169,10 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); /// - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) => + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) + { throw new NotImplementedException(); + } private sealed class NopAIFunction(string name, string description, JsonElement schema) : AIFunction { diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Chat/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Chat/ChatCompletionProvider.cs index 17ee10124..314697c83 100644 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Chat/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Chat/ChatCompletionProvider.cs @@ -1,4 +1,10 @@ +using Azure; using BotSharp.Abstraction.Hooks; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Infrastructures.Streams; +using BotSharp.Core.Observables.Queues; +using BotSharp.Plugin.OpenAI.Models.Realtime; +using Fluid; using OpenAI.Chat; namespace BotSharp.Plugin.OpenAI.Providers.Chat; @@ -179,39 +185,133 @@ public async Task GetChatCompletionsAsync(Agent agent, return true; } - public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { var client = ProviderHelper.GetClient(Provider, _model, _services); var chatClient = client.GetChatClient(_model); var (prompt, messages, options) = PrepareOptions(agent, conversations); - var response = chatClient.CompleteChatStreamingAsync(messages, options); + var hub = _services.GetRequiredService>(); + var messageId = conversations.LastOrDefault()?.MessageId ?? string.Empty; - await foreach (var choice in response) + var contentHooks = _services.GetHooks(agent.Id); + // Before chat completion hook + foreach (var hook in contentHooks) + { + await hook.BeforeGenerating(agent, conversations); + } + + hub.Push(new() { - if (choice.FinishReason == ChatFinishReason.FunctionCall || choice.FinishReason == ChatFinishReason.ToolCalls) + ServiceProvider = _services, + EventName = "BeforeReceiveLlmStreamMessage", + Data = new RoleDialogModel(AgentRole.Assistant, string.Empty) { - var update = choice.ToolCallUpdates?.FirstOrDefault()?.FunctionArgumentsUpdate?.ToString() ?? string.Empty; - _logger.LogInformation(update); + CurrentAgentId = agent.Id, + MessageId = messageId + } + }); + + using var textStream = new RealtimeTextStream(); + var toolCalls = new List(); + ChatTokenUsage? tokenUsage = null; + + var responseMessage = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; - await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, update) + await foreach (var choice in chatClient.CompleteChatStreamingAsync(messages, options)) + { + tokenUsage = choice.Usage; + + if (!choice.ToolCallUpdates.IsNullOrEmpty()) + { + toolCalls.AddRange(choice.ToolCallUpdates); + } + + if (!choice.ContentUpdate.IsNullOrEmpty()) + { + var text = choice.ContentUpdate[0]?.Text ?? string.Empty; + textStream.Collect(text); + +#if DEBUG + _logger.LogCritical($"Content update: {text}"); +#endif + + var content = new RoleDialogModel(AgentRole.Assistant, text) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + hub.Push(new() { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + ServiceProvider = _services, + EventName = "OnReceiveLlmStreamMessage", + Data = content }); - continue; } - if (choice.ContentUpdate.IsNullOrEmpty()) continue; + if (choice.FinishReason == ChatFinishReason.ToolCalls || choice.FinishReason == ChatFinishReason.FunctionCall) + { + var meta = toolCalls.FirstOrDefault(x => !string.IsNullOrEmpty(x.FunctionName)); + var functionName = meta?.FunctionName; + var toolCallId = meta?.ToolCallId; + var args = toolCalls.Where(x => x.FunctionArgumentsUpdate != null).Select(x => x.FunctionArgumentsUpdate.ToString()).ToList(); + var functionArgument = string.Join(string.Empty, args); + +#if DEBUG + _logger.LogCritical($"Tool Call (id: {toolCallId}) => {functionName}({functionArgument})"); +#endif + + responseMessage = new RoleDialogModel(AgentRole.Function, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + ToolCallId = toolCallId, + FunctionName = functionName, + FunctionArgs = functionArgument + }; + } + else if (choice.FinishReason.HasValue) + { + var allText = textStream.GetText(); + _logger.LogCritical($"Text Content: {allText}"); + + responseMessage = new RoleDialogModel(AgentRole.Assistant, allText) + { + CurrentAgentId = agent.Id, + MessageId = messageId, + IsStreaming = true + }; + } + } + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "AfterReceiveLlmStreamMessage", + Data = responseMessage + }); - _logger.LogInformation(choice.ContentUpdate[0]?.Text); - await onMessageReceived(new RoleDialogModel(choice.Role?.ToString() ?? ChatMessageRole.Assistant.ToString(), choice.ContentUpdate[0]?.Text ?? string.Empty) + var inputTokenDetails = tokenUsage?.InputTokenDetails; + // After chat completion hook + foreach (var hook in contentHooks) + { + await hook.AfterGenerated(responseMessage, new TokenStatsModel { - RenderedInstruction = string.Join("\r\n", renderedInstructions) + Prompt = prompt, + Provider = Provider, + Model = _model, + TextInputTokens = (tokenUsage?.InputTokenCount ?? 0) - (inputTokenDetails?.CachedTokenCount ?? 0), + CachedTextInputTokens = inputTokenDetails?.CachedTokenCount ?? 0, + TextOutputTokens = tokenUsage?.OutputTokenCount ?? 0 }); } - return true; + return responseMessage; } @@ -412,4 +512,11 @@ public void SetModelName(string model) { _model = model; } +} + + +class ToolCallData +{ + public ChatFinishReason? Reason { get; set; } + public List ToolCalls { get; set; } = []; } \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.SemanticKernel/SemanticKernelChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.SemanticKernel/SemanticKernelChatCompletionProvider.cs index 3f742777d..57277399d 100644 --- a/src/Plugins/BotSharp.Plugin.SemanticKernel/SemanticKernelChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.SemanticKernel/SemanticKernelChatCompletionProvider.cs @@ -94,7 +94,7 @@ public Task GetChatCompletionsAsync(Agent agent, List con throw new NotImplementedException(); } /// - public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { throw new NotImplementedException(); } diff --git a/src/Plugins/BotSharp.Plugin.SparkDesk/BotSharp.Plugin.SparkDesk.csproj b/src/Plugins/BotSharp.Plugin.SparkDesk/BotSharp.Plugin.SparkDesk.csproj index 51f26f872..d49d41df4 100644 --- a/src/Plugins/BotSharp.Plugin.SparkDesk/BotSharp.Plugin.SparkDesk.csproj +++ b/src/Plugins/BotSharp.Plugin.SparkDesk/BotSharp.Plugin.SparkDesk.csproj @@ -15,7 +15,7 @@ - + diff --git a/src/Plugins/BotSharp.Plugin.SparkDesk/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.SparkDesk/Providers/ChatCompletionProvider.cs index ba0aa220f..035e03e0c 100644 --- a/src/Plugins/BotSharp.Plugin.SparkDesk/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.SparkDesk/Providers/ChatCompletionProvider.cs @@ -1,6 +1,10 @@ using BotSharp.Abstraction.Agents; using BotSharp.Abstraction.Agents.Enums; using BotSharp.Abstraction.Loggers; +using BotSharp.Abstraction.Observables.Models; +using BotSharp.Core.Infrastructures.Streams; +using BotSharp.Core.Observables.Queues; +using Microsoft.AspNetCore.SignalR; namespace BotSharp.Plugin.SparkDesk.Providers; @@ -143,34 +147,77 @@ public async Task GetChatCompletionsAsync(Agent agent, List GetChatCompletionsStreamingAsync(Agent agent, List conversations, Func onMessageReceived) + public async Task GetChatCompletionsStreamingAsync(Agent agent, List conversations) { var client = new SparkDeskClient(appId: _settings.AppId, apiKey: _settings.ApiKey, apiSecret: _settings.ApiSecret); var (prompt, messages, funcall) = PrepareOptions(agent, conversations); + var messageId = conversations.LastOrDefault()?.MessageId ?? string.Empty; + var hub = _services.GetRequiredService>(); + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "BeforeReceiveLlmStreamMessage", + Data = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + } + }); + + var responseMessage = new RoleDialogModel(AgentRole.Assistant, string.Empty) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; + + using var textStream = new RealtimeTextStream(); await foreach (StreamedChatResponse response in client.ChatAsStreamAsync(modelVersion: _settings.ModelVersion, messages, functions: funcall.Length == 0 ? null : funcall)) { - if (response.FunctionCall !=null) + if (response.FunctionCall != null) { - await onMessageReceived(new RoleDialogModel(AgentRole.Function, response.Text) - { + responseMessage = new RoleDialogModel(AgentRole.Function, string.Empty) + { CurrentAgentId = agent.Id, + MessageId = messageId, + ToolCallId = response.FunctionCall.Name, FunctionName = response.FunctionCall.Name, - FunctionArgs = response.FunctionCall.Arguments, - RenderedInstruction = string.Join("\r\n", renderedInstructions) - }); - continue; + FunctionArgs = response.FunctionCall.Arguments + }; } - - await onMessageReceived(new RoleDialogModel(AgentRole.Assistant, response.Text) + else { - CurrentAgentId = agent.Id, - RenderedInstruction = string.Join("\r\n", renderedInstructions) - }); - - } + textStream.Collect(response.Text); + responseMessage = new RoleDialogModel(AgentRole.Assistant, response.Text) + { + CurrentAgentId = agent.Id, + MessageId = messageId + }; - return true; + hub.Push(new() + { + ServiceProvider = _services, + EventName = "OnReceiveLlmStreamMessage", + Data = responseMessage + }); + } + } + + if (responseMessage.Role == AgentRole.Assistant) + { + responseMessage.Content = textStream.GetText(); + responseMessage.IsStreaming = true; + } + + hub.Push(new() + { + ServiceProvider = _services, + EventName = "AfterReceiveLlmStreamMessage", + Data = responseMessage + }); + + return responseMessage; } public void SetModelName(string model) diff --git a/tests/BotSharp.LLM.Tests/ChatCompletionTests.cs b/tests/BotSharp.LLM.Tests/ChatCompletionTests.cs index ee3e132a0..f6c06dc46 100644 --- a/tests/BotSharp.LLM.Tests/ChatCompletionTests.cs +++ b/tests/BotSharp.LLM.Tests/ChatCompletionTests.cs @@ -1,4 +1,4 @@ -using BotSharp.Abstraction.Agents.Enums; +using BotSharp.Abstraction.Agents.Enums; using BotSharp.Abstraction.Agents.Models; using BotSharp.Abstraction.Conversations.Models; using BotSharp.Abstraction.MLTasks; @@ -96,13 +96,15 @@ public async Task GetChatCompletionsAsync_Test(IChatCompletion chatCompletion, A public async Task GetChatCompletionsStreamingAsync_Test(IChatCompletion chatCompletion, Agent agent, string modelName) { chatCompletion.SetModelName(modelName); - var conversation = new List([new RoleDialogModel(AgentRole.User, "write a poem about stars")]); + RoleDialogModel reply = null; - var result = await chatCompletion.GetChatCompletionsStreamingAsync(agent,conversation, async (received) => + var messages = new List { - reply = received; - }); - result.ShouldBeTrue(); + new RoleDialogModel(AgentRole.User, "write a poem about stars") + }; + var result = await chatCompletion.GetChatCompletionsStreamingAsync(agent, messages); + + result.ShouldNotBeNull(); reply.ShouldNotBeNull(); reply.Content.ShouldNotBeNullOrEmpty(); }