|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package org.opensearch.ml.engine.algorithms.agent.tracing; |
| 7 | + |
| 8 | +import static org.junit.Assert.assertEquals; |
| 9 | +import static org.junit.Assert.assertNotNull; |
| 10 | +import static org.junit.Assert.assertNotSame; |
| 11 | +import static org.junit.Assert.assertTrue; |
| 12 | +import static org.mockito.ArgumentMatchers.any; |
| 13 | +import static org.mockito.Mockito.mock; |
| 14 | +import static org.mockito.Mockito.when; |
| 15 | + |
| 16 | +import java.util.HashMap; |
| 17 | +import java.util.Map; |
| 18 | +import java.util.concurrent.CountDownLatch; |
| 19 | +import java.util.concurrent.ExecutorService; |
| 20 | +import java.util.concurrent.Executors; |
| 21 | +import java.util.concurrent.TimeUnit; |
| 22 | +import java.util.concurrent.atomic.AtomicInteger; |
| 23 | + |
| 24 | +import org.junit.Before; |
| 25 | +import org.junit.Test; |
| 26 | +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; |
| 27 | +import org.opensearch.telemetry.tracing.Span; |
| 28 | +import org.opensearch.telemetry.tracing.Tracer; |
| 29 | + |
| 30 | +/** |
| 31 | + * Tests for multi-span scenarios, span identification, and race conditions in MLAgentTracer. |
| 32 | + * These tests ensure that: |
| 33 | + * 1. Each span has unique identification |
| 34 | + * 2. Parent-child relationships are properly maintained |
| 35 | + * 3. No race conditions occur in concurrent usage |
| 36 | + * 4. Span attributes help identify which span is problematic |
| 37 | + */ |
| 38 | +public class MLAgentTracerMultiSpanTests { |
| 39 | + private MLFeatureEnabledSetting mockFeatureSetting; |
| 40 | + private Tracer mockTracer; |
| 41 | + |
| 42 | + @Before |
| 43 | + public void setup() { |
| 44 | + mockFeatureSetting = mock(MLFeatureEnabledSetting.class); |
| 45 | + mockTracer = mock(Tracer.class); |
| 46 | + MLAgentTracer.resetForTest(); |
| 47 | + } |
| 48 | + |
| 49 | + private Span createMockSpan(String traceId, String spanId, String spanName) { |
| 50 | + Span mockSpan = mock(Span.class); |
| 51 | + when(mockSpan.getTraceId()).thenReturn(traceId); |
| 52 | + when(mockSpan.getSpanId()).thenReturn(spanId); |
| 53 | + when(mockSpan.getSpanName()).thenReturn(spanName); |
| 54 | + return mockSpan; |
| 55 | + } |
| 56 | + |
| 57 | + @Test |
| 58 | + public void testMultiSpanHierarchyWithUniqueIdentification() { |
| 59 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); |
| 60 | + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); |
| 61 | + |
| 62 | + // Set up mock spans with unique IDs |
| 63 | + Span mockRootSpan = createMockSpan("trace-123", "span-root", MLAgentTracer.AGENT_TASK_SPAN); |
| 64 | + Span mockChildSpan = createMockSpan("trace-123", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN); |
| 65 | + Span mockGrandchildSpan = createMockSpan("trace-123", "span-grandchild", MLAgentTracer.AGENT_TOOL_CALL_SPAN); |
| 66 | + |
| 67 | + when(mockTracer.startSpan(any())).thenReturn(mockRootSpan, mockChildSpan, mockGrandchildSpan); |
| 68 | + |
| 69 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 70 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 71 | + |
| 72 | + // Create root span |
| 73 | + Map<String, String> rootAttributes = new HashMap<>(); |
| 74 | + rootAttributes.put("operation", "agent_task"); |
| 75 | + rootAttributes.put("request_id", "req-123"); |
| 76 | + Span rootSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, rootAttributes, null); |
| 77 | + assertNotNull("Root span should not be null", rootSpan); |
| 78 | + assertNotNull("Root span should have trace ID", rootSpan.getTraceId()); |
| 79 | + assertNotNull("Root span should have span ID", rootSpan.getSpanId()); |
| 80 | + |
| 81 | + // Create child span |
| 82 | + Map<String, String> childAttributes = new HashMap<>(); |
| 83 | + childAttributes.put("operation", "llm_call"); |
| 84 | + childAttributes.put("model", "gpt-4"); |
| 85 | + Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, rootSpan); |
| 86 | + assertNotNull("Child span should not be null", childSpan); |
| 87 | + assertNotNull("Child span should have trace ID", childSpan.getTraceId()); |
| 88 | + assertNotNull("Child span should have span ID", childSpan.getSpanId()); |
| 89 | + |
| 90 | + // Verify parent-child relationship |
| 91 | + assertEquals("Child should have same trace ID as parent", rootSpan.getTraceId(), childSpan.getTraceId()); |
| 92 | + assertNotSame("Child should have different span ID than parent", rootSpan.getSpanId(), childSpan.getSpanId()); |
| 93 | + |
| 94 | + // Create grandchild span |
| 95 | + Map<String, String> grandchildAttributes = new HashMap<>(); |
| 96 | + grandchildAttributes.put("operation", "tool_call"); |
| 97 | + grandchildAttributes.put("tool_name", "search_index"); |
| 98 | + Span grandchildSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, grandchildAttributes, childSpan); |
| 99 | + assertNotNull("Grandchild span should not be null", grandchildSpan); |
| 100 | + assertEquals("Grandchild should have same trace ID as root", rootSpan.getTraceId(), grandchildSpan.getTraceId()); |
| 101 | + assertNotSame("Grandchild should have different span ID than child", childSpan.getSpanId(), grandchildSpan.getSpanId()); |
| 102 | + |
| 103 | + // End spans in reverse order |
| 104 | + tracer.endSpan(grandchildSpan); |
| 105 | + tracer.endSpan(childSpan); |
| 106 | + tracer.endSpan(rootSpan); |
| 107 | + } |
| 108 | + |
| 109 | + @Test |
| 110 | + public void testSpanIdentificationWithAttributes() { |
| 111 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); |
| 112 | + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); |
| 113 | + |
| 114 | + // Set up mock spans with unique IDs but same trace ID |
| 115 | + Span mockTaskSpan = createMockSpan("trace-456", "span-task", MLAgentTracer.AGENT_TASK_SPAN); |
| 116 | + Span mockLlmSpan = createMockSpan("trace-456", "span-llm", MLAgentTracer.AGENT_LLM_CALL_SPAN); |
| 117 | + Span mockToolSpan = createMockSpan("trace-456", "span-tool", MLAgentTracer.AGENT_TOOL_CALL_SPAN); |
| 118 | + |
| 119 | + when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockLlmSpan, mockToolSpan); |
| 120 | + |
| 121 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 122 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 123 | + |
| 124 | + // Create spans with identifying attributes |
| 125 | + Map<String, String> taskAttributes = new HashMap<>(); |
| 126 | + taskAttributes.put("operation", "agent_task"); |
| 127 | + taskAttributes.put("request_id", "req-456"); |
| 128 | + taskAttributes.put("user_id", "user-789"); |
| 129 | + taskAttributes.put("session_id", "session-abc"); |
| 130 | + |
| 131 | + Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null); |
| 132 | + assertNotNull("Task span should not be null", taskSpan); |
| 133 | + |
| 134 | + // Create multiple child spans with different operations |
| 135 | + Map<String, String> llmAttributes = new HashMap<>(); |
| 136 | + llmAttributes.put("operation", "llm_call"); |
| 137 | + llmAttributes.put("model", "gpt-4"); |
| 138 | + llmAttributes.put("request_id", "req-456"); // Same request ID for correlation |
| 139 | + |
| 140 | + Span llmSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, llmAttributes, taskSpan); |
| 141 | + assertNotNull("LLM span should not be null", llmSpan); |
| 142 | + |
| 143 | + Map<String, String> toolAttributes = new HashMap<>(); |
| 144 | + toolAttributes.put("operation", "tool_call"); |
| 145 | + toolAttributes.put("tool_name", "search_index"); |
| 146 | + toolAttributes.put("request_id", "req-456"); // Same request ID for correlation |
| 147 | + |
| 148 | + Span toolSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, toolAttributes, taskSpan); |
| 149 | + assertNotNull("Tool span should not be null", toolSpan); |
| 150 | + |
| 151 | + // Verify all spans have unique IDs but same trace ID |
| 152 | + assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), llmSpan.getTraceId()); |
| 153 | + assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), toolSpan.getTraceId()); |
| 154 | + assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), llmSpan.getSpanId()); |
| 155 | + assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), toolSpan.getSpanId()); |
| 156 | + assertNotSame("Spans should have different span IDs", llmSpan.getSpanId(), toolSpan.getSpanId()); |
| 157 | + |
| 158 | + // End spans |
| 159 | + tracer.endSpan(toolSpan); |
| 160 | + tracer.endSpan(llmSpan); |
| 161 | + tracer.endSpan(taskSpan); |
| 162 | + } |
| 163 | + |
| 164 | + @Test |
| 165 | + public void testConcurrentSpanCreationNoRaceConditions() throws InterruptedException { |
| 166 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); |
| 167 | + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); |
| 168 | + |
| 169 | + // Set up mock spans for concurrent testing |
| 170 | + Span mockSpan = createMockSpan("trace-concurrent", "span-concurrent", MLAgentTracer.AGENT_TASK_SPAN); |
| 171 | + when(mockTracer.startSpan(any())).thenReturn(mockSpan); |
| 172 | + |
| 173 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 174 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 175 | + |
| 176 | + int numThreads = 10; |
| 177 | + int spansPerThread = 5; |
| 178 | + ExecutorService executor = Executors.newFixedThreadPool(numThreads); |
| 179 | + CountDownLatch startLatch = new CountDownLatch(1); |
| 180 | + CountDownLatch endLatch = new CountDownLatch(numThreads); |
| 181 | + AtomicInteger totalSpansCreated = new AtomicInteger(0); |
| 182 | + |
| 183 | + // Create multiple threads that create spans concurrently |
| 184 | + for (int i = 0; i < numThreads; i++) { |
| 185 | + final int threadId = i; |
| 186 | + executor.submit(() -> { |
| 187 | + try { |
| 188 | + startLatch.await(); // Wait for all threads to start together |
| 189 | + |
| 190 | + for (int j = 0; j < spansPerThread; j++) { |
| 191 | + Map<String, String> attributes = new HashMap<>(); |
| 192 | + attributes.put("thread_id", String.valueOf(threadId)); |
| 193 | + attributes.put("span_index", String.valueOf(j)); |
| 194 | + |
| 195 | + Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null); |
| 196 | + assertNotNull("Span should not be null", span); |
| 197 | + assertNotNull("Span should have trace ID", span.getTraceId()); |
| 198 | + assertNotNull("Span should have span ID", span.getSpanId()); |
| 199 | + |
| 200 | + totalSpansCreated.incrementAndGet(); |
| 201 | + tracer.endSpan(span); |
| 202 | + } |
| 203 | + } catch (InterruptedException e) { |
| 204 | + Thread.currentThread().interrupt(); |
| 205 | + } finally { |
| 206 | + endLatch.countDown(); |
| 207 | + } |
| 208 | + }); |
| 209 | + } |
| 210 | + |
| 211 | + // Start all threads simultaneously |
| 212 | + startLatch.countDown(); |
| 213 | + |
| 214 | + // Wait for all threads to complete |
| 215 | + boolean completed = endLatch.await(10, TimeUnit.SECONDS); |
| 216 | + assertTrue("All threads should complete within timeout", completed); |
| 217 | + assertEquals("Should create expected number of spans", numThreads * spansPerThread, totalSpansCreated.get()); |
| 218 | + |
| 219 | + executor.shutdown(); |
| 220 | + assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS)); |
| 221 | + } |
| 222 | + |
| 223 | + @Test |
| 224 | + public void testConcurrentParentChildSpanCreation() throws InterruptedException { |
| 225 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); |
| 226 | + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); |
| 227 | + |
| 228 | + // Set up mock spans for parent-child testing |
| 229 | + Span mockParentSpan = createMockSpan("trace-parent", "span-parent", MLAgentTracer.AGENT_TASK_SPAN); |
| 230 | + Span mockChildSpan = createMockSpan("trace-parent", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN); |
| 231 | + when(mockTracer.startSpan(any())).thenReturn(mockParentSpan, mockChildSpan); |
| 232 | + |
| 233 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 234 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 235 | + |
| 236 | + int numThreads = 5; |
| 237 | + ExecutorService executor = Executors.newFixedThreadPool(numThreads); |
| 238 | + CountDownLatch startLatch = new CountDownLatch(1); |
| 239 | + CountDownLatch endLatch = new CountDownLatch(numThreads); |
| 240 | + |
| 241 | + // Create multiple threads that create parent-child span hierarchies concurrently |
| 242 | + for (int i = 0; i < numThreads; i++) { |
| 243 | + final int threadId = i; |
| 244 | + executor.submit(() -> { |
| 245 | + try { |
| 246 | + startLatch.await(); // Wait for all threads to start together |
| 247 | + |
| 248 | + // Create parent span |
| 249 | + Map<String, String> parentAttributes = new HashMap<>(); |
| 250 | + parentAttributes.put("thread_id", String.valueOf(threadId)); |
| 251 | + |
| 252 | + Span parentSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, parentAttributes, null); |
| 253 | + assertNotNull("Parent span should not be null", parentSpan); |
| 254 | + String parentTraceId = parentSpan.getTraceId(); |
| 255 | + String parentSpanId = parentSpan.getSpanId(); |
| 256 | + |
| 257 | + // Create multiple child spans |
| 258 | + for (int j = 0; j < 3; j++) { |
| 259 | + Map<String, String> childAttributes = new HashMap<>(); |
| 260 | + childAttributes.put("thread_id", String.valueOf(threadId)); |
| 261 | + childAttributes.put("child_index", String.valueOf(j)); |
| 262 | + |
| 263 | + Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, parentSpan); |
| 264 | + assertNotNull("Child span should not be null", childSpan); |
| 265 | + |
| 266 | + // Verify parent-child relationship |
| 267 | + assertEquals("Child should have same trace ID as parent", parentTraceId, childSpan.getTraceId()); |
| 268 | + assertNotSame("Child should have different span ID than parent", parentSpanId, childSpan.getSpanId()); |
| 269 | + |
| 270 | + tracer.endSpan(childSpan); |
| 271 | + } |
| 272 | + |
| 273 | + tracer.endSpan(parentSpan); |
| 274 | + } catch (InterruptedException e) { |
| 275 | + Thread.currentThread().interrupt(); |
| 276 | + } finally { |
| 277 | + endLatch.countDown(); |
| 278 | + } |
| 279 | + }); |
| 280 | + } |
| 281 | + |
| 282 | + // Start all threads simultaneously |
| 283 | + startLatch.countDown(); |
| 284 | + |
| 285 | + // Wait for all threads to complete |
| 286 | + boolean completed = endLatch.await(10, TimeUnit.SECONDS); |
| 287 | + assertTrue("All threads should complete within timeout", completed); |
| 288 | + |
| 289 | + executor.shutdown(); |
| 290 | + assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS)); |
| 291 | + } |
| 292 | + |
| 293 | + @Test |
| 294 | + public void testSpanIdentificationWithErrorScenarios() { |
| 295 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); |
| 296 | + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); |
| 297 | + |
| 298 | + // Set up mock spans for error scenario testing |
| 299 | + Span mockTaskSpan = createMockSpan("trace-error", "span-task", MLAgentTracer.AGENT_TASK_SPAN); |
| 300 | + Span mockErrorSpan = createMockSpan("trace-error", "span-error", MLAgentTracer.AGENT_LLM_CALL_SPAN); |
| 301 | + Span mockSuccessSpan = createMockSpan("trace-error", "span-success", MLAgentTracer.AGENT_TOOL_CALL_SPAN); |
| 302 | + when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockErrorSpan, mockSuccessSpan); |
| 303 | + |
| 304 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 305 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 306 | + |
| 307 | + // Create a span hierarchy where one span fails |
| 308 | + Map<String, String> taskAttributes = new HashMap<>(); |
| 309 | + taskAttributes.put("operation", "agent_task"); |
| 310 | + taskAttributes.put("request_id", "req-error-test"); |
| 311 | + taskAttributes.put("user_id", "user-error"); |
| 312 | + |
| 313 | + Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null); |
| 314 | + assertNotNull("Task span should not be null", taskSpan); |
| 315 | + |
| 316 | + // Create child span that will "fail" |
| 317 | + Map<String, String> errorAttributes = new HashMap<>(); |
| 318 | + errorAttributes.put("operation", "llm_call"); |
| 319 | + errorAttributes.put("model", "gpt-4"); |
| 320 | + errorAttributes.put("request_id", "req-error-test"); |
| 321 | + |
| 322 | + Span errorSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, errorAttributes, taskSpan); |
| 323 | + assertNotNull("Error span should not be null", errorSpan); |
| 324 | + |
| 325 | + // Simulate an error in the child span |
| 326 | + try { |
| 327 | + // Simulate some operation that might fail |
| 328 | + throw new RuntimeException("Simulated LLM call failure"); |
| 329 | + } catch (Exception e) { |
| 330 | + // Mark the span as having an error |
| 331 | + errorSpan.setError(e); |
| 332 | + } |
| 333 | + |
| 334 | + // Create another child span that succeeds |
| 335 | + Map<String, String> successAttributes = new HashMap<>(); |
| 336 | + successAttributes.put("operation", "tool_call"); |
| 337 | + successAttributes.put("tool_name", "search_index"); |
| 338 | + successAttributes.put("request_id", "req-error-test"); |
| 339 | + |
| 340 | + Span successSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, successAttributes, taskSpan); |
| 341 | + assertNotNull("Success span should not be null", successSpan); |
| 342 | + |
| 343 | + // Verify all spans have proper identification for debugging |
| 344 | + assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), errorSpan.getTraceId()); |
| 345 | + assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), successSpan.getTraceId()); |
| 346 | + assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), errorSpan.getSpanId()); |
| 347 | + assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), successSpan.getSpanId()); |
| 348 | + assertNotSame("Each span should have unique span ID", errorSpan.getSpanId(), successSpan.getSpanId()); |
| 349 | + |
| 350 | + // End spans |
| 351 | + tracer.endSpan(successSpan); |
| 352 | + tracer.endSpan(errorSpan); |
| 353 | + tracer.endSpan(taskSpan); |
| 354 | + } |
| 355 | + |
| 356 | + @Test |
| 357 | + public void testNoopTracerMultiSpanBehavior() { |
| 358 | + when(mockFeatureSetting.isTracingEnabled()).thenReturn(false); |
| 359 | + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); |
| 360 | + MLAgentTracer tracer = MLAgentTracer.getInstance(); |
| 361 | + |
| 362 | + // Test that noop tracer still provides proper span identification |
| 363 | + Map<String, String> attributes = new HashMap<>(); |
| 364 | + attributes.put("operation", "noop_test"); |
| 365 | + attributes.put("request_id", "req-noop"); |
| 366 | + |
| 367 | + Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null); |
| 368 | + assertNotNull("Noop span should not be null", span); |
| 369 | + |
| 370 | + // Noop spans should still have identification for debugging |
| 371 | + assertNotNull("Noop span should have trace ID", span.getTraceId()); |
| 372 | + assertNotNull("Noop span should have span ID", span.getSpanId()); |
| 373 | + |
| 374 | + tracer.endSpan(span); |
| 375 | + } |
| 376 | +} |
0 commit comments