Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
import org.bson.Document;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.io.IOException;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -79,7 +79,6 @@
import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.testing.MongoAssertions.assertCause;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -242,11 +241,13 @@ void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
assertEquals(1, callback1.getInvocations());
long elapsed = msElapsedSince(start);

assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)),

assertFalse(elapsed > minTimeout(timeoutMs, serverSelectionTimeoutMS),
format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
+ "This indicates that the callback was not called with the expected timeout.",
min(serverSelectionTimeoutMS, timeoutMs),
elapsed));
elapsed,
minTimeout(timeoutMs, serverSelectionTimeoutMS)));

}
}

Expand All @@ -260,6 +261,10 @@ private static Stream<Arguments> testValidCallbackInputsTimeoutWhenTimeoutMsIsSe
500, // timeoutMS
1000, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if serverSelectionTimeoutMS is infinite",
500, // timeoutMS
-1, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold,
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0",
0, // infinite timeoutMS
500, // serverSelectionTimeoutMS
Expand All @@ -268,14 +273,17 @@ private static Stream<Arguments> testValidCallbackInputsTimeoutWhenTimeoutMsIsSe
}

// Not a prose test
@ParameterizedTest(name = "test callback timeout when server selection timeout is "
+ "infinite and timeoutMs is set to {0}")
@ValueSource(ints = {0, 100})
void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) {
@Test
@DisplayName("test callback timeout when serverSelectionTimeoutMS and timeoutMS are infinite")
void testCallbackTimeoutWhenServerSelectionTimeoutMsIsInfiniteTimeoutMsIsSet() {
TestCallback callback1 = createCallback();
Duration expectedTimeout = ChronoUnit.FOREVER.getDuration();

OidcCallback callback2 = (context) -> {
assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration());
assertEquals(expectedTimeout, context.getTimeout(),
format("Expected timeout to be infinite (%s), but was %s",
expectedTimeout, context.getTimeout()));

return callback1.onRequest(context);
};

Expand All @@ -284,7 +292,7 @@ void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final
builder.serverSelectionTimeout(
-1, // -1 means infinite
TimeUnit.MILLISECONDS))
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
.timeout(0, TimeUnit.MILLISECONDS)
.build();

try (MongoClient mongoClient = createMongoClient(clientSettings)) {
Expand Down Expand Up @@ -1242,4 +1250,10 @@ public TestCallback createHumanCallback() {
private long msElapsedSince(final long timeOfStart) {
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart);
}

private static long minTimeout(final int timeoutMs, final int serverSelectionTimeoutMS) {
long timeoutMsEffective = timeoutMs != 0 ? timeoutMs : Long.MAX_VALUE;
long serverSelectionTimeoutMSEffective = serverSelectionTimeoutMS != -1 ? serverSelectionTimeoutMS : Long.MAX_VALUE;
return Math.min(timeoutMsEffective, serverSelectionTimeoutMSEffective);
}
}