Skip to content

Commit f6a87b2

Browse files
authored
Always elide devices missing required PreKeys
1 parent e8a1854 commit f6a87b2

File tree

12 files changed

+302
-473
lines changed

12 files changed

+302
-473
lines changed

service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import jakarta.ws.rs.WebApplicationException;
3535
import jakarta.ws.rs.core.MediaType;
3636
import jakarta.ws.rs.core.Response;
37-
import java.io.IOException;
3837
import java.nio.ByteBuffer;
3938
import java.security.MessageDigest;
4039
import java.security.NoSuchAlgorithmException;
@@ -57,7 +56,6 @@
5756
import org.whispersystems.textsecuregcm.auth.GroupSendTokenHeader;
5857
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
5958
import org.whispersystems.textsecuregcm.entities.CheckKeysRequest;
60-
import org.whispersystems.textsecuregcm.entities.ECPreKey;
6159
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
6260
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
6361
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@@ -92,7 +90,6 @@ public class KeysController {
9290
private final ServerSecretParams serverSecretParams;
9391
private final Clock clock;
9492

95-
private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys");
9693
private static final String STORE_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "storeKeys");
9794
private static final String STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME =
9895
MetricsUtil.name(KeysController.class, "storeKeyBundleSize");
@@ -395,51 +392,14 @@ public PreKeyResponse getDeviceKeys(
395392

396393
final List<Device> devices = parseDeviceId(deviceId, target);
397394

398-
final List<PreKeyResponseItem> responseItems = Flux.fromIterable(devices)
399-
.flatMap(device -> Mono.zip(
400-
Mono.just(device),
401-
Mono.fromFuture(keysManager.takeEC(targetIdentifier.uuid(), device.getId())),
402-
Mono.fromFuture(keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())),
403-
Mono.fromFuture(keysManager.takePQ(targetIdentifier.uuid(), device.getId()))))
404-
.flatMap(deviceAndPreKeys -> {
405-
final Device device = deviceAndPreKeys.getT1();
406-
final KEMSignedPreKey pqPreKey = deviceAndPreKeys.getT4().orElse(null);
407-
final ECPreKey unsignedEcPreKey = deviceAndPreKeys.getT2().orElse(null);
408-
final ECSignedPreKey signedEcPreKey = deviceAndPreKeys.getT3().orElse(null);
409-
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
410-
411-
Metrics.counter(GET_KEYS_COUNTER_NAME, Tags.of(
412-
UserAgentTagUtil.getPlatformTag(userAgent),
413-
Tag.of(IDENTITY_TYPE_TAG_NAME, targetIdentifier.identityType().name()),
414-
Tag.of("oneTimeEcKeyAvailable", String.valueOf(unsignedEcPreKey != null)),
415-
Tag.of("signedEcKeyAvailable", String.valueOf(signedEcPreKey != null)),
416-
Tag.of("pqKeyAvailable", String.valueOf(pqPreKey != null))))
417-
.increment();
418-
419-
if (pqPreKey == null) {
420-
// The PQ prekey should never be null. This should only happen if the account or device has been
421-
// removed.
422-
return Mono.fromCompletionStage(() -> accounts.getByServiceIdentifierAsync(targetIdentifier))
423-
.flatMap(maybeAccount -> maybeAccount
424-
.flatMap(rereadAccount -> rereadAccount.getDevice(device.getId()))
425-
.filter(rereadDevice ->
426-
registrationId == rereadDevice.getRegistrationId(targetIdentifier.identityType()))
427-
.map(rereadDevice -> {
428-
// The account and device still exist, and the device we originally read matches the current
429-
// registrationId, so the lastResort key should have existed
430-
log.error(
431-
"Target {}, Account {}, DeviceId {}, RegistrationId {} was missing a last resort prekey",
432-
targetIdentifier,
433-
target.getIdentifier(IdentityType.ACI),
434-
rereadDevice.getId(),
435-
rereadDevice.getRegistrationId(targetIdentifier.identityType()));
436-
return Mono.<PreKeyResponseItem>error(new IOException("Device missing last resort prekey"));
437-
})
438-
.orElse(Mono.empty()));
439-
}
440-
return Mono.just(new PreKeyResponseItem(
441-
device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey, pqPreKey));
442-
})
395+
final List<PreKeyResponseItem> responseItems = Flux.fromIterable(devices).flatMap(device -> Mono
396+
.fromCompletionStage(keysManager.takeDevicePreKeys(device.getId(), targetIdentifier, userAgent))
397+
.flatMap(Mono::justOrEmpty)
398+
.map(devicePreKeys -> new PreKeyResponseItem(
399+
device.getId(), device.getRegistrationId(targetIdentifier.identityType()),
400+
devicePreKeys.ecSignedPreKey(),
401+
devicePreKeys.ecPreKey().orElse(null),
402+
devicePreKeys.kemSignedPreKey())))
443403
.collectList()
444404
.block();
445405

service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class PreKeyResponseItem {
1919
private int registrationId;
2020

2121
@JsonProperty
22-
@Schema(description="the signed elliptic-curve prekey for the device, if one has been set")
22+
@Schema(description="the signed elliptic-curve prekey for the device")
2323
private ECSignedPreKey signedPreKey;
2424

2525
@JsonProperty
@@ -28,7 +28,7 @@ public class PreKeyResponseItem {
2828

2929
@JsonProperty
3030
@Schema(description="a signed post-quantum prekey for the device " +
31-
"(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)")
31+
"(a one-time prekey if any remain, otherwise the last-resort prekey)")
3232
private KEMSignedPreKey pqPreKey;
3333

3434
public PreKeyResponseItem() {}

service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public Mono<GetPreKeysResponse> getPreKeys(final GetPreKeysAnonymousRequest requ
5656
groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier);
5757

5858
yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND)
59-
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager));
59+
.flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager));
6060
} catch (final StatusException e) {
6161
yield Mono.error(e);
6262
}
@@ -66,7 +66,7 @@ yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND)
6666
lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED)
6767
.flatMap(targetAccount ->
6868
UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
69-
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)
69+
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)
7070
: Mono.error(Status.UNAUTHENTICATED.asException()));
7171

7272
default -> Mono.error(Status.INVALID_ARGUMENT.asException());

service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
import org.signal.chat.common.EcSignedPreKey;
1212
import org.signal.chat.common.KemSignedPreKey;
1313
import org.signal.chat.keys.GetPreKeysResponse;
14-
import org.whispersystems.textsecuregcm.entities.ECPreKey;
15-
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
16-
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
17-
import org.whispersystems.textsecuregcm.identity.IdentityType;
14+
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
1815
import org.whispersystems.textsecuregcm.storage.Account;
1916
import org.whispersystems.textsecuregcm.storage.Device;
2017
import org.whispersystems.textsecuregcm.storage.KeysManager;
@@ -28,50 +25,45 @@ class KeysGrpcHelper {
2825
static final byte ALL_DEVICES = 0;
2926

3027
static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount,
31-
final IdentityType identityType,
28+
final ServiceIdentifier targetServiceIdentifier,
3229
final byte targetDeviceId,
3330
final KeysManager keysManager) {
3431

3532
final Flux<Device> devices = targetDeviceId == ALL_DEVICES
3633
? Flux.fromIterable(targetAccount.getDevices())
3734
: Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));
3835

36+
final String userAgent = RequestAttributesUtil.getUserAgent().orElse(null);
3937
return devices
40-
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
41-
.flatMap(device -> Flux.merge(
42-
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),
43-
Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())),
44-
Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId())))
38+
.flatMap(device -> Mono
39+
.fromFuture(keysManager.takeDevicePreKeys(device.getId(), targetServiceIdentifier, userAgent))
4540
.flatMap(Mono::justOrEmpty)
46-
.reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> {
47-
if (preKey instanceof ECPreKey ecPreKey) {
48-
builder.setEcOneTimePreKey(EcPreKey.newBuilder()
41+
.map(devicePreKeys -> {
42+
final GetPreKeysResponse.PreKeyBundle.Builder builder = GetPreKeysResponse.PreKeyBundle.newBuilder()
43+
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
44+
.setKeyId(devicePreKeys.ecSignedPreKey().keyId())
45+
.setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey()))
46+
.setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature()))
47+
.build())
48+
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
49+
.setKeyId(devicePreKeys.kemSignedPreKey().keyId())
50+
.setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey()))
51+
.setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature()))
52+
.build());
53+
devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder()
4954
.setKeyId(ecPreKey.keyId())
5055
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
51-
.build());
52-
} else if (preKey instanceof ECSignedPreKey ecSignedPreKey) {
53-
builder.setEcSignedPreKey(EcSignedPreKey.newBuilder()
54-
.setKeyId(ecSignedPreKey.keyId())
55-
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
56-
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
57-
.build());
58-
} else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) {
59-
builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
60-
.setKeyId(kemSignedPreKey.keyId())
61-
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
62-
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
63-
.build());
64-
} else {
65-
throw new AssertionError("Unexpected pre-key type: " + preKey.getClass());
66-
}
67-
68-
return builder;
69-
})
70-
// Cast device IDs to `int` to match data types in the response object’s protobuf definition
71-
.map(builder -> Tuples.of((int) device.getId(), builder.build())))
56+
.build()));
57+
// Cast device IDs to `int` to match data types in the response object’s protobuf definition
58+
return Tuples.of((int) device.getId(), builder.build());
59+
}))
60+
// If there were no devices with valid prekey bundles in the account, the account is gone
61+
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
7262
.collectMap(Tuple2::getT1, Tuple2::getT2)
7363
.map(preKeyBundles -> GetPreKeysResponse.newBuilder()
74-
.setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize()))
64+
.setIdentityKey(ByteString
65+
.copyFrom(targetAccount.getIdentityKey(targetServiceIdentifier.identityType())
66+
.serialize()))
7567
.putAllPreKeys(preKeyBundles)
7668
.build());
7769
}

service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ public Mono<GetPreKeysResponse> getPreKeys(final GetPreKeysRequest request) {
136136
.flatMap(Mono::justOrEmpty))
137137
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
138138
.flatMap(targetAccount ->
139-
KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), deviceId, keysManager));
139+
KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager));
140140
}
141141

142142
@Override

service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,31 @@
55

66
package org.whispersystems.textsecuregcm.storage;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import io.micrometer.core.instrument.Metrics;
910
import java.util.List;
1011
import java.util.Optional;
1112
import java.util.UUID;
1213
import java.util.concurrent.CompletableFuture;
14+
import io.micrometer.core.instrument.Tag;
15+
import io.micrometer.core.instrument.Tags;
16+
import org.whispersystems.textsecuregcm.controllers.KeysController;
1317
import org.whispersystems.textsecuregcm.entities.ECPreKey;
1418
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
1519
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
1620
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
21+
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
1722
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
23+
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
24+
import org.whispersystems.textsecuregcm.util.Futures;
25+
import org.whispersystems.textsecuregcm.util.Optionals;
1826
import reactor.core.publisher.Flux;
1927
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
28+
import javax.annotation.Nullable;
2029

2130
public class KeysManager {
31+
// KeysController for backwards compatibility
32+
private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys");
2233

2334
private final SingleUseECPreKeyStore ecPreKeys;
2435
private final SingleUseKEMPreKeyStore pqPreKeys;
@@ -115,11 +126,13 @@ public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, fin
115126

116127
}
117128

118-
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
129+
@VisibleForTesting
130+
CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
119131
return ecPreKeys.take(identifier, deviceId);
120132
}
121133

122-
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
134+
@VisibleForTesting
135+
CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
123136
final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
124137
return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys)
125138
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
@@ -209,4 +222,36 @@ public Flux<DeviceKEMPreKeyPages> listStoredKEMPreKeyPages(int lookupConcurrency
209222
public CompletableFuture<Void> pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) {
210223
return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId);
211224
}
225+
226+
public record DevicePreKeys(
227+
ECSignedPreKey ecSignedPreKey,
228+
Optional<ECPreKey> ecPreKey,
229+
KEMSignedPreKey kemSignedPreKey) {}
230+
231+
public CompletableFuture<Optional<DevicePreKeys>> takeDevicePreKeys(
232+
final byte deviceId,
233+
final ServiceIdentifier serviceIdentifier,
234+
final @Nullable String userAgent) {
235+
final UUID uuid = serviceIdentifier.uuid();
236+
return Futures.zipWith(
237+
this.takeEC(uuid, deviceId),
238+
this.getEcSignedPreKey(uuid, deviceId),
239+
this.takePQ(uuid, deviceId),
240+
(maybeUnsignedEcPreKey, maybeSignedEcPreKey, maybePqPreKey) -> {
241+
242+
Metrics.counter(GET_KEYS_COUNTER_NAME, Tags.of(
243+
UserAgentTagUtil.getPlatformTag(userAgent),
244+
Tag.of("identityType", serviceIdentifier.identityType().name()),
245+
Tag.of("oneTimeEcKeyAvailable", String.valueOf(maybeUnsignedEcPreKey.isPresent())),
246+
Tag.of("signedEcKeyAvailable", String.valueOf(maybeSignedEcPreKey.isPresent())),
247+
Tag.of("pqKeyAvailable", String.valueOf(maybePqPreKey.isPresent()))))
248+
.increment();
249+
250+
// The pq prekey and signed EC prekey should never be null for an existing account. This should only happen
251+
// if the account or device has been removed and the read was split, so we can return empty in those cases.
252+
return Optionals.zipWith(maybeSignedEcPreKey, maybePqPreKey, (signedEcPreKey, pqPreKey) ->
253+
new DevicePreKeys(signedEcPreKey, maybeUnsignedEcPreKey, pqPreKey));
254+
})
255+
.toCompletableFuture();
256+
}
212257
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright 2025 Signal Messenger, LLC
3+
* SPDX-License-Identifier: AGPL-3.0-only
4+
*/
5+
6+
package org.whispersystems.textsecuregcm.util;
7+
8+
import java.util.concurrent.CompletionStage;
9+
import org.apache.commons.lang3.function.TriFunction;
10+
11+
public class Futures {
12+
13+
public static <T, U, V, R> CompletionStage<R> zipWith(
14+
CompletionStage<T> futureT,
15+
CompletionStage<U> futureU,
16+
CompletionStage<V> futureV,
17+
TriFunction<T, U, V, R> fun) {
18+
19+
return futureT.thenCompose(t -> futureU.thenCombine(futureV, (u, v) -> fun.apply(t, u, v)));
20+
}
21+
}

service/src/main/proto/org/signal/chat/keys.proto

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ message GetPreKeysResponse {
195195

196196
/**
197197
* A one-time KEM pre-key (or a last-resort KEM pre-key) for the targeted
198-
* account/device/identity. May not be set if the targeted device has not
199-
* yet uploaded any KEM pre-keys.
198+
* account/device/identity.
200199
*/
201200
common.KemSignedPreKey kem_one_time_pre_key = 3;
202201
}

0 commit comments

Comments
 (0)