Skip to content

Commit 30f1470

Browse files
committed
remove ThreadStamp
1 parent d46403b commit 30f1470

File tree

1 file changed

+42
-95
lines changed

1 file changed

+42
-95
lines changed

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java

Lines changed: 42 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import java.util.concurrent.BlockingQueue;
3535
import java.util.concurrent.TimeUnit;
3636
import java.util.concurrent.TimeoutException;
37-
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
37+
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
3838
import java.util.concurrent.locks.LockSupport;
3939
import javax.annotation.Nullable;
4040

@@ -57,7 +57,7 @@
5757
* @param <T> Type of items emitted by the {@link Publisher} from which this {@link BlockingIterable} is created.
5858
*/
5959
final class PublisherAsBlockingIterable<T> implements BlockingIterable<T> {
60-
private static final int MAX_OUTSTANDING_DEMAND = min(128, SpscBlockingQueue.ThreadStamp.MAX_DEMAND);
60+
private static final int MAX_OUTSTANDING_DEMAND = 128;
6161
final Publisher<T> original;
6262
private final int queueCapacityHint;
6363

@@ -287,10 +287,17 @@ private static final class SpscBlockingQueue<T> implements BlockingQueue<T> {
287287
getLong("io.servicetalk.concurrent.internal.blockingIterableYieldNs", 1024);
288288

289289
@SuppressWarnings("rawtypes")
290-
private static final AtomicReferenceFieldUpdater<SpscBlockingQueue, ThreadStamp> threadStampUpdater =
291-
AtomicReferenceFieldUpdater.newUpdater(SpscBlockingQueue.class, ThreadStamp.class, "threadStamp");
290+
private static final AtomicLongFieldUpdater<SpscBlockingQueue> pcIndexUpdater =
291+
AtomicLongFieldUpdater.newUpdater(SpscBlockingQueue.class, "pcIndex");
292292
private final Queue<T> spscQueue;
293-
private volatile ThreadStamp threadStamp = new ThreadStamp(null);
293+
@Nullable
294+
private Thread consumerThread;
295+
/**
296+
* high 32 bits == producer index (see {@link #pIndex(long)})
297+
* low 32 bits == consumer index (see {@link #cIndex(long)}}
298+
* @see #computePCIndex(int, int)
299+
*/
300+
private volatile long pcIndex;
294301

295302
SpscBlockingQueue(Queue<T> spscQueue) {
296303
this.spscQueue = requireNonNull(spscQueue);
@@ -505,21 +512,15 @@ public String toString() {
505512
}
506513

507514
private void producerSignalAdded() {
508-
ThreadStamp nextStamp = null;
509515
for (;;) {
510-
final ThreadStamp currStamp = threadStamp;
511-
if (nextStamp == null) {
512-
nextStamp = new ThreadStamp(null, currStamp.incProducerIndex(), currStamp.consumerIndex);
513-
} else {
514-
// nextStamp.thread = null; -> not necessary, this is always the case.
515-
nextStamp.producerIndex = currStamp.incProducerIndex();
516-
nextStamp.consumerIndex = currStamp.consumerIndex;
517-
}
518-
if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp)) {
519-
if (currStamp.thread != null) {
520-
LockSupport.unpark(currStamp.thread);
521-
// Producer should be equal/behind consumer in order for an unpark to be necessary.
522-
assert currStamp.producerIndex - currStamp.consumerIndex <= 0;
516+
final long currIndex = pcIndex;
517+
final int producer = pIndex(currIndex);
518+
final int consumer = cIndex(currIndex);
519+
if (pcIndexUpdater.compareAndSet(this, currIndex, computePCIndex(producer + 1, consumer))) {
520+
if (producer - consumer <= 0 && consumerThread != null) {
521+
final Thread wakeThread = consumerThread;
522+
consumerThread = null;
523+
LockSupport.unpark(wakeThread);
523524
}
524525
break;
525526
}
@@ -528,38 +529,21 @@ private void producerSignalAdded() {
528529

529530
private T take0(BiLongFunction<TimeUnit, T> taker, long timeout, TimeUnit unit) throws InterruptedException {
530531
final Thread currentThread = Thread.currentThread();
531-
ThreadStamp nextStamp = null;
532532
for (;;) {
533-
ThreadStamp currStamp = threadStamp;
534-
if (currStamp.producerIndex == currStamp.consumerIndex) {
535-
if (nextStamp == null) {
536-
nextStamp = new ThreadStamp(currentThread, currStamp.producerIndex,
537-
currStamp.incConsumerIndex());
538-
} else {
539-
// nextStamp.thread = currentThread; -> not necessary, this is always the case.
540-
nextStamp.producerIndex = currStamp.producerIndex;
541-
nextStamp.consumerIndex = currStamp.incConsumerIndex();
542-
}
543-
544-
if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp)) {
533+
long currIndex = pcIndex;
534+
final int producer = pIndex(currIndex);
535+
final int consumer = cIndex(currIndex);
536+
if (producer == consumer) {
537+
// Set consumerThread before pcIndex, to establish happens-before with producer thread.
538+
consumerThread = currentThread;
539+
if (pcIndexUpdater.compareAndSet(this, currIndex, computePCIndex(producer, consumer + 1))) {
545540
return taker.apply(timeout, unit);
546541
}
547542
} else {
548543
final T item = spscQueue.poll();
549544
if (item != null) {
550-
for (;;) {
551-
if (nextStamp == null) {
552-
nextStamp = new ThreadStamp(null, currStamp.producerIndex,
553-
currStamp.incConsumerIndex());
554-
} else {
555-
nextStamp.thread = null;
556-
nextStamp.producerIndex = currStamp.producerIndex;
557-
nextStamp.consumerIndex = currStamp.incConsumerIndex();
558-
}
559-
if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp)) {
560-
break;
561-
}
562-
currStamp = threadStamp;
545+
while (!pcIndexUpdater.compareAndSet(this, currIndex, computePCIndex(producer, consumer + 1))) {
546+
currIndex = pcIndex;
563547
}
564548
return item;
565549
}
@@ -569,21 +553,12 @@ private T take0(BiLongFunction<TimeUnit, T> taker, long timeout, TimeUnit unit)
569553
}
570554
}
571555

572-
private void consumerSignalRemoved(int i) {
573-
ThreadStamp nextStamp = null;
556+
private void consumerSignalRemoved(final int i) {
574557
for (;;) {
575-
final ThreadStamp currStamp = threadStamp;
576-
assert (currStamp.producerIndex - currStamp.consumerIndex) + i < ThreadStamp.MAX_DEMAND;
577-
if (nextStamp == null) {
578-
nextStamp = new ThreadStamp(null, currStamp.producerIndex,
579-
currStamp.incConsumerIndex(i));
580-
} else {
581-
// nextStamp.thread = null; -> not necessary, this is always the case.
582-
nextStamp.producerIndex = currStamp.producerIndex;
583-
nextStamp.consumerIndex = currStamp.incConsumerIndex(i);
584-
}
585-
586-
if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp)) {
558+
final long currIndex = pcIndex;
559+
final int producer = pIndex(currIndex);
560+
final int consumer = cIndex(currIndex);
561+
if (pcIndexUpdater.compareAndSet(this, currIndex, computePCIndex(producer, consumer + i))) {
587562
break;
588563
}
589564
}
@@ -641,44 +616,16 @@ private static void checkInterrupted() throws InterruptedException {
641616
}
642617
}
643618

644-
/**
645-
* The producer thread may produce multiple items before the consumer thread consume the events. If the consumer
646-
* thread changes we need to make sure the new consumer thread observes production events so this object
647-
* contains the thread to wakeup and a count of how many items are in the queue and not yet consumed.
648-
*/
649-
private static final class ThreadStamp {
650-
private static final Short MAX_DEMAND = Short.MAX_VALUE;
651-
@Nullable
652-
Thread thread;
653-
short producerIndex;
654-
short consumerIndex;
655-
656-
ThreadStamp(@Nullable Thread thread) {
657-
this.thread = thread;
658-
}
659-
660-
ThreadStamp(@Nullable Thread thread, short producerIndex, short consumerIndex) {
661-
this.thread = thread;
662-
this.producerIndex = producerIndex;
663-
this.consumerIndex = consumerIndex;
664-
}
665-
666-
short incProducerIndex() {
667-
return (short) (producerIndex + 1);
668-
}
669-
670-
short incConsumerIndex() {
671-
return (short) (consumerIndex + 1);
672-
}
619+
private static long computePCIndex(int producer, int consumer) {
620+
return ((long) producer << 32) | consumer;
621+
}
673622

674-
short incConsumerIndex(int amt) {
675-
return (short) (consumerIndex + amt);
676-
}
623+
private static int cIndex(long pcIndex) {
624+
return (int) pcIndex;
625+
}
677626

678-
@Override
679-
public String toString() {
680-
return "thread: " + thread + " producerIndex: " + producerIndex + " consumerIndex: " + consumerIndex;
681-
}
627+
private static int pIndex(long pcIndex) {
628+
return (int) (pcIndex >>> 32);
682629
}
683630

684631
private interface BiLongFunction<T, R> {

0 commit comments

Comments
 (0)