Skip to content

Commit 201dbb3

Browse files
authored
Fix AwaitCompletion to yield results during source iteration
Merge of PR #505 that closes #502
1 parent bad8004 commit 201dbb3

File tree

1 file changed

+194
-141
lines changed

1 file changed

+194
-141
lines changed

MoreLinq/Experimental/Await.cs

Lines changed: 194 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -416,45 +416,124 @@ public static IAwaitQuery<TResult> AwaitCompletion<T, TTaskResult, TResult>(
416416

417417
return
418418
AwaitQuery.Create(
419-
options => _(options.MaxConcurrency ?? int.MaxValue,
419+
options => _(options.MaxConcurrency,
420420
options.Scheduler ?? TaskScheduler.Default,
421421
options.PreserveOrder));
422422

423-
IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered)
423+
IEnumerable<TResult> _(int? maxConcurrency, TaskScheduler scheduler, bool ordered)
424424
{
425+
// A separate task will enumerate the source and launch tasks.
426+
// It will post all progress as notices to the collection below.
427+
// A notice is essentially a discriminated union like:
428+
//
429+
// type Notice<'a, 'b> =
430+
// | End
431+
// | Result of (int * 'a * Task<'b>)
432+
// | Error of ExceptionDispatchInfo
433+
//
434+
// Note that BlockingCollection.CompleteAdding is never used to
435+
// to mark the end (which its own notice above) because
436+
// BlockingCollection.Add throws if called after CompleteAdding
437+
// and we want to deliberately tolerate the race condition.
438+
425439
var notices = new BlockingCollection<(Notice, (int, T, Task<TTaskResult>), ExceptionDispatchInfo)>();
426-
var cancellationTokenSource = new CancellationTokenSource();
427-
var cancellationToken = cancellationTokenSource.Token;
428-
var completed = false;
429440

430-
var enumerator =
431-
source.Index()
432-
.Select(e => (e.Key, Item: e.Value, Task: evaluator(e.Value, cancellationToken)))
433-
.GetEnumerator();
441+
var consumerCancellationTokenSource = new CancellationTokenSource();
442+
(Exception, Exception) lastCriticalErrors = default;
443+
444+
void PostNotice(Notice notice,
445+
(int, T, Task<TTaskResult>) item,
446+
Exception error)
447+
{
448+
// If a notice fails to post then assume critical error
449+
// conditions (like low memory), capture the error without
450+
// further allocation of resources and trip the cancellation
451+
// token source used by the main loop waiting on notices.
452+
// Note that only the "last" critical error is reported
453+
// as maintaining a list would incur allocations. The idea
454+
// here is to make a best effort attempt to report any of
455+
// the error conditions that may be occuring, which is still
456+
// better than nothing.
457+
458+
try
459+
{
460+
var edi = error != null
461+
? ExceptionDispatchInfo.Capture(error)
462+
: null;
463+
notices.Add((notice, item, edi));
464+
}
465+
catch (Exception e)
466+
{
467+
// Don't use ExceptionDispatchInfo.Capture here to avoid
468+
// inducing allocations if already under low memory
469+
// conditions.
470+
471+
lastCriticalErrors = (e, error);
472+
consumerCancellationTokenSource.Cancel();
473+
throw;
474+
}
475+
}
476+
477+
var completed = false;
478+
var cancellationTokenSource = new CancellationTokenSource();
434479

480+
var enumerator = source.Index().GetEnumerator();
435481
IDisposable disposable = enumerator; // disables AccessToDisposedClosure warnings
436482

437483
try
438484
{
485+
var cancellationToken = cancellationTokenSource.Token;
486+
487+
// Fire-up a parallel loop to iterate through the source and
488+
// launch tasks, posting a result-notice as each task
489+
// completes and another, an end-notice, when all tasks have
490+
// completed.
491+
439492
Task.Factory.StartNew(
440-
() =>
441-
CollectToAsync(
442-
enumerator,
443-
e => e.Task,
444-
notices,
445-
(e, r) => (Notice.Result, (e.Key, e.Item, e.Task), default),
446-
ex => (Notice.Error, default, ExceptionDispatchInfo.Capture(ex)),
447-
(Notice.End, default, default),
448-
maxConcurrency, cancellationTokenSource),
493+
async () =>
494+
{
495+
try
496+
{
497+
await enumerator.StartAsync(
498+
e => evaluator(e.Value, cancellationToken),
499+
(e, r) => PostNotice(Notice.Result, (e.Key, e.Value, r), default),
500+
() => PostNotice(Notice.End, default, default),
501+
maxConcurrency, cancellationToken);
502+
}
503+
catch (Exception e)
504+
{
505+
PostNotice(Notice.Error, default, e);
506+
}
507+
},
449508
CancellationToken.None,
450509
TaskCreationOptions.DenyChildAttach,
451510
scheduler);
452511

512+
// Remainder here is the main loop that waits for and
513+
// processes notices.
514+
453515
var nextKey = 0;
454516
var holds = ordered ? new List<(int, T, Task<TTaskResult>)>() : null;
455517

456-
foreach (var (kind, result, error) in notices.GetConsumingEnumerable())
518+
using (var notice = notices.GetConsumingEnumerable(consumerCancellationTokenSource.Token)
519+
.GetEnumerator())
520+
while (true)
457521
{
522+
try
523+
{
524+
if (!notice.MoveNext())
525+
break;
526+
}
527+
catch (OperationCanceledException e) when (e.CancellationToken == consumerCancellationTokenSource.Token)
528+
{
529+
var (error1, error2) = lastCriticalErrors;
530+
throw new Exception("One or more critical errors have occurred.",
531+
error2 != null ? new AggregateException(error1, error2)
532+
: new AggregateException(error1));
533+
}
534+
535+
var (kind, result, error) = notice.Current;
536+
458537
if (kind == Notice.Error)
459538
error.Throw();
460539

@@ -531,149 +610,76 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
531610
}
532611
}
533612

534-
enum Notice { Result, Error, End }
535-
536-
static async Task CollectToAsync<T, TResult, TNotice>(
537-
this IEnumerator<T> e,
538-
Func<T, Task<TResult>> taskSelector,
539-
BlockingCollection<TNotice> collection,
540-
Func<T, Task<TResult>, TNotice> completionNoticeSelector,
541-
Func<Exception, TNotice> errorNoticeSelector,
542-
TNotice endNotice,
543-
int maxConcurrency,
544-
CancellationTokenSource cancellationTokenSource)
613+
enum Notice { End, Result, Error }
614+
615+
static async Task StartAsync<T, TResult>(
616+
this IEnumerator<T> enumerator,
617+
Func<T, Task<TResult>> starter,
618+
Action<T, Task<TResult>> onTaskCompletion,
619+
Action onEnd,
620+
int? maxConcurrency,
621+
CancellationToken cancellationToken)
545622
{
546-
Reader<T> reader = null;
623+
if (enumerator == null) throw new ArgumentNullException(nameof(enumerator));
624+
if (starter == null) throw new ArgumentNullException(nameof(starter));
625+
if (onTaskCompletion == null) throw new ArgumentNullException(nameof(onTaskCompletion));
626+
if (onEnd == null) throw new ArgumentNullException(nameof(onEnd));
627+
if (maxConcurrency < 1) throw new ArgumentOutOfRangeException(nameof(maxConcurrency));
547628

548-
try
629+
using (enumerator)
549630
{
550-
reader = new Reader<T>(e);
551-
552-
var cancellationToken = cancellationTokenSource.Token;
553-
var cancellationTaskSource = new TaskCompletionSource<bool>();
554-
cancellationToken.Register(() => cancellationTaskSource.TrySetResult(true));
631+
var pendingCount = 1; // terminator
555632

556-
var tasks = new List<(T Item, Task<TResult> Task)>();
557-
558-
for (var i = 0; i < maxConcurrency; i++)
633+
void OnPendingCompleted()
559634
{
560-
if (!reader.TryRead(out var item))
561-
break;
562-
tasks.Add((item, taskSelector(item)));
635+
if (Interlocked.Decrement(ref pendingCount) == 0)
636+
onEnd();
563637
}
564638

565-
while (tasks.Count > 0)
639+
var concurrencyGate = maxConcurrency is int count
640+
? new ConcurrencyGate(count)
641+
: ConcurrencyGate.Unbounded;
642+
643+
while (enumerator.MoveNext())
566644
{
567-
// Task.WaitAny is synchronous and blocking but allows the
568-
// waiting to be cancelled via a CancellationToken.
569-
// Task.WhenAny can be awaited so it is better since the
570-
// thread won't be blocked and can return to the pool.
571-
// However, it doesn't support cancellation so instead a
572-
// task is built on top of the CancellationToken that
573-
// completes when the CancellationToken trips.
574-
//
575-
// Also, Task.WhenAny returns the task (Task) object that
576-
// completed but task objects may not be unique due to
577-
// caching, e.g.:
578-
//
579-
// async Task<bool> Foo() => true;
580-
// async Task<bool> Bar() => true;
581-
// var foo = Foo();
582-
// var bar = Bar();
583-
// var same = foo.Equals(bar); // == true
584-
//
585-
// In this case, the task returned by Task.WhenAny will
586-
// match `foo` and `bar`:
587-
//
588-
// var done = Task.WhenAny(foo, bar);
589-
//
590-
// Logically speaking, the uniqueness of a task does not
591-
// matter but here it does, especially when Await (the main
592-
// user of CollectAsync) needs to return results ordered.
593-
// Fortunately, we compose our own task on top of the
594-
// original that links each item with the task result and as
595-
// a consequence generate new and unique task objects.
596-
597-
var completedTask = await
598-
Task.WhenAny(tasks.Select(it => (Task) it.Task).Concat(cancellationTaskSource.Task))
599-
.ConfigureAwait(continueOnCapturedContext: false);
600-
601-
if (completedTask == cancellationTaskSource.Task)
645+
try
602646
{
603-
// Cancellation during the wait means the enumeration
604-
// has been stopped by the user so the results of the
605-
// remaining tasks are no longer needed. Those tasks
606-
// should cancel as a result of sharing the same
607-
// cancellation token and provided that they passed it
608-
// on to any downstream asynchronous operations. Either
609-
// way, this loop is done so exit hard here.
610-
611-
return;
647+
await concurrencyGate.EnterAsync(cancellationToken);
612648
}
613-
614-
var i = tasks.FindIndex(it => it.Task.Equals(completedTask));
615-
649+
catch (OperationCanceledException e) when (e.CancellationToken == cancellationToken)
616650
{
617-
var (item, task) = tasks[i];
618-
tasks.RemoveAt(i);
651+
return;
652+
}
619653

620-
// Await the task rather than using its result directly
621-
// to avoid having the task's exception bubble up as
622-
// AggregateException if the task failed.
654+
Interlocked.Increment(ref pendingCount);
623655

624-
collection.Add(completionNoticeSelector(item, task));
625-
}
656+
var item = enumerator.Current;
657+
var task = starter(item);
626658

627-
{
628-
if (reader.TryRead(out var item))
629-
tasks.Add((item, taskSelector(item)));
630-
}
631-
}
659+
// Add a continutation that notifies completion of the task,
660+
// along with the necessary housekeeping, in case it
661+
// completes before maximum concurrency is reached.
632662

633-
collection.Add(endNotice);
634-
}
635-
catch (Exception ex)
636-
{
637-
cancellationTokenSource.Cancel();
638-
collection.Add(errorNoticeSelector(ex));
639-
}
640-
finally
641-
{
642-
reader?.Dispose();
643-
}
663+
#pragma warning disable 4014 // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/compiler-messages/cs4014
644664

645-
collection.CompleteAdding();
646-
}
665+
task.ContinueWith(cancellationToken: cancellationToken,
666+
continuationOptions: TaskContinuationOptions.ExecuteSynchronously,
667+
scheduler: TaskScheduler.Current,
668+
continuationAction: t =>
669+
{
670+
concurrencyGate.Exit();
647671

648-
sealed class Reader<T> : IDisposable
649-
{
650-
IEnumerator<T> _enumerator;
672+
if (cancellationToken.IsCancellationRequested)
673+
return;
651674

652-
public Reader(IEnumerator<T> enumerator) =>
653-
_enumerator = enumerator;
675+
onTaskCompletion(item, t);
676+
OnPendingCompleted();
677+
});
654678

655-
public bool TryRead(out T item)
656-
{
657-
var ended = false;
658-
if (_enumerator == null || (ended = !_enumerator.MoveNext()))
659-
{
660-
if (ended)
661-
Dispose();
662-
item = default;
663-
return false;
679+
#pragma warning restore 4014
664680
}
665681

666-
item = _enumerator.Current;
667-
return true;
668-
}
669-
670-
public void Dispose()
671-
{
672-
var e = _enumerator;
673-
if (e == null)
674-
return;
675-
_enumerator = null;
676-
e.Dispose();
682+
OnPendingCompleted();
677683
}
678684
}
679685

@@ -720,6 +726,53 @@ static class TupleComparer<T1, T2, T3>
720726
public static readonly IComparer<(T1, T2, T3)> Item3 =
721727
Comparer<(T1, T2, T3)>.Create((x, y) => Comparer<T3>.Default.Compare(x.Item3, y.Item3));
722728
}
729+
730+
static class CompletedTask
731+
{
732+
#if NET451 || NETSTANDARD1_0
733+
734+
public static readonly Task Instance;
735+
736+
static CompletedTask()
737+
{
738+
var tcs = new TaskCompletionSource<object>();
739+
tcs.SetResult(null);
740+
Instance = tcs.Task;
741+
}
742+
743+
#else
744+
745+
public static readonly Task Instance = Task.CompletedTask;
746+
747+
#endif
748+
}
749+
750+
sealed class ConcurrencyGate
751+
{
752+
public static readonly ConcurrencyGate Unbounded = new ConcurrencyGate();
753+
754+
readonly SemaphoreSlim _semaphore;
755+
756+
ConcurrencyGate(SemaphoreSlim semaphore = null) =>
757+
_semaphore = semaphore;
758+
759+
public ConcurrencyGate(int max) :
760+
this(new SemaphoreSlim(max, max)) {}
761+
762+
public Task EnterAsync(CancellationToken token)
763+
{
764+
if (_semaphore == null)
765+
{
766+
token.ThrowIfCancellationRequested();
767+
return CompletedTask.Instance;
768+
}
769+
770+
return _semaphore.WaitAsync(token);
771+
}
772+
773+
public void Exit() =>
774+
_semaphore?.Release();
775+
}
723776
}
724777
}
725778

0 commit comments

Comments
 (0)