Skip to content
19 changes: 10 additions & 9 deletions jetcd-core/src/main/java/io/etcd/jetcd/impl/WatchImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.etcd.jetcd.ByteSequence;
import io.etcd.jetcd.Watch;
import io.etcd.jetcd.api.VertxWatchGrpc;
import io.etcd.jetcd.api.WatchCancelRequest;
import io.etcd.jetcd.api.WatchCreateRequest;
import io.etcd.jetcd.api.WatchProgressRequest;
import io.etcd.jetcd.api.WatchRequest;
Expand All @@ -41,7 +40,9 @@
import io.etcd.jetcd.support.Errors;
import io.etcd.jetcd.support.Util;
import io.grpc.Status;
import io.grpc.stub.ClientCallStreamObserver;
import io.vertx.core.streams.WriteStream;
import io.vertx.grpc.stub.GrpcWriteStream;

import com.google.common.base.Strings;
import com.google.common.util.concurrent.FutureCallback;
Expand Down Expand Up @@ -201,14 +202,14 @@ public void close() {
synchronized (WatchImpl.this.lock) {
if (closed.compareAndSet(false, true)) {
if (wstream.get() != null) {
if (id != -1) {
final WatchCancelRequest watchCancelRequest = WatchCancelRequest.newBuilder().setWatchId(this.id)
.build();
final WatchRequest request = WatchRequest.newBuilder().setCancelRequest(watchCancelRequest).build();

wstream.get().end(request);
} else {
wstream.get().end();
WriteStream<WatchRequest> ws = wstream.get();
if (ws instanceof GrpcWriteStream<?>) {
GrpcWriteStream<?> gws = (GrpcWriteStream<?>) ws;
var observer = gws.streamObserver();
if (observer instanceof ClientCallStreamObserver<?>) {
ClientCallStreamObserver<?> callObs = (ClientCallStreamObserver<?>) observer;
callObs.cancel("Watcher cancelled", null);
}
}
}

Expand Down
27 changes: 27 additions & 0 deletions jetcd-core/src/test/java/io/etcd/jetcd/impl/WatchUnitTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,31 @@ public void testWatcherWithRequireLeaderErrsOutOnNoLeader() throws InterruptedEx
assertThat(watcher.isClosed()).isTrue();
}
}

@Test
public void testWatcherCloseSendsRstStream() throws InterruptedException {
CountDownLatch closeLatch = new CountDownLatch(1);
Watch.Listener listener = Watch.listener(
TestUtil::noOpWatchResponseConsumer,
() -> closeLatch.countDown());

Watch.Watcher watcher = watchClient.watch(KEY, listener);
try {
WatchResponse createdResponse = createWatchResponse(0);
responseObserverRef.get().onNext(createdResponse);

watcher.close();

boolean closedCompleted = closeLatch.await(1, TimeUnit.SECONDS);
assertThat(closedCompleted).isTrue();

assertThat(watcher.isClosed()).isTrue();

verify(requestStreamObserverMock, timeout(100).atLeast(1)).onNext(argThat(hasCreateKey(KEY)));
} finally {
if (!watcher.isClosed()) {
watcher.close();
}
}
}
}