34
34
import java .util .concurrent .BlockingQueue ;
35
35
import java .util .concurrent .TimeUnit ;
36
36
import java .util .concurrent .TimeoutException ;
37
- import java .util .concurrent .atomic .AtomicReferenceFieldUpdater ;
37
+ import java .util .concurrent .atomic .AtomicLongFieldUpdater ;
38
38
import java .util .concurrent .locks .LockSupport ;
39
39
import javax .annotation .Nullable ;
40
40
57
57
* @param <T> Type of items emitted by the {@link Publisher} from which this {@link BlockingIterable} is created.
58
58
*/
59
59
final class PublisherAsBlockingIterable <T > implements BlockingIterable <T > {
60
- private static final int MAX_OUTSTANDING_DEMAND = min ( 128 , SpscBlockingQueue . ThreadStamp . MAX_DEMAND ) ;
60
+ private static final int MAX_OUTSTANDING_DEMAND = 128 ;
61
61
final Publisher <T > original ;
62
62
private final int queueCapacityHint ;
63
63
@@ -287,10 +287,17 @@ private static final class SpscBlockingQueue<T> implements BlockingQueue<T> {
287
287
getLong ("io.servicetalk.concurrent.internal.blockingIterableYieldNs" , 1024 );
288
288
289
289
@ SuppressWarnings ("rawtypes" )
290
- private static final AtomicReferenceFieldUpdater <SpscBlockingQueue , ThreadStamp > threadStampUpdater =
291
- AtomicReferenceFieldUpdater .newUpdater (SpscBlockingQueue .class , ThreadStamp . class , "threadStamp " );
290
+ private static final AtomicLongFieldUpdater <SpscBlockingQueue > pcIndexUpdater =
291
+ AtomicLongFieldUpdater .newUpdater (SpscBlockingQueue .class , "pcIndex " );
292
292
private final Queue <T > spscQueue ;
293
- private volatile ThreadStamp threadStamp = new ThreadStamp (null );
293
+ @ Nullable
294
+ private Thread consumerThread ;
295
+ /**
296
+ * high 32 bits == producer index (see {@link #pIndex(long)})
297
+ * low 32 bits == consumer index (see {@link #cIndex(long)}}
298
+ * @see #computePCIndex(int, int)
299
+ */
300
+ private volatile long pcIndex ;
294
301
295
302
SpscBlockingQueue (Queue <T > spscQueue ) {
296
303
this .spscQueue = requireNonNull (spscQueue );
@@ -505,21 +512,15 @@ public String toString() {
505
512
}
506
513
507
514
private void producerSignalAdded () {
508
- ThreadStamp nextStamp = null ;
509
515
for (;;) {
510
- final ThreadStamp currStamp = threadStamp ;
511
- if (nextStamp == null ) {
512
- nextStamp = new ThreadStamp (null , currStamp .incProducerIndex (), currStamp .consumerIndex );
513
- } else {
514
- // nextStamp.thread = null; -> not necessary, this is always the case.
515
- nextStamp .producerIndex = currStamp .incProducerIndex ();
516
- nextStamp .consumerIndex = currStamp .consumerIndex ;
517
- }
518
- if (threadStampUpdater .compareAndSet (this , currStamp , nextStamp )) {
519
- if (currStamp .thread != null ) {
520
- LockSupport .unpark (currStamp .thread );
521
- // Producer should be equal/behind consumer in order for an unpark to be necessary.
522
- assert currStamp .producerIndex - currStamp .consumerIndex <= 0 ;
516
+ final long currIndex = pcIndex ;
517
+ final int producer = pIndex (currIndex );
518
+ final int consumer = cIndex (currIndex );
519
+ if (pcIndexUpdater .compareAndSet (this , currIndex , computePCIndex (producer + 1 , consumer ))) {
520
+ if (producer - consumer <= 0 && consumerThread != null ) {
521
+ final Thread wakeThread = consumerThread ;
522
+ consumerThread = null ;
523
+ LockSupport .unpark (wakeThread );
523
524
}
524
525
break ;
525
526
}
@@ -528,38 +529,21 @@ private void producerSignalAdded() {
528
529
529
530
private T take0 (BiLongFunction <TimeUnit , T > taker , long timeout , TimeUnit unit ) throws InterruptedException {
530
531
final Thread currentThread = Thread .currentThread ();
531
- ThreadStamp nextStamp = null ;
532
532
for (;;) {
533
- ThreadStamp currStamp = threadStamp ;
534
- if (currStamp .producerIndex == currStamp .consumerIndex ) {
535
- if (nextStamp == null ) {
536
- nextStamp = new ThreadStamp (currentThread , currStamp .producerIndex ,
537
- currStamp .incConsumerIndex ());
538
- } else {
539
- // nextStamp.thread = currentThread; -> not necessary, this is always the case.
540
- nextStamp .producerIndex = currStamp .producerIndex ;
541
- nextStamp .consumerIndex = currStamp .incConsumerIndex ();
542
- }
543
-
544
- if (threadStampUpdater .compareAndSet (this , currStamp , nextStamp )) {
533
+ long currIndex = pcIndex ;
534
+ final int producer = pIndex (currIndex );
535
+ final int consumer = cIndex (currIndex );
536
+ if (producer == consumer ) {
537
+ // Set consumerThread before pcIndex, to establish happens-before with producer thread.
538
+ consumerThread = currentThread ;
539
+ if (pcIndexUpdater .compareAndSet (this , currIndex , computePCIndex (producer , consumer + 1 ))) {
545
540
return taker .apply (timeout , unit );
546
541
}
547
542
} else {
548
543
final T item = spscQueue .poll ();
549
544
if (item != null ) {
550
- for (;;) {
551
- if (nextStamp == null ) {
552
- nextStamp = new ThreadStamp (null , currStamp .producerIndex ,
553
- currStamp .incConsumerIndex ());
554
- } else {
555
- nextStamp .thread = null ;
556
- nextStamp .producerIndex = currStamp .producerIndex ;
557
- nextStamp .consumerIndex = currStamp .incConsumerIndex ();
558
- }
559
- if (threadStampUpdater .compareAndSet (this , currStamp , nextStamp )) {
560
- break ;
561
- }
562
- currStamp = threadStamp ;
545
+ while (!pcIndexUpdater .compareAndSet (this , currIndex , computePCIndex (producer , consumer + 1 ))) {
546
+ currIndex = pcIndex ;
563
547
}
564
548
return item ;
565
549
}
@@ -569,21 +553,12 @@ private T take0(BiLongFunction<TimeUnit, T> taker, long timeout, TimeUnit unit)
569
553
}
570
554
}
571
555
572
- private void consumerSignalRemoved (int i ) {
573
- ThreadStamp nextStamp = null ;
556
+ private void consumerSignalRemoved (final int i ) {
574
557
for (;;) {
575
- final ThreadStamp currStamp = threadStamp ;
576
- assert (currStamp .producerIndex - currStamp .consumerIndex ) + i < ThreadStamp .MAX_DEMAND ;
577
- if (nextStamp == null ) {
578
- nextStamp = new ThreadStamp (null , currStamp .producerIndex ,
579
- currStamp .incConsumerIndex (i ));
580
- } else {
581
- // nextStamp.thread = null; -> not necessary, this is always the case.
582
- nextStamp .producerIndex = currStamp .producerIndex ;
583
- nextStamp .consumerIndex = currStamp .incConsumerIndex (i );
584
- }
585
-
586
- if (threadStampUpdater .compareAndSet (this , currStamp , nextStamp )) {
558
+ final long currIndex = pcIndex ;
559
+ final int producer = pIndex (currIndex );
560
+ final int consumer = cIndex (currIndex );
561
+ if (pcIndexUpdater .compareAndSet (this , currIndex , computePCIndex (producer , consumer + i ))) {
587
562
break ;
588
563
}
589
564
}
@@ -641,44 +616,16 @@ private static void checkInterrupted() throws InterruptedException {
641
616
}
642
617
}
643
618
644
- /**
645
- * The producer thread may produce multiple items before the consumer thread consume the events. If the consumer
646
- * thread changes we need to make sure the new consumer thread observes production events so this object
647
- * contains the thread to wakeup and a count of how many items are in the queue and not yet consumed.
648
- */
649
- private static final class ThreadStamp {
650
- private static final Short MAX_DEMAND = Short .MAX_VALUE ;
651
- @ Nullable
652
- Thread thread ;
653
- short producerIndex ;
654
- short consumerIndex ;
655
-
656
- ThreadStamp (@ Nullable Thread thread ) {
657
- this .thread = thread ;
658
- }
659
-
660
- ThreadStamp (@ Nullable Thread thread , short producerIndex , short consumerIndex ) {
661
- this .thread = thread ;
662
- this .producerIndex = producerIndex ;
663
- this .consumerIndex = consumerIndex ;
664
- }
665
-
666
- short incProducerIndex () {
667
- return (short ) (producerIndex + 1 );
668
- }
669
-
670
- short incConsumerIndex () {
671
- return (short ) (consumerIndex + 1 );
672
- }
619
+ private static long computePCIndex (int producer , int consumer ) {
620
+ return ((long ) producer << 32 ) | consumer ;
621
+ }
673
622
674
- short incConsumerIndex ( int amt ) {
675
- return (short ) ( consumerIndex + amt ) ;
676
- }
623
+ private static int cIndex ( long pcIndex ) {
624
+ return (int ) pcIndex ;
625
+ }
677
626
678
- @ Override
679
- public String toString () {
680
- return "thread: " + thread + " producerIndex: " + producerIndex + " consumerIndex: " + consumerIndex ;
681
- }
627
+ private static int pIndex (long pcIndex ) {
628
+ return (int ) (pcIndex >>> 32 );
682
629
}
683
630
684
631
private interface BiLongFunction <T , R > {
0 commit comments