@@ -416,45 +416,124 @@ public static IAwaitQuery<TResult> AwaitCompletion<T, TTaskResult, TResult>(
416
416
417
417
return
418
418
AwaitQuery . Create (
419
- options => _ ( options . MaxConcurrency ?? int . MaxValue ,
419
+ options => _ ( options . MaxConcurrency ,
420
420
options . Scheduler ?? TaskScheduler . Default ,
421
421
options . PreserveOrder ) ) ;
422
422
423
- IEnumerable < TResult > _ ( int maxConcurrency , TaskScheduler scheduler , bool ordered )
423
+ IEnumerable < TResult > _ ( int ? maxConcurrency , TaskScheduler scheduler , bool ordered )
424
424
{
425
+ // A separate task will enumerate the source and launch tasks.
426
+ // It will post all progress as notices to the collection below.
427
+ // A notice is essentially a discriminated union like:
428
+ //
429
+ // type Notice<'a, 'b> =
430
+ // | End
431
+ // | Result of (int * 'a * Task<'b>)
432
+ // | Error of ExceptionDispatchInfo
433
+ //
434
+ // Note that BlockingCollection.CompleteAdding is never used to
435
+ // to mark the end (which its own notice above) because
436
+ // BlockingCollection.Add throws if called after CompleteAdding
437
+ // and we want to deliberately tolerate the race condition.
438
+
425
439
var notices = new BlockingCollection < ( Notice , ( int , T , Task < TTaskResult > ) , ExceptionDispatchInfo ) > ( ) ;
426
- var cancellationTokenSource = new CancellationTokenSource ( ) ;
427
- var cancellationToken = cancellationTokenSource . Token ;
428
- var completed = false ;
429
440
430
- var enumerator =
431
- source . Index ( )
432
- . Select ( e => ( e . Key , Item : e . Value , Task : evaluator ( e . Value , cancellationToken ) ) )
433
- . GetEnumerator ( ) ;
441
+ var consumerCancellationTokenSource = new CancellationTokenSource ( ) ;
442
+ ( Exception , Exception ) lastCriticalErrors = default ;
443
+
444
+ void PostNotice ( Notice notice ,
445
+ ( int , T , Task < TTaskResult > ) item ,
446
+ Exception error )
447
+ {
448
+ // If a notice fails to post then assume critical error
449
+ // conditions (like low memory), capture the error without
450
+ // further allocation of resources and trip the cancellation
451
+ // token source used by the main loop waiting on notices.
452
+ // Note that only the "last" critical error is reported
453
+ // as maintaining a list would incur allocations. The idea
454
+ // here is to make a best effort attempt to report any of
455
+ // the error conditions that may be occuring, which is still
456
+ // better than nothing.
457
+
458
+ try
459
+ {
460
+ var edi = error != null
461
+ ? ExceptionDispatchInfo . Capture ( error )
462
+ : null ;
463
+ notices . Add ( ( notice , item , edi ) ) ;
464
+ }
465
+ catch ( Exception e )
466
+ {
467
+ // Don't use ExceptionDispatchInfo.Capture here to avoid
468
+ // inducing allocations if already under low memory
469
+ // conditions.
470
+
471
+ lastCriticalErrors = ( e , error ) ;
472
+ consumerCancellationTokenSource . Cancel ( ) ;
473
+ throw ;
474
+ }
475
+ }
476
+
477
+ var completed = false ;
478
+ var cancellationTokenSource = new CancellationTokenSource ( ) ;
434
479
480
+ var enumerator = source . Index ( ) . GetEnumerator ( ) ;
435
481
IDisposable disposable = enumerator ; // disables AccessToDisposedClosure warnings
436
482
437
483
try
438
484
{
485
+ var cancellationToken = cancellationTokenSource . Token ;
486
+
487
+ // Fire-up a parallel loop to iterate through the source and
488
+ // launch tasks, posting a result-notice as each task
489
+ // completes and another, an end-notice, when all tasks have
490
+ // completed.
491
+
439
492
Task . Factory . StartNew (
440
- ( ) =>
441
- CollectToAsync (
442
- enumerator ,
443
- e => e . Task ,
444
- notices ,
445
- ( e , r ) => ( Notice . Result , ( e . Key , e . Item , e . Task ) , default ) ,
446
- ex => ( Notice . Error , default , ExceptionDispatchInfo . Capture ( ex ) ) ,
447
- ( Notice . End , default , default ) ,
448
- maxConcurrency , cancellationTokenSource ) ,
493
+ async ( ) =>
494
+ {
495
+ try
496
+ {
497
+ await enumerator . StartAsync (
498
+ e => evaluator ( e . Value , cancellationToken ) ,
499
+ ( e , r ) => PostNotice ( Notice . Result , ( e . Key , e . Value , r ) , default ) ,
500
+ ( ) => PostNotice ( Notice . End , default , default ) ,
501
+ maxConcurrency , cancellationToken ) ;
502
+ }
503
+ catch ( Exception e )
504
+ {
505
+ PostNotice ( Notice . Error , default , e ) ;
506
+ }
507
+ } ,
449
508
CancellationToken . None ,
450
509
TaskCreationOptions . DenyChildAttach ,
451
510
scheduler ) ;
452
511
512
+ // Remainder here is the main loop that waits for and
513
+ // processes notices.
514
+
453
515
var nextKey = 0 ;
454
516
var holds = ordered ? new List < ( int , T , Task < TTaskResult > ) > ( ) : null ;
455
517
456
- foreach ( var ( kind , result , error ) in notices . GetConsumingEnumerable ( ) )
518
+ using ( var notice = notices . GetConsumingEnumerable ( consumerCancellationTokenSource . Token )
519
+ . GetEnumerator ( ) )
520
+ while ( true )
457
521
{
522
+ try
523
+ {
524
+ if ( ! notice . MoveNext ( ) )
525
+ break ;
526
+ }
527
+ catch ( OperationCanceledException e ) when ( e . CancellationToken == consumerCancellationTokenSource . Token )
528
+ {
529
+ var ( error1 , error2 ) = lastCriticalErrors ;
530
+ throw new Exception ( "One or more critical errors have occurred." ,
531
+ error2 != null ? new AggregateException ( error1 , error2 )
532
+ : new AggregateException ( error1 ) ) ;
533
+ }
534
+
535
+ var ( kind , result , error ) = notice . Current ;
536
+
458
537
if ( kind == Notice . Error )
459
538
error . Throw ( ) ;
460
539
@@ -531,149 +610,76 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
531
610
}
532
611
}
533
612
534
- enum Notice { Result , Error , End }
535
-
536
- static async Task CollectToAsync < T , TResult , TNotice > (
537
- this IEnumerator < T > e ,
538
- Func < T , Task < TResult > > taskSelector ,
539
- BlockingCollection < TNotice > collection ,
540
- Func < T , Task < TResult > , TNotice > completionNoticeSelector ,
541
- Func < Exception , TNotice > errorNoticeSelector ,
542
- TNotice endNotice ,
543
- int maxConcurrency ,
544
- CancellationTokenSource cancellationTokenSource )
613
+ enum Notice { End , Result , Error }
614
+
615
+ static async Task StartAsync < T , TResult > (
616
+ this IEnumerator < T > enumerator ,
617
+ Func < T , Task < TResult > > starter ,
618
+ Action < T , Task < TResult > > onTaskCompletion ,
619
+ Action onEnd ,
620
+ int ? maxConcurrency ,
621
+ CancellationToken cancellationToken )
545
622
{
546
- Reader < T > reader = null ;
623
+ if ( enumerator == null ) throw new ArgumentNullException ( nameof ( enumerator ) ) ;
624
+ if ( starter == null ) throw new ArgumentNullException ( nameof ( starter ) ) ;
625
+ if ( onTaskCompletion == null ) throw new ArgumentNullException ( nameof ( onTaskCompletion ) ) ;
626
+ if ( onEnd == null ) throw new ArgumentNullException ( nameof ( onEnd ) ) ;
627
+ if ( maxConcurrency < 1 ) throw new ArgumentOutOfRangeException ( nameof ( maxConcurrency ) ) ;
547
628
548
- try
629
+ using ( enumerator )
549
630
{
550
- reader = new Reader < T > ( e ) ;
551
-
552
- var cancellationToken = cancellationTokenSource . Token ;
553
- var cancellationTaskSource = new TaskCompletionSource < bool > ( ) ;
554
- cancellationToken . Register ( ( ) => cancellationTaskSource . TrySetResult ( true ) ) ;
631
+ var pendingCount = 1 ; // terminator
555
632
556
- var tasks = new List < ( T Item , Task < TResult > Task ) > ( ) ;
557
-
558
- for ( var i = 0 ; i < maxConcurrency ; i ++ )
633
+ void OnPendingCompleted ( )
559
634
{
560
- if ( ! reader . TryRead ( out var item ) )
561
- break ;
562
- tasks . Add ( ( item , taskSelector ( item ) ) ) ;
635
+ if ( Interlocked . Decrement ( ref pendingCount ) == 0 )
636
+ onEnd ( ) ;
563
637
}
564
638
565
- while ( tasks . Count > 0 )
639
+ var concurrencyGate = maxConcurrency is int count
640
+ ? new ConcurrencyGate ( count )
641
+ : ConcurrencyGate . Unbounded ;
642
+
643
+ while ( enumerator . MoveNext ( ) )
566
644
{
567
- // Task.WaitAny is synchronous and blocking but allows the
568
- // waiting to be cancelled via a CancellationToken.
569
- // Task.WhenAny can be awaited so it is better since the
570
- // thread won't be blocked and can return to the pool.
571
- // However, it doesn't support cancellation so instead a
572
- // task is built on top of the CancellationToken that
573
- // completes when the CancellationToken trips.
574
- //
575
- // Also, Task.WhenAny returns the task (Task) object that
576
- // completed but task objects may not be unique due to
577
- // caching, e.g.:
578
- //
579
- // async Task<bool> Foo() => true;
580
- // async Task<bool> Bar() => true;
581
- // var foo = Foo();
582
- // var bar = Bar();
583
- // var same = foo.Equals(bar); // == true
584
- //
585
- // In this case, the task returned by Task.WhenAny will
586
- // match `foo` and `bar`:
587
- //
588
- // var done = Task.WhenAny(foo, bar);
589
- //
590
- // Logically speaking, the uniqueness of a task does not
591
- // matter but here it does, especially when Await (the main
592
- // user of CollectAsync) needs to return results ordered.
593
- // Fortunately, we compose our own task on top of the
594
- // original that links each item with the task result and as
595
- // a consequence generate new and unique task objects.
596
-
597
- var completedTask = await
598
- Task . WhenAny ( tasks . Select ( it => ( Task ) it . Task ) . Concat ( cancellationTaskSource . Task ) )
599
- . ConfigureAwait ( continueOnCapturedContext : false ) ;
600
-
601
- if ( completedTask == cancellationTaskSource . Task )
645
+ try
602
646
{
603
- // Cancellation during the wait means the enumeration
604
- // has been stopped by the user so the results of the
605
- // remaining tasks are no longer needed. Those tasks
606
- // should cancel as a result of sharing the same
607
- // cancellation token and provided that they passed it
608
- // on to any downstream asynchronous operations. Either
609
- // way, this loop is done so exit hard here.
610
-
611
- return ;
647
+ await concurrencyGate . EnterAsync ( cancellationToken ) ;
612
648
}
613
-
614
- var i = tasks . FindIndex ( it => it . Task . Equals ( completedTask ) ) ;
615
-
649
+ catch ( OperationCanceledException e ) when ( e . CancellationToken == cancellationToken )
616
650
{
617
- var ( item , task ) = tasks [ i ] ;
618
- tasks . RemoveAt ( i ) ;
651
+ return ;
652
+ }
619
653
620
- // Await the task rather than using its result directly
621
- // to avoid having the task's exception bubble up as
622
- // AggregateException if the task failed.
654
+ Interlocked . Increment ( ref pendingCount ) ;
623
655
624
- collection . Add ( completionNoticeSelector ( item , task ) ) ;
625
- }
656
+ var item = enumerator . Current ;
657
+ var task = starter ( item ) ;
626
658
627
- {
628
- if ( reader . TryRead ( out var item ) )
629
- tasks . Add ( ( item , taskSelector ( item ) ) ) ;
630
- }
631
- }
659
+ // Add a continutation that notifies completion of the task,
660
+ // along with the necessary housekeeping, in case it
661
+ // completes before maximum concurrency is reached.
632
662
633
- collection . Add ( endNotice ) ;
634
- }
635
- catch ( Exception ex )
636
- {
637
- cancellationTokenSource . Cancel ( ) ;
638
- collection . Add ( errorNoticeSelector ( ex ) ) ;
639
- }
640
- finally
641
- {
642
- reader ? . Dispose ( ) ;
643
- }
663
+ #pragma warning disable 4014 // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/compiler-messages/cs4014
644
664
645
- collection . CompleteAdding ( ) ;
646
- }
665
+ task . ContinueWith ( cancellationToken : cancellationToken ,
666
+ continuationOptions : TaskContinuationOptions . ExecuteSynchronously ,
667
+ scheduler : TaskScheduler . Current ,
668
+ continuationAction : t =>
669
+ {
670
+ concurrencyGate . Exit ( ) ;
647
671
648
- sealed class Reader < T > : IDisposable
649
- {
650
- IEnumerator < T > _enumerator ;
672
+ if ( cancellationToken . IsCancellationRequested )
673
+ return ;
651
674
652
- public Reader ( IEnumerator < T > enumerator ) =>
653
- _enumerator = enumerator ;
675
+ onTaskCompletion ( item , t ) ;
676
+ OnPendingCompleted ( ) ;
677
+ } ) ;
654
678
655
- public bool TryRead ( out T item )
656
- {
657
- var ended = false ;
658
- if ( _enumerator == null || ( ended = ! _enumerator . MoveNext ( ) ) )
659
- {
660
- if ( ended )
661
- Dispose ( ) ;
662
- item = default ;
663
- return false ;
679
+ #pragma warning restore 4014
664
680
}
665
681
666
- item = _enumerator . Current ;
667
- return true ;
668
- }
669
-
670
- public void Dispose ( )
671
- {
672
- var e = _enumerator ;
673
- if ( e == null )
674
- return ;
675
- _enumerator = null ;
676
- e . Dispose ( ) ;
682
+ OnPendingCompleted ( ) ;
677
683
}
678
684
}
679
685
@@ -720,6 +726,53 @@ static class TupleComparer<T1, T2, T3>
720
726
public static readonly IComparer < ( T1 , T2 , T3 ) > Item3 =
721
727
Comparer < ( T1 , T2 , T3 ) > . Create ( ( x , y ) => Comparer < T3 > . Default . Compare ( x . Item3 , y . Item3 ) ) ;
722
728
}
729
+
730
+ static class CompletedTask
731
+ {
732
+ #if NET451 || NETSTANDARD1_0
733
+
734
+ public static readonly Task Instance ;
735
+
736
+ static CompletedTask ( )
737
+ {
738
+ var tcs = new TaskCompletionSource < object > ( ) ;
739
+ tcs . SetResult ( null ) ;
740
+ Instance = tcs . Task ;
741
+ }
742
+
743
+ #else
744
+
745
+ public static readonly Task Instance = Task . CompletedTask ;
746
+
747
+ #endif
748
+ }
749
+
750
+ sealed class ConcurrencyGate
751
+ {
752
+ public static readonly ConcurrencyGate Unbounded = new ConcurrencyGate ( ) ;
753
+
754
+ readonly SemaphoreSlim _semaphore ;
755
+
756
+ ConcurrencyGate ( SemaphoreSlim semaphore = null ) =>
757
+ _semaphore = semaphore ;
758
+
759
+ public ConcurrencyGate ( int max ) :
760
+ this ( new SemaphoreSlim ( max , max ) ) { }
761
+
762
+ public Task EnterAsync ( CancellationToken token )
763
+ {
764
+ if ( _semaphore == null )
765
+ {
766
+ token . ThrowIfCancellationRequested ( ) ;
767
+ return CompletedTask . Instance ;
768
+ }
769
+
770
+ return _semaphore . WaitAsync ( token ) ;
771
+ }
772
+
773
+ public void Exit ( ) =>
774
+ _semaphore ? . Release ( ) ;
775
+ }
723
776
}
724
777
}
725
778
0 commit comments