diff --git a/src/Worker/Core/DurableTaskWorkerOptions.cs b/src/Worker/Core/DurableTaskWorkerOptions.cs index bc0c838d..bbf19584 100644 --- a/src/Worker/Core/DurableTaskWorkerOptions.cs +++ b/src/Worker/Core/DurableTaskWorkerOptions.cs @@ -128,7 +128,7 @@ public DataConverter DataConverter /// or entity) can be processed concurrently by the worker. It is recommended to set these values based on the /// expected workload and the resources available on the machine running the worker. /// - public ConcurrencyOptions Concurrency { get; } = new(); + public ConcurrencyOptions Concurrency { get; init; } = new(); /// /// Gets or sets the versioning options for the Durable Task worker. diff --git a/src/Worker/Grpc/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs b/src/Worker/Grpc/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs index 715d2177..e0e14d71 100644 --- a/src/Worker/Grpc/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs +++ b/src/Worker/Grpc/DependencyInjection/DurableTaskWorkerBuilderExtensions.cs @@ -20,7 +20,9 @@ public static class DurableTaskWorkerBuilderExtensions /// Note: only 1 instance of gRPC worker is supported per sidecar. /// public static IDurableTaskWorkerBuilder UseGrpc(this IDurableTaskWorkerBuilder builder) - => builder.UseGrpc(opt => { }); + => builder.UseGrpc(opt => + { + }); /// /// Configures the to be a gRPC client. diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs index fe9c3df5..33f6c478 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Globalization; using System.Text; +using Dapr.DurableTask.Abstractions; +using Dapr.DurableTask.Entities; +using Dapr.DurableTask.Worker.Shims; using DurableTask.Core; using DurableTask.Core.Entities; using DurableTask.Core.Entities.OperationFormat; using DurableTask.Core.History; -using Dapr.DurableTask.Abstractions; -using Dapr.DurableTask.Entities; -using Dapr.DurableTask.Worker.Shims; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using static Dapr.DurableTask.Protobuf.TaskHubSidecarService; -using DTCore = DurableTask.Core; using P = Dapr.DurableTask.Protobuf; namespace Dapr.DurableTask.Worker.Grpc; @@ -20,103 +20,153 @@ namespace Dapr.DurableTask.Worker.Grpc; /// /// The gRPC Durable Task worker. /// -sealed partial class GrpcDurableTaskWorker +partial class GrpcDurableTaskWorker { - class Processor + class Processor(GrpcDurableTaskWorker worker, TaskHubSidecarServiceClient client) { - static readonly Google.Protobuf.WellKnownTypes.Empty EmptyMessage = new(); - - readonly GrpcDurableTaskWorker worker; - readonly TaskHubSidecarServiceClient client; - readonly DurableTaskShimFactory shimFactory; - readonly GrpcDurableTaskWorkerOptions.InternalOptions internalOptions; - - public Processor(GrpcDurableTaskWorker worker, TaskHubSidecarServiceClient client) - { - this.worker = worker; - this.client = client; - this.shimFactory = new DurableTaskShimFactory(this.worker.grpcOptions, this.worker.loggerFactory); - this.internalOptions = this.worker.grpcOptions.Internal; - } + readonly DurableTaskShimFactory shimFactory = new(worker.grpcOptions, worker.loggerFactory); + readonly GrpcDurableTaskWorkerOptions.InternalOptions internalOptions = worker.grpcOptions.Internal; - ILogger Logger => this.worker.logger; + ILogger Logger => worker.logger; public async Task ExecuteAsync(CancellationToken cancellation) { + int reconnectAttempt = 0; + var connectionStartTime = DateTime.UtcNow; + while (!cancellation.IsCancellationRequested) { + // Use the worker's CreateCallOptions method to ensure consistent settings + // This ensures no deadline is set for unlimited connection time + var callOptions = worker.CreateCallOptions(cancellation); + this.Logger.ConfiguringGrpcCallOptions(); + try { - AsyncServerStreamingCall stream = await this.ConnectAsync(cancellation); - await this.ProcessWorkItemsAsync(stream, cancellation); - } - catch (RpcException) when (cancellation.IsCancellationRequested) - { - // Worker is shutting down - let the method exit gracefully - break; + if (reconnectAttempt > 0) + { + this.Logger.StartingReconnectAttempt(reconnectAttempt); + } + + connectionStartTime = DateTime.UtcNow; + var workerConcurrencyOptions = worker.workerOptions.Concurrency; + + // Establish connection once and let gRPC handle reconnections + this.Logger.OpeningTaskStream(); + using var stream = client.GetWorkItems( + new P.GetWorkItemsRequest + { + MaxConcurrentActivityWorkItems = + workerConcurrencyOptions.MaximumConcurrentActivityWorkItems, + MaxConcurrentEntityWorkItems = workerConcurrencyOptions.MaximumConcurrentEntityWorkItems, + MaxConcurrentOrchestrationWorkItems = + workerConcurrencyOptions.MaximumConcurrentOrchestrationWorkItems, + Capabilities = { P.WorkerCapability.HistoryStreaming, }, + }, + callOptions); + + this.Logger.EstablishedWorkItemConnection(); + var lastActivityCheck = DateTime.UtcNow; + int workItemsProcessed = 0; + + // Process work items as they arrive + await foreach (var workItem in stream.ResponseStream.ReadAllAsync(cancellation)) + { + workItemsProcessed++; + DateTime lastActivityTimestamp = DateTime.UtcNow; + this.Logger.ReceivedWorkItem(workItem.RequestCase.ToString(), lastActivityTimestamp); + + // Each work item is processed in its own background task + await this.ProcessWorkItemAsync(workItem, cancellation); + + // Periodically log connection stats for long-running connections + var timeSinceLastCheck = DateTime.UtcNow - lastActivityCheck; + if (timeSinceLastCheck > TimeSpan.FromMinutes(5)) + { + var now = DateTime.UtcNow; + string connectionDuration = + (now - connectionStartTime).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture); + string timeSinceLastActivity = + (now - lastActivityTimestamp).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture); + this.Logger.ConnectionStats(connectionDuration, timeSinceLastActivity, workItemsProcessed); + lastActivityCheck = DateTime.UtcNow; + } + } + + // Stream ended without error - this is unusual but not necessarily an error + this.Logger.StreamEndedGracefully( + (DateTime.UtcNow - connectionStartTime).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture)); + + // Reset reconnect attempt counter on clean exit + reconnectAttempt = 0; + + // Brief pause before reconnecting + await Task.Delay(TimeSpan.FromSeconds(1), cancellation); } - catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled) + catch (OperationCanceledException ex) when (cancellation.IsCancellationRequested) { - // Sidecar is shutting down - retry - this.Logger.SidecarDisconnected(); + // Normal shutdown - exit peacefully + this.Logger.CancellationRequested(ex.Message); + throw; } catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) { - // Sidecar is down - keep retrying - this.Logger.SidecarUnavailable(); - } - catch (RpcException ex) when (ex.StatusCode == StatusCode.NotFound) - { - // We retry on a NotFound for several reasons: - // 1. It was the existing behavior through the UnexpectedError path. - // 2. A 404 can be returned for a missing task hub or authentication failure. Authentication takes - // time to propagate so we should retry instead of making the user restart the application. - // 3. In some cases, a task hub can be created separately from the scheduler. If a worker is deployed - // between the scheduler and task hub, it would need to be restarted to function. - this.Logger.TaskHubNotFound(); + // Attempt to reconnect with an exponential backoff + reconnectAttempt++; + string connectionDuration = + (DateTime.UtcNow - connectionStartTime).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture); + + this.Logger.SidecarUnavailableWithDetails(connectionDuration, ex.Status, ex.StatusCode, ex.Message); + + // Add a backoff delay that increase with reocnnect attempts + int delaySeconds = Math.Min(30, (int)Math.Pow(2, Math.Min(reconnectAttempt, 5))); + this.Logger.ReconnectionDelay(delaySeconds, reconnectAttempt + 1); + + await Task.Delay(TimeSpan.FromSeconds(delaySeconds), cancellation); } - catch (OperationCanceledException) when (cancellation.IsCancellationRequested) + catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled) { - // Shutting down, lets exit gracefully. - break; + reconnectAttempt++; + string connectionDuration = + (DateTime.UtcNow - connectionStartTime).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture); + + this.Logger.GrpcCallCancelled(connectionDuration, ex.Status, ex.StatusCode, ex.Message); + + // Add a brief delay before reconnecting + await Task.Delay(TimeSpan.FromSeconds(1), cancellation); } catch (Exception ex) { - // Unknown failure - retry? - this.Logger.UnexpectedError(ex, string.Empty); - } + reconnectAttempt++; + string connectionDuration = + (DateTime.UtcNow - connectionStartTime).ToString(@"hh\:mm\:ss", CultureInfo.InvariantCulture); - try - { - // CONSIDER: Exponential backoff - await Task.Delay(TimeSpan.FromSeconds(5), cancellation); + this.Logger.GrpcCallUnexpectedError(connectionDuration, ex.GetType().Name, ex.Message, ex); + + // Add a brief delay before reconnecting + await Task.Delay(TimeSpan.FromSeconds(1), cancellation); } - catch (OperationCanceledException) when (cancellation.IsCancellationRequested) + finally { - // Worker is shutting down - let the method exit gracefully - break; + if (cancellation.IsCancellationRequested) + { + this.Logger.CancellationRequested($"Cancellation handled at {nameof(this.ExecuteAsync)} in {nameof(Processor)}"); + } } } + + this.Logger.CancellationRequested($"Cancellation requested at {nameof(this.ExecuteAsync)} within {nameof(Processor)}"); } - static string GetActionsListForLogging(IReadOnlyList actions) + static string GetActionsListForLogging(IReadOnlyList actions) => actions.Count switch { - if (actions.Count == 0) - { - return string.Empty; - } - else if (actions.Count == 1) - { - return actions[0].OrchestratorActionTypeCase.ToString(); - } - else - { - // Returns something like "ScheduleTask x5, CreateTimer x1,..." - return string.Join(", ", actions - .GroupBy(a => a.OrchestratorActionTypeCase) - .Select(group => $"{group.Key} x{group.Count()}")); - } - } + 0 => string.Empty, + 1 => actions[0].OrchestratorActionTypeCase.ToString(), + _ => string.Join( + ", ", + actions.GroupBy(a => a.OrchestratorActionTypeCase) + .Select(group => $"{group.Key} x{group.Count()}")), + }; static P.TaskFailureDetails? EvaluateOrchestrationVersioning(DurableTaskWorkerOptions.VersioningOptions? versioning, string orchestrationVersion, out bool versionCheckFailed) { @@ -174,6 +224,57 @@ static string GetActionsListForLogging(IReadOnlyList actio return failureDetails; } + /// + /// Process work items with simpler error handling. + /// + /// The work item to process. + /// Cancellation token. + async Task ProcessWorkItemAsync(P.WorkItem workItem, CancellationToken cancellationToken) + { + // Handle different work item types with straightforward logic + switch (workItem.RequestCase) + { + case P.WorkItem.RequestOneofCase.OrchestratorRequest: + await this.OnRunOrchestratorAsync( + workItem.OrchestratorRequest, + workItem.CompletionToken, + cancellationToken); + break; + case P.WorkItem.RequestOneofCase.ActivityRequest: + this.RunBackgroundTask( + workItem, + () => this.OnRunActivityAsync( + workItem.ActivityRequest, + workItem.CompletionToken, + cancellationToken)); + break; + case P.WorkItem.RequestOneofCase.EntityRequest: + this.RunBackgroundTask( + workItem, + () => this.OnRunEntityBatchAsync( + workItem.EntityRequest.ToEntityBatchRequest(), + cancellationToken)); + break; + case P.WorkItem.RequestOneofCase.EntityRequestV2: + workItem.EntityRequestV2.ToEntityBatchRequest( + out EntityBatchRequest batchRequest, + out List operationInfos); + + this.RunBackgroundTask( + workItem, + () => this.OnRunEntityBatchAsync( + batchRequest, + cancellationToken, + workItem.CompletionToken, + operationInfos)); + break; + case P.WorkItem.RequestOneofCase.HealthPing: + default: + this.Logger.UnexpectedWorkItemType(workItem.RequestCase.ToString()); + break; + } + } + async ValueTask BuildRuntimeStateAsync( P.OrchestratorRequest orchestratorRequest, ProtoUtils.EntityConversionState? entityConversionState, @@ -195,7 +296,7 @@ async ValueTask BuildRuntimeStateAsync( }; using AsyncServerStreamingCall streamResponse = - this.client.StreamInstanceHistory(streamRequest, cancellationToken: cancellation); + client.StreamInstanceHistory(streamRequest, cancellationToken: cancellation); await foreach (P.HistoryChunk chunk in streamResponse.ResponseStream.ReadAllAsync(cancellation)) { @@ -227,101 +328,6 @@ async ValueTask BuildRuntimeStateAsync( return runtimeState; } - async Task> ConnectAsync(CancellationToken cancellation) - { - await this.client!.HelloAsync(EmptyMessage, cancellationToken: cancellation); - this.Logger.EstablishedWorkItemConnection(); - - DurableTaskWorkerOptions workerOptions = this.worker.workerOptions; - - // Get the stream for receiving work-items - return this.client!.GetWorkItems( - new P.GetWorkItemsRequest - { - MaxConcurrentActivityWorkItems = - workerOptions.Concurrency.MaximumConcurrentActivityWorkItems, - MaxConcurrentOrchestrationWorkItems = - workerOptions.Concurrency.MaximumConcurrentOrchestrationWorkItems, - MaxConcurrentEntityWorkItems = - workerOptions.Concurrency.MaximumConcurrentEntityWorkItems, - Capabilities = { P.WorkerCapability.HistoryStreaming }, - }, - cancellationToken: cancellation); - } - - async Task ProcessWorkItemsAsync(AsyncServerStreamingCall stream, CancellationToken cancellation) - { - // Create a new token source for timing out and a final token source that keys off of them both. - // The timeout token is used to detect when we are no longer getting any messages, including health checks. - // If this is the case, it signifies the connection has been dropped silently and we need to reconnect. - using var timeoutSource = new CancellationTokenSource(); - timeoutSource.CancelAfter(TimeSpan.FromSeconds(60)); - using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellation, timeoutSource.Token); - - await foreach (P.WorkItem workItem in stream.ResponseStream.ReadAllAsync(cancellationToken: cancellation)) - { - timeoutSource.CancelAfter(TimeSpan.FromSeconds(60)); - if (workItem.RequestCase == P.WorkItem.RequestOneofCase.OrchestratorRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunOrchestratorAsync( - workItem.OrchestratorRequest, - workItem.CompletionToken, - cancellation)); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.ActivityRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunActivityAsync( - workItem.ActivityRequest, - workItem.CompletionToken, - cancellation)); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunEntityBatchAsync(workItem.EntityRequest.ToEntityBatchRequest(), cancellation)); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequestV2) - { - workItem.EntityRequestV2.ToEntityBatchRequest( - out EntityBatchRequest batchRequest, - out List operationInfos); - - this.RunBackgroundTask( - workItem, - () => this.OnRunEntityBatchAsync( - batchRequest, - cancellation, - workItem.CompletionToken, - operationInfos)); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.HealthPing) - { - // No-op - } - else - { - this.Logger.UnexpectedWorkItemType(workItem.RequestCase.ToString()); - } - } - - if (tokenSource.IsCancellationRequested || tokenSource.Token.IsCancellationRequested) - { - // The token has cancelled, this means either: - // 1. The broader 'cancellation' was triggered, return here to start a graceful shutdown. - // 2. The timeoutSource was triggered, return here to trigger a reconnect to the backend. - if (!cancellation.IsCancellationRequested) - { - // Since the cancellation came from the timeout, log a warning. - this.Logger.ConnectionTimeout(); - } - } - } - void RunBackgroundTask(P.WorkItem? workItem, Func handler) { // TODO: is Task.Run appropriate here? Should we have finer control over the tasks and their threads? @@ -352,7 +358,7 @@ async Task OnRunOrchestratorAsync( CancellationToken cancellationToken) { OrchestratorExecutionResult? result = null; - P.TaskFailureDetails? failureDetails = null; + P.TaskFailureDetails? failureDetails; TaskName name = new("(unknown)"); ProtoUtils.EntityConversionState? entityConversionState = @@ -360,7 +366,7 @@ async Task OnRunOrchestratorAsync( ? new(this.internalOptions.InsertEntityUnlocksOnCompletion) : null; - DurableTaskWorkerOptions.VersioningOptions? versioning = this.worker.workerOptions.Versioning; + DurableTaskWorkerOptions.VersioningOptions? versioning = worker.workerOptions.Versioning; bool versionFailure = false; try { @@ -383,15 +389,15 @@ async Task OnRunOrchestratorAsync( runtimeState.PastEvents.Count, runtimeState.NewEvents.Count); - await using AsyncServiceScope scope = this.worker.services.CreateAsyncScope(); - if (this.worker.Factory.TryCreateOrchestrator( + await using AsyncServiceScope scope = worker.services.CreateAsyncScope(); + if (worker.Factory.TryCreateOrchestrator( name, scope.ServiceProvider, out ITaskOrchestrator? orchestrator)) { // Both the factory invocation and the ExecuteAsync could involve user code and need to be handled // as part of try/catch. ParentOrchestrationInstance? parent = runtimeState.ParentInstance switch { - ParentInstance p => new(new(p.Name), p.OrchestrationInstance.InstanceId), + { } p => new(new(p.Name), p.OrchestrationInstance.InstanceId), _ => null, }; @@ -457,7 +463,7 @@ async Task OnRunOrchestratorAsync( else { this.Logger.AbandoningOrchestrationDueToVersioning(request.InstanceId, completionToken); - await this.client.AbandonTaskOrchestratorWorkItemAsync( + await client.AbandonTaskOrchestratorWorkItemAsync( new P.AbandonOrchestrationTaskRequest { CompletionToken = completionToken, @@ -494,7 +500,7 @@ await this.client.AbandonTaskOrchestratorWorkItemAsync( response.Actions.Count, GetActionsListForLogging(response.Actions)); - await this.client.CompleteOrchestratorTaskAsync(response, cancellationToken: cancellationToken); + await client.CompleteOrchestratorTaskAsync(response, cancellationToken: cancellationToken); } async Task OnRunActivityAsync(P.ActivityRequest request, string completionToken, CancellationToken cancellation) @@ -511,8 +517,8 @@ async Task OnRunActivityAsync(P.ActivityRequest request, string completionToken, P.TaskFailureDetails? failureDetails = null; try { - await using AsyncServiceScope scope = this.worker.services.CreateAsyncScope(); - if (this.worker.Factory.TryCreateActivity(name, scope.ServiceProvider, out ITaskActivity? activity)) + await using AsyncServiceScope scope = worker.services.CreateAsyncScope(); + if (worker.Factory.TryCreateActivity(name, scope.ServiceProvider, out ITaskActivity? activity)) { // Both the factory invocation and the RunAsync could involve user code and need to be handled as // part of try/catch. @@ -557,7 +563,7 @@ async Task OnRunActivityAsync(P.ActivityRequest request, string completionToken, CompletionToken = completionToken, }; - await this.client.CompleteActivityTaskAsync(response, cancellationToken: cancellation); + await client.CompleteActivityTaskAsync(response, cancellationToken: cancellation); } async Task OnRunEntityBatchAsync( @@ -566,7 +572,7 @@ async Task OnRunEntityBatchAsync( string? completionToken = null, List? operationInfos = null) { - var coreEntityId = DTCore.Entities.EntityId.FromString(batchRequest.InstanceId!); + var coreEntityId = EntityId.FromString(batchRequest.InstanceId!); EntityId entityId = new(coreEntityId.Name, coreEntityId.Key); TaskName name = new(entityId.Name); @@ -575,8 +581,8 @@ async Task OnRunEntityBatchAsync( try { - await using AsyncServiceScope scope = this.worker.services.CreateAsyncScope(); - IDurableTaskFactory2 factory = (IDurableTaskFactory2)this.worker.Factory; + await using AsyncServiceScope scope = worker.services.CreateAsyncScope(); + IDurableTaskFactory2 factory = (IDurableTaskFactory2)worker.Factory; if (factory.TryCreateEntity(name, scope.ServiceProvider, out ITaskEntity? entity)) { @@ -623,7 +629,7 @@ async Task OnRunEntityBatchAsync( completionToken, operationInfos?.Take(batchResult.Results?.Count ?? 0)); - await this.client.CompleteEntityTaskAsync(response, cancellationToken: cancellation); + await client.CompleteEntityTaskAsync(response, cancellationToken: cancellation); } } } diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.cs index 768ed0ab..1c37f0a5 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Diagnostics; using Dapr.DurableTask.Worker.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -10,13 +11,14 @@ namespace Dapr.DurableTask.Worker.Grpc; /// /// The gRPC Durable Task worker. /// -sealed partial class GrpcDurableTaskWorker : DurableTaskWorker +partial class GrpcDurableTaskWorker : DurableTaskWorker { readonly GrpcDurableTaskWorkerOptions grpcOptions; readonly DurableTaskWorkerOptions workerOptions; readonly IServiceProvider services; readonly ILoggerFactory loggerFactory; readonly ILogger logger; + int reconnectAttempts; /// /// Initializes a new instance of the class. @@ -43,57 +45,135 @@ public GrpcDurableTaskWorker( this.logger = loggerFactory.CreateLogger("Dapr.DurableTask"); } - /// - protected override async Task ExecuteAsync(CancellationToken stoppingToken) + /// + /// Creates call options with appropriate settings for long-running connections. + /// + /// The cancellation token. + /// CallOptions configured for long-running connections. + internal CallOptions CreateCallOptions(CancellationToken cancellationToken) { - await using AsyncDisposable disposable = this.GetCallInvoker(out CallInvoker callInvoker, out string address); - this.logger.StartingTaskHubWorker(address); - await new Processor(this, new(callInvoker)).ExecuteAsync(stoppingToken); + // Create call options with NO deadline to ensure unlimited connection time + // This aligns with our channel settings for long-running connections + var options = new CallOptions(cancellationToken: cancellationToken); + + // By not setting a Deadline property, we ensure the connection can + // stay open indefinitely, which matches our channel settings + this.logger.ConfiguringGrpcCallOptions(); + + return options; } -#if NET6_0_OR_GREATER - static GrpcChannel GetChannel(string? address) + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - if (string.IsNullOrEmpty(address)) + while (!stoppingToken.IsCancellationRequested) { - address = "http://localhost:4001"; + try + { + // Reset reconnect counter when we start a new attempt + if (this.reconnectAttempts > 0) + { + this.logger.StartingReconnectAttempt(this.reconnectAttempts); + } + + await using AsyncDisposable disposable = + this.GetCallInvoker(out CallInvoker callInvoker, out string address); + this.logger.StartingTaskHubWorker(address); + + var stopwatch = Stopwatch.StartNew(); + await new Processor(this, new(callInvoker)).ExecuteAsync(stoppingToken); + stopwatch.Stop(); + + this.logger.TaskHubWorkerExited(stopwatch.ElapsedMilliseconds); + + // If we got here without an exception, break out of the retry loop + break; + } + catch (Exception ex) when (!stoppingToken.IsCancellationRequested) + { + this.reconnectAttempts++; + + // Log exception with detailed context + this.logger.TaskHubWorkerError(this.reconnectAttempts, ex.GetType().Name, ex.Message, ex); + + // Add a brief delay before retrying to avoid tight CPU-bound loops + await Task.Delay( + TimeSpan.FromSeconds(Math.Min(30, Math.Pow(2, Math.Min(this.reconnectAttempts, 5)))), + stoppingToken); + } + catch (Exception ex) + { + this.logger.UnexpectedError(ex, nameof(GrpcDurableTaskWorker)); + throw; + } } - return GrpcChannel.ForAddress(address); + this.logger.CancellationRequested($"Cancellation handled at {nameof(this.ExecuteAsync)} in {nameof(GrpcDurableTaskWorker)}"); } -#endif -#if NETSTANDARD2_0 static GrpcChannel GetChannel(string? address) { if (string.IsNullOrEmpty(address)) { - address = "localhost:4001"; + address = "http://localhost:4001"; } - return new(address, ChannelCredentials.Insecure); + // Create and configure the gRPC channel options for long-lived connections + var channelOptions = new GrpcChannelOptions + { + // No message size limit + MaxReceiveMessageSize = null, + + // Configure keep-alive settings to maintain long-lived connections + HttpHandler = new SocketsHttpHandler + { + // Enable keep-alive + KeepAlivePingPolicy = HttpKeepAlivePingPolicy.Always, + KeepAlivePingDelay = TimeSpan.FromSeconds(30), + KeepAlivePingTimeout = TimeSpan.FromSeconds(30), + + // Pooled connections are reused and won't time out from inactivity + EnableMultipleHttp2Connections = true, + + // Set a very long connection lifetime - this allows a controlled connection refresh strategy + PooledConnectionLifetime = TimeSpan.FromDays(1), + + // Disable idle timeout entirely + PooledConnectionIdleTimeout = Timeout.InfiniteTimeSpan, + }, + + DisposeHttpClient = true, + }; + + return GrpcChannel.ForAddress(address, channelOptions); } -#endif AsyncDisposable GetCallInvoker(out CallInvoker callInvoker, out string address) { - if (this.grpcOptions.Channel is GrpcChannel c) + if (this.grpcOptions.Channel is { } c) { + this.logger.GrpcChannelTarget(c.Target); callInvoker = c.CreateCallInvoker(); address = c.Target; return default; } - if (this.grpcOptions.CallInvoker is CallInvoker invoker) + if (this.grpcOptions.CallInvoker is { } invoker) { + this.logger.SelectGrpcCallInvoker(); callInvoker = invoker; address = "(unspecified)"; return default; } + this.logger.CreatingGrpcChannelForAddress(this.grpcOptions.Address); c = GetChannel(this.grpcOptions.Address); callInvoker = c.CreateCallInvoker(); address = c.Target; - return new AsyncDisposable(() => new(c.ShutdownAsync())); + return new AsyncDisposable(() => + { + this.logger.ShuttingDownGrpcChannel(c.Target); + return new(c.ShutdownAsync()); + }); } } diff --git a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs b/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs deleted file mode 100644 index 2349edfc..00000000 --- a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Dapr.DurableTask.Worker.Grpc.Internal; - -/// -/// Provides access to configuring internal options for the gRPC worker. -/// -public static class InternalOptionsExtensions -{ - /// - /// Configure the worker to use the default settings for connecting to the Azure Managed Durable Task service. - /// - /// The gRPC worker options. - /// - /// This is an internal API that supports the DurableTask infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new DurableTask release. - /// - public static void ConfigureForAzureManaged(this GrpcDurableTaskWorkerOptions options) - { - options.Internal.ConvertOrchestrationEntityEvents = true; - options.Internal.InsertEntityUnlocksOnCompletion = true; - } -} diff --git a/src/Worker/Grpc/Logs.cs b/src/Worker/Grpc/Logs.cs index 047f8ede..31dce7e8 100644 --- a/src/Worker/Grpc/Logs.cs +++ b/src/Worker/Grpc/Logs.cs @@ -16,8 +16,8 @@ static partial class Logs [LoggerMessage(EventId = 2, Level = LogLevel.Information, Message = "Durable Task gRPC worker has disconnected from gRPC server.")] public static partial void SidecarDisconnected(this ILogger logger); - [LoggerMessage(EventId = 3, Level = LogLevel.Information, Message = "The gRPC server for Durable Task gRPC worker is unavailable. Will continue retrying.")] - public static partial void SidecarUnavailable(this ILogger logger); + [LoggerMessage(EventId = 3, Level = LogLevel.Information, Message = "Sidecar unavailable after {connectionDuration}: {status} {statusCode} {message}")] + public static partial void SidecarUnavailableWithDetails(this ILogger logger, string connectionDuration, Status status, StatusCode statusCode, string message); [LoggerMessage(EventId = 4, Level = LogLevel.Information, Message = "Sidecar work-item streaming connection established.")] public static partial void EstablishedWorkItemConnection(this ILogger logger); @@ -57,5 +57,53 @@ static partial class Logs [LoggerMessage(EventId = 58, Level = LogLevel.Information, Message = "Abandoning orchestration. InstanceId = '{instanceId}'. Completion token = '{completionToken}'")] public static partial void AbandoningOrchestrationDueToVersioning(this ILogger logger, string instanceId, string completionToken); + + [LoggerMessage(EventId = 59, Level = LogLevel.Debug, Message = "Cancellation requested. Message: '{message}'")] + public static partial void CancellationRequested(this ILogger logger, string message); + + [LoggerMessage(EventId = 60, Level = LogLevel.Debug, Message = "Starting reconnection attempt #{attemptCount}")] + public static partial void StartingReconnectAttempt(this ILogger logger, int attemptCount); + + [LoggerMessage(EventId = 61, Level = LogLevel.Debug, Message = "Task hub worker exited after {elapsedTimeMs} ms")] + public static partial void TaskHubWorkerExited(this ILogger logger, long elapsedTimeMs); + + [LoggerMessage(EventId = 62, Level = LogLevel.Debug, Message = "Error in task hub worker, attempt #{reconnectionAttempts}: {exceptionType}: {exceptionMessage}")] + public static partial void TaskHubWorkerError(this ILogger logger, int reconnectionAttempts, string exceptionType, string exceptionMessage, Exception ex); + + [LoggerMessage(EventId = 63, Level = LogLevel.Debug, Message = "Using provided gRPC channel with target '{target}'")] + public static partial void GrpcChannelTarget(this ILogger logger, string target); + + [LoggerMessage(EventId = 64, Level = LogLevel.Debug, Message = "Using provided CallInvoker")] + public static partial void SelectGrpcCallInvoker(this ILogger logger); + + [LoggerMessage(EventId = 65, Level = LogLevel.Debug, Message = "Creating new gRPC channel for address '{address}'")] + public static partial void CreatingGrpcChannelForAddress(this ILogger logger, string address); + + [LoggerMessage(EventId = 66, Level = LogLevel.Debug, Message = "Shutting down gRPC channel for address '{address}'")] + public static partial void ShuttingDownGrpcChannel(this ILogger logger, string address); + + [LoggerMessage(EventId = 67, Level = LogLevel.Debug, Message = "Configuring gRPC call with no deadline constraint")] + public static partial void ConfiguringGrpcCallOptions(this ILogger logger); + + [LoggerMessage(EventId = 68, Level = LogLevel.Debug, Message = "Opening stream connection to get work items")] + public static partial void OpeningTaskStream(this ILogger logger); + + [LoggerMessage(EventId = 69, Level = LogLevel.Debug, Message = "Received work item of type '{workItemType}' at '{lastActivityTimestamp}'")] + public static partial void ReceivedWorkItem(this ILogger logger, string workItemType, DateTime lastActivityTimestamp); + + [LoggerMessage(EventId = 70, Level = LogLevel.Debug, Message = "Connection stats: Duration={connectionDuration}, LastActivity={timeSinceLastActivity}, WorkItemsProcessed={workItemsProcessed}")] + public static partial void ConnectionStats(this ILogger logger, string connectionDuration, string timeSinceLastActivity, int workItemsProcessed); + + [LoggerMessage(EventId = 71, Level = LogLevel.Warning, Message = "Work item stream ended gracefully after {connectionDuration}. This is unusual but not necessarily an error.")] + public static partial void StreamEndedGracefully(this ILogger logger, string connectionDuration); + + [LoggerMessage(EventId = 72, Level = LogLevel.Warning, Message = "gRPC call cancelled after {connectionDuration}: {status} {statusCode} {message}")] + public static partial void GrpcCallCancelled(this ILogger logger, string connectionDuration, Status status, StatusCode statusCode, string message); + + [LoggerMessage(EventId = 73, Level = LogLevel.Warning, Message = "Unexpected error in gRPC worker after {connectionDuration}: {exceptionType}: {exceptionMessage}")] + public static partial void GrpcCallUnexpectedError(this ILogger logger, string connectionDuration, string exceptionType, string exceptionMessage, Exception ex); + + [LoggerMessage(EventId = 74, Level = LogLevel.Debug, Message = "Waiting {delaySeconds} seconds before reconnection attempt #{reconnectAttempt}")] + public static partial void ReconnectionDelay(this ILogger logger, int delaySeconds, int reconnectAttempt); } } diff --git a/test/Worker/Core.Tests/Shims/DurableTaskShimFactoryTests.cs b/test/Worker/Core.Tests/Shims/DurableTaskShimFactoryTests.cs new file mode 100644 index 00000000..e67116e1 --- /dev/null +++ b/test/Worker/Core.Tests/Shims/DurableTaskShimFactoryTests.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core.Entities; +using Dapr.DurableTask.Entities; +using Dapr.DurableTask.Worker.Shims; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Dapr.DurableTask.Worker.Tests.Shims; + +public class DurableTaskShimFactoryTests +{ + [Fact] + public void Constructor_WithNullParameters_UsesDefaultValues() + { + // Act + var factory = new DurableTaskShimFactory(null, null); + + // Assert - No exception means success + Assert.NotNull(factory); + } + + [Fact] + public void Default_Property_ReturnsNonNullInstance() + { + // Act + var factory = DurableTaskShimFactory.Default; + + // Assert + Assert.NotNull(factory); + } + + [Fact] + public void CreateActivity_WithValidParameters_ReturnsTaskActivity() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestActivity"); + var mockActivity = new Mock(); + + // Act + var result = factory.CreateActivity(taskName, mockActivity.Object); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateActivity_WithDefaultName_ThrowsArgumentException() + { + // Arrange + var factory = DurableTaskShimFactory.Default; + var taskName = default(TaskName); + var mockActivity = new Mock(); + + // Act & Assert + Assert.Throws(() => factory.CreateActivity(taskName, mockActivity.Object)); + } + + [Fact] + public void CreateActivity_WithNullActivity_ThrowsArgumentNullException() + { + // Arrange + var factory = DurableTaskShimFactory.Default; + var taskName = new TaskName("TestActivity"); + ITaskActivity activity = null!; + + // Act & Assert + Assert.Throws(() => factory.CreateActivity(taskName, activity)); + } + + [Fact] + public void CreateActivity_WithGenericDelegate_ReturnsTaskActivity() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestActivity"); + Func> implementation = (ctx, input) => Task.FromResult(input); + + // Act + var result = factory.CreateActivity(taskName, implementation); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateOrchestration_WithValidParameters_ReturnsTaskOrchestration() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestOrchestration"); + var mockOrchestrator = new Mock(); + + // Act + var result = factory.CreateOrchestration(taskName, mockOrchestrator.Object); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateOrchestration_WithProperties_ReturnsTaskOrchestration() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestOrchestration"); + var mockOrchestrator = new Mock(); + var properties = new Dictionary { { "key", "value" } }; + + // Act + var result = factory.CreateOrchestration(taskName, mockOrchestrator.Object, properties); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateOrchestration_WithNullProperties_ThrowsArgumentNullException() + { + // Arrange + var factory = DurableTaskShimFactory.Default; + var taskName = new TaskName("TestOrchestration"); + var mockOrchestrator = new Mock(); + IReadOnlyDictionary properties = null!; + + // Act & Assert + Assert.Throws(() => factory.CreateOrchestration(taskName, mockOrchestrator.Object, properties)); + } + + [Fact] + public void CreateOrchestration_WithGenericDelegate_ReturnsTaskOrchestration() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestOrchestration"); + Func> implementation = (ctx, input) => Task.FromResult(input); + + // Act + var result = factory.CreateOrchestration(taskName, implementation); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateEntity_WithValidParameters_ReturnsTaskEntity() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var factory = new DurableTaskShimFactory(options, loggerFactory); + var taskName = new TaskName("TestEntity"); + var mockEntity = new Mock(); + var entityId = new EntityId("TestEntityType", "TestEntityKey"); + + // Act + var result = factory.CreateEntity(taskName, mockEntity.Object, entityId); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + } + + [Fact] + public void CreateEntity_WithNullEntity_ThrowsArgumentNullException() + { + // Arrange + var factory = DurableTaskShimFactory.Default; + var taskName = new TaskName("TestEntity"); + ITaskEntity entity = null!; + var entityId = new EntityId("TestEntityType", "TestEntityKey"); + + // Act & Assert + Assert.Throws(() => factory.CreateEntity(taskName, entity, entityId)); + } +} diff --git a/test/Worker/Core.Tests/Shims/JsonDataConverterShim.cs b/test/Worker/Core.Tests/Shims/JsonDataConverterShim.cs new file mode 100644 index 00000000..2c3b2f82 --- /dev/null +++ b/test/Worker/Core.Tests/Shims/JsonDataConverterShim.cs @@ -0,0 +1,105 @@ +namespace Dapr.DurableTask.Worker.Shims.Tests; + +public class JsonDataConverterShimTests +{ + [Fact] + public void Constructor_WithNullConverter_ThrowsArgumentNullException() + { + // Arrange & Act & Assert + Assert.Throws(() => new JsonDataConverterShim(null)); + } + + [Fact] + public void Serialize_ForwardsCallToInnerConverter() + { + // Arrange + var testObject = new { Name = "Test", Value = 123 }; + var expectedResult = "{\"Name\":\"Test\",\"Value\":123}"; + + var mockConverter = new Mock(); + mockConverter + .Setup(c => c.Serialize(testObject)) + .Returns(expectedResult); + + var shim = new JsonDataConverterShim(mockConverter.Object); + + // Act + var result = shim.Serialize(testObject); + + // Assert + Assert.Equal(expectedResult, result); + mockConverter.Verify(c => c.Serialize(testObject), Times.Once); + } + + [Fact] + public void SerializeWithFormatting_ForwardsCallToInnerConverter_IgnoresFormattingParameter() + { + // Arrange + var testObject = new { Name = "Test", Value = 123 }; + var expectedResult = "{\"Name\":\"Test\",\"Value\":123}"; + + var mockConverter = new Mock(); + mockConverter + .Setup(c => c.Serialize(testObject)) + .Returns(expectedResult); + + var shim = new JsonDataConverterShim(mockConverter.Object); + + // Act + var result = shim.Serialize(testObject, true); + + // Assert + Assert.Equal(expectedResult, result); + mockConverter.Verify(c => c.Serialize(testObject), Times.Once); + } + + [Fact] + public void Deserialize_ForwardsCallToInnerConverter() + { + // Arrange + var jsonData = "{\"Name\":\"Test\",\"Value\":123}"; + var expectedObject = new TestClass { Name = "Test", Value = 123 }; + + var mockConverter = new Mock(); + mockConverter + .Setup(c => c.Deserialize(jsonData, typeof(TestClass))) + .Returns(expectedObject); + + var shim = new JsonDataConverterShim(mockConverter.Object); + + // Act + var result = shim.Deserialize(jsonData, typeof(TestClass)); + + // Assert + Assert.Same(expectedObject, result); + mockConverter.Verify(c => c.Deserialize(jsonData, typeof(TestClass)), Times.Once); + } + + [Fact] + public void Deserialize_WithNullData_ForwardsCallToInnerConverter() + { + // Arrange + string jsonData = null; + TestClass? expectedObject = null; + + var mockConverter = new Mock(); + mockConverter + .Setup(c => c.Deserialize(jsonData, typeof(TestClass))) + .Returns(expectedObject); + + var shim = new JsonDataConverterShim(mockConverter.Object); + + // Act + var result = shim.Deserialize(jsonData, typeof(TestClass)); + + // Assert + Assert.Null(result); + mockConverter.Verify(c => c.Deserialize(jsonData, typeof(TestClass)), Times.Once); + } + + class TestClass + { + public string Name { get; set; } + public int Value { get; set; } + } +} \ No newline at end of file diff --git a/test/Worker/Core.Tests/Shims/TaskEntityShimTests.cs b/test/Worker/Core.Tests/Shims/TaskEntityShimTests.cs new file mode 100644 index 00000000..01f73aae --- /dev/null +++ b/test/Worker/Core.Tests/Shims/TaskEntityShimTests.cs @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core.Entities; +using DurableTask.Core.Entities.OperationFormat; +using Dapr.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Dapr.DurableTask.Worker.Shims; + +namespace Dapr.DurableTask.Worker.Tests.Shims; + +public class TaskEntityShimTests +{ + readonly Mock mockDataConverter; + readonly Mock mockTaskEntity; + readonly EntityId entityId; + readonly Mock mockLogger; + readonly TaskEntityShim shim; + + public TaskEntityShimTests() + { + // Setup common test dependencies + mockDataConverter = new Mock(); + mockTaskEntity = new Mock(); + entityId = new EntityId("TestEntity", "TestKey"); + mockLogger = new Mock(); + + // Create the shim with mocked dependencies + shim = new TaskEntityShim( + mockDataConverter.Object, + mockTaskEntity.Object, + entityId, + mockLogger.Object); + } + + [Fact] + public void Constructor_WithValidParameters_InitializesCorrectly() + { + // Arrange + var dataConverter = new Mock().Object; + var taskEntity = new Mock().Object; + var entityId = new EntityId("TestEntity", "TestKey"); + var logger = new Mock().Object; + + // Act - Create the shim + var shim = new TaskEntityShim(dataConverter, taskEntity, entityId, logger); + + // Assert - No exception means success + Assert.NotNull(shim); + } + + [Fact] + public void Constructor_WithNullDataConverter_ThrowsArgumentNullException() + { + // Arrange + DataConverter dataConverter = null!; + var taskEntity = new Mock().Object; + var entityId = new EntityId("TestEntity", "TestKey"); + var logger = new Mock().Object; + + // Act & Assert + Assert.Throws(() => + new TaskEntityShim(dataConverter, taskEntity, entityId, logger)); + } + + [Fact] + public void Constructor_WithNullTaskEntity_ThrowsArgumentNullException() + { + // Arrange + var dataConverter = new Mock().Object; + ITaskEntity taskEntity = null!; + var entityId = new EntityId("TestEntity", "TestKey"); + var logger = new Mock().Object; + + // Act & Assert + Assert.Throws(() => + new TaskEntityShim(dataConverter, taskEntity, entityId, logger)); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_WithSuccessfulOperation_ReturnsCorrectResults() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List + { + new OperationRequest { Operation = "TestOperation", Input = "testInput" } + } + }; + + object operationResult = "testResult"; + string serializedResult = "serializedResult"; + + mockDataConverter + .Setup(dc => dc.Serialize(operationResult)) + .Returns(serializedResult); + + mockTaskEntity + .Setup(te => te.RunAsync(It.IsAny())) + .ReturnsAsync(operationResult); + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Single(result.Results); + Assert.Equal(serializedResult, result.Results[0].Result); + Assert.Null(result.Results[0].FailureDetails); + + mockTaskEntity.Verify( + te => te.RunAsync(It.IsAny()), + Times.Once); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_WithFailingOperation_CapturesExceptionInResult() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List + { + new OperationRequest { Operation = "TestOperation", Input = "testInput" } + } + }; + + var expectedException = new InvalidOperationException("Test exception"); + + mockTaskEntity + .Setup(te => te.RunAsync(It.IsAny())) + .ThrowsAsync(expectedException); + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Single(result.Results); + Assert.Null(result.Results[0].Result); + Assert.NotNull(result.Results[0].FailureDetails); + Assert.Equal(expectedException.Message, result.Results[0].FailureDetails.ErrorMessage); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_WithMultipleOperations_ProcessesAllOperations() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List + { + new OperationRequest { Operation = "Operation1", Input = "input1" }, + new OperationRequest { Operation = "Operation2", Input = "input2" }, + new OperationRequest { Operation = "Operation3", Input = "input3" } + } + }; + + mockTaskEntity + .Setup(te => te.RunAsync(It.IsAny())) + .ReturnsAsync("result"); + + mockDataConverter + .Setup(dc => dc.Serialize(It.IsAny())) + .Returns("serializedResult"); + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Equal(3, result.Results.Count); + + mockTaskEntity.Verify( + te => te.RunAsync(It.IsAny()), + Times.Exactly(3)); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_WithMixedResults_HandlesSuccessAndFailure() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List + { + new OperationRequest { Operation = "SuccessOperation", Input = "input1" }, + new OperationRequest { Operation = "FailingOperation", Input = "input2" }, + new OperationRequest { Operation = "SuccessOperation", Input = "input3" } + } + }; + + var expectedException = new InvalidOperationException("Test exception"); + + mockTaskEntity + .SetupSequence(te => te.RunAsync(It.IsAny())) + .ReturnsAsync("result1") + .ThrowsAsync(expectedException) + .ReturnsAsync("result3"); + + mockDataConverter + .Setup(dc => dc.Serialize(It.IsAny())) + .Returns("serializedResult"); + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Equal(3, result.Results.Count); + + Assert.NotNull(result.Results[0].Result); + Assert.Null(result.Results[0].FailureDetails); + + Assert.Null(result.Results[1].Result); + Assert.NotNull(result.Results[1].FailureDetails); + Assert.Equal(expectedException.Message, result.Results[1].FailureDetails.ErrorMessage); + + Assert.NotNull(result.Results[2].Result); + Assert.Null(result.Results[2].FailureDetails); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_CommitsStateOnSuccessAndRollsBackOnFailure() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List + { + new OperationRequest { Operation = "SuccessOperation", Input = "input1" }, + new OperationRequest { Operation = "FailingOperation", Input = "input2" } + } + }; + + var expectedException = new InvalidOperationException("Test exception"); + + mockTaskEntity + .SetupSequence(te => te.RunAsync(It.IsAny())) + .ReturnsAsync("result1") + .ThrowsAsync(expectedException); + + mockDataConverter + .Setup(dc => dc.Serialize(It.IsAny())) + .Returns("serializedResult"); + + // We need to verify state commits/rollbacks occur correctly + // This requires inspecting the internal state which is challenging in tests + // Instead, we'll verify the correct operations are called and ensure result is as expected + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Equal(2, result.Results.Count); + + Assert.NotNull(result.Results[0].Result); + Assert.Null(result.Results[0].FailureDetails); + + Assert.Null(result.Results[1].Result); + Assert.NotNull(result.Results[1].FailureDetails); + } + + [Fact] + public async Task ExecuteOperationBatchAsync_WithEmptyOperations_ReturnsEmptyResults() + { + // Arrange + var operations = new EntityBatchRequest + { + EntityState = "initialState", + Operations = new List() + }; + + // Act + var result = await shim.ExecuteOperationBatchAsync(operations); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Results); + Assert.Empty(result.Results); + + mockTaskEntity.Verify( + te => te.RunAsync(It.IsAny()), + Times.Never); + } +} \ No newline at end of file diff --git a/test/Worker/Core.Tests/Shims/TestOrchestrationShimTests.cs b/test/Worker/Core.Tests/Shims/TestOrchestrationShimTests.cs new file mode 100644 index 00000000..accf13cc --- /dev/null +++ b/test/Worker/Core.Tests/Shims/TestOrchestrationShimTests.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using Microsoft.Extensions.Logging.Abstractions; +using Dapr.DurableTask.Worker.Shims; + +namespace Dapr.DurableTask.Worker.Tests.Shims; + +public class TaskOrchestrationShimTests +{ + [Fact] + public void Constructor_WithValidParameters_InitializesCorrectly() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var context = new OrchestrationInvocationContext( + new TaskName("TestOrchestration"), + options, + loggerFactory); + var mockOrchestrator = new Mock(); + + // Act - Create the shim + var shim = new TaskOrchestrationShim(context, mockOrchestrator.Object); + + // Assert - No exception means success + Assert.NotNull(shim); + } + + [Fact] + public void Constructor_WithProperties_InitializesCorrectly() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var context = new OrchestrationInvocationContext( + new TaskName("TestOrchestration"), + options, + loggerFactory); + var mockOrchestrator = new Mock(); + var properties = new Dictionary { { "key", "value" } }; + + // Act - Create the shim with properties + var shim = new TaskOrchestrationShim(context, mockOrchestrator.Object, properties); + + // Assert - No exception means success + Assert.NotNull(shim); + } + + [Fact] + public void Constructor_WithNullInvocationContext_ThrowsArgumentNullException() + { + // Arrange + OrchestrationInvocationContext context = null!; + var mockOrchestrator = new Mock(); + + // Act & Assert + Assert.Throws(() => new TaskOrchestrationShim(context, mockOrchestrator.Object)); + } + + [Fact] + public void Constructor_WithNullImplementation_ThrowsArgumentNullException() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var context = new OrchestrationInvocationContext( + new TaskName("TestOrchestration"), + options, + loggerFactory); + ITaskOrchestrator implementation = null!; + + // Act & Assert + Assert.Throws(() => new TaskOrchestrationShim(context, implementation)); + } + + [Fact] + public void Constructor_WithNullProperties_ThrowsArgumentNullException() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var context = new OrchestrationInvocationContext( + new TaskName("TestOrchestration"), + options, + loggerFactory); + var mockOrchestrator = new Mock(); + IReadOnlyDictionary properties = null!; + + // Act & Assert + Assert.Throws(() => + new TaskOrchestrationShim(context, mockOrchestrator.Object, properties)); + } + + [Fact] + public void GetStatus_ReturnsNull_WhenContextNotInitialized() + { + // Arrange + var options = new DurableTaskWorkerOptions(); + var loggerFactory = new NullLoggerFactory(); + var context = new OrchestrationInvocationContext( + new TaskName("TestOrchestration"), + options, + loggerFactory); + var mockOrchestrator = new Mock(); + + var shim = new TaskOrchestrationShim(context, mockOrchestrator.Object); + + // Act + string? status = shim.GetStatus(); + + // Assert + Assert.Null(status); + } +} \ No newline at end of file diff --git a/test/Worker/Grpc.Tests/GrpcChannelConfigurationTests.cs b/test/Worker/Grpc.Tests/GrpcChannelConfigurationTests.cs new file mode 100644 index 00000000..f7e8988f --- /dev/null +++ b/test/Worker/Grpc.Tests/GrpcChannelConfigurationTests.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net.Http; +using System.Reflection; +using Dapr.DurableTask.Worker.Grpc; +using Grpc.Core; +using Grpc.Net.Client; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Dapr.DurableTask.Worker.Grpc.Tests; + +public class GrpcChannelConfigurationTests +{ + [Fact] + public void GetChannel_ConfiguresSocketsHttpHandler_WithLongRunningConnectionSettings() + { + // Arrange + // Access the private static GetChannel method via reflection + var methodInfo = typeof(GrpcDurableTaskWorker).GetMethod("GetChannel", + BindingFlags.NonPublic | BindingFlags.Static); + + // Act + var channel = (GrpcChannel)methodInfo.Invoke(null, ["http://localhost:4001"]); + + // Get the HTTP handler via reflection (no public API to access it) + var handlerField = channel.GetType().GetField("_handler", BindingFlags.NonPublic | BindingFlags.Instance); + var handler = handlerField?.GetValue(channel) as HttpMessageHandler; + + // If we can't get to the actual handler through reflection, the test can't proceed + if (handler == null) + { + // This is not ideal, but the GrpcChannel class doesn't expose its handler publicly + return; + } + + // Try to get to the SocketsHttpHandler + var socketsHandler = GetSocketsHttpHandler(handler); + + // Assert + socketsHandler.Should().NotBeNull("channel should use SocketsHttpHandler"); + + if (socketsHandler is not null) + { + socketsHandler.KeepAlivePingPolicy.Should().Be(HttpKeepAlivePingPolicy.Always, + "keep-alive pings should be enabled"); + + socketsHandler.PooledConnectionIdleTimeout.Should().Be(Timeout.InfiniteTimeSpan, + "connections should never time out from inactivity"); + + socketsHandler.PooledConnectionLifetime.Should().Be(TimeSpan.FromDays(1), + "connections should have a controlled lifetime of 1 day"); + } + } + + // Helper method to get to the SocketsHttpHandler through potentially nested handlers + private static SocketsHttpHandler GetSocketsHttpHandler(HttpMessageHandler handler) + { + while (handler != null) + { + if (handler is SocketsHttpHandler socketsHandler) + { + return socketsHandler; + } + + // Try to get the inner handler if this is a delegating handler + var delegatingHandler = handler as DelegatingHandler; + handler = delegatingHandler?.InnerHandler; + } + + return null; + } +} diff --git a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs new file mode 100644 index 00000000..18784fe2 --- /dev/null +++ b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Dapr.DurableTask.Worker.Grpc.Tests; + +public class GrpcDurableTaskWorkerTests +{ + [Fact] + public async Task ExecuteAsync_CancellationRequested_ExitsCleanly() + { + // Arrange + var mockFactory = new Mock(); + var mockGrpcOptions = new Mock>(); + var mockWorkerOptions = new Mock>(); + var mockServiceProvider = new Mock(); + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + + mockLoggerFactory.Setup(f => f.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + mockGrpcOptions.Setup(o => o.Get(It.IsAny())).Returns(new GrpcDurableTaskWorkerOptions()); + mockWorkerOptions.Setup(o => o.Get(It.IsAny())).Returns(new DurableTaskWorkerOptions()); + + var cancellationTokenSource = new CancellationTokenSource(); + + // Create a test worker with the mocked dependencies + var worker = new GrpcDurableTaskWorker( + "TestWorker", + mockFactory.Object, + mockGrpcOptions.Object, + mockWorkerOptions.Object, + mockServiceProvider.Object, + mockLoggerFactory.Object); + + // Act + // Start the worker + var workerTask = worker.StartAsync(cancellationTokenSource.Token); + + // Immediately request cancellation + await cancellationTokenSource.CancelAsync(); + + // Wait for the worker to exit (with a timeout to prevent test hangs) + await Task.WhenAny(workerTask, Task.Delay(5000, cancellationTokenSource.Token)); + + // Assert + // The worker should have exited cleanly without throwing + workerTask.IsCompleted.Should().BeTrue("worker should exit after cancellation"); + await TestExtensions.Invoking(() => workerTask).Should().NotThrowAsync(); + } + + [Fact] + public void CreateCallOptions_HasNoDeadline() + { + // Arrange + var mockFactory = new Mock(); + var mockGrpcOptions = new Mock>(); + var mockWorkerOptions = new Mock>(); + var mockServiceProvider = new Mock(); + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + + mockLoggerFactory.Setup(f => f.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + mockGrpcOptions.Setup(o => o.Get(It.IsAny())).Returns(new GrpcDurableTaskWorkerOptions()); + mockWorkerOptions.Setup(o => o.Get(It.IsAny())).Returns(new DurableTaskWorkerOptions()); + + var worker = new GrpcDurableTaskWorker( + "TestWorker", + mockFactory.Object, + mockGrpcOptions.Object, + mockWorkerOptions.Object, + mockServiceProvider.Object, + mockLoggerFactory.Object); + + // Act + var options = worker.CreateCallOptions(CancellationToken.None); + + // Assert + // The CallOptions should have a null deadline + options.Deadline.Should().BeNull("Deadline should be null to allow unlimited connection time"); + } +} diff --git a/test/Worker/Grpc.Tests/TestExtensions.cs b/test/Worker/Grpc.Tests/TestExtensions.cs new file mode 100644 index 00000000..1a782c0d --- /dev/null +++ b/test/Worker/Grpc.Tests/TestExtensions.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Dapr.DurableTask.Worker.Grpc.Tests; + +public static class TestExtensions +{ + /// + /// Helper method for fluent assertions with async methods. + /// + public static Func Invoking(Func action) => action; +}