1818import static org .assertj .core .api .Assertions .assertThat ;
1919
2020import java .util .List ;
21+ import java .util .concurrent .CompletableFuture ;
22+ import java .util .concurrent .ExecutionException ;
23+ import java .util .concurrent .ExecutorService ;
24+ import java .util .concurrent .Executors ;
2125import org .junit .jupiter .api .Test ;
2226import software .amazon .awssdk .auth .credentials .AnonymousCredentialsProvider ;
2327import software .amazon .awssdk .awscore .interceptor .TraceIdExecutionInterceptor ;
@@ -243,7 +247,7 @@ public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttrib
243247 }
244248
245249 @ Test
246- public void traceIdInterceptorWithNewThreadInheritsTraceId () throws Exception {
250+ public void traceIdInterceptorWithNewThreadInheritsTraceId () {
247251 EnvironmentVariableHelper .run (env -> {
248252 env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
249253
@@ -275,4 +279,67 @@ public void traceIdInterceptorWithNewThreadInheritsTraceId() throws Exception {
275279 }
276280 });
277281 }
282+
283+ @ Test
284+ public void traceIdInterceptorWithExecutiveServicePreservesTraceId () {
285+ EnvironmentVariableHelper .run (env -> {
286+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
287+
288+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
289+ ExecutorService executor = Executors .newFixedThreadPool (2 );
290+ try (MockSyncHttpClient mockHttpClient = new MockSyncHttpClient ();
291+ ProtocolRestJsonClient client = ProtocolRestJsonClient .builder ()
292+ .region (Region .US_WEST_2 )
293+ .credentialsProvider (AnonymousCredentialsProvider .create ())
294+ .httpClient (mockHttpClient )
295+ .build ()) {
296+
297+ mockHttpClient .stubNextResponse (HttpExecuteResponse .builder ()
298+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
299+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
300+ .build ());
301+
302+ executor .submit (() -> client .allTypes ()).get ();
303+
304+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
305+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
306+
307+ } catch (InterruptedException | ExecutionException e ) {
308+ throw new RuntimeException (e );
309+ } finally {
310+ SdkInternalThreadLocal .clear ();
311+ }
312+ });
313+ }
314+
315+ @ Test
316+ public void traceIdInterceptorWithRunAsyncDoesNotPreservesTraceId () throws Exception {
317+ EnvironmentVariableHelper .run (env -> {
318+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
319+
320+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
321+ try (MockSyncHttpClient mockHttpClient = new MockSyncHttpClient ();
322+ ProtocolRestJsonClient client = ProtocolRestJsonClient .builder ()
323+ .region (Region .US_WEST_2 )
324+ .credentialsProvider (AnonymousCredentialsProvider .create ())
325+ .httpClient (mockHttpClient )
326+ .build ()) {
327+
328+ mockHttpClient .stubNextResponse (HttpExecuteResponse .builder ()
329+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
330+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
331+ .build ());
332+
333+ CompletableFuture .runAsync (client ::allTypes ).get ();
334+
335+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
336+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).isEmpty ();
337+
338+ } catch (InterruptedException | ExecutionException e ) {
339+ throw new RuntimeException (e );
340+ } finally {
341+ SdkInternalThreadLocal .clear ();
342+ }
343+ });
344+ }
278345}
0 commit comments