Skip to content

Commit 8598ede

Browse files
committed
MLAgentTracer Class (opensearch-project#3946)
* adding agent tracing to mlplugin Signed-off-by: chrislai <[email protected]> * add settings and clean code Signed-off-by: chrislai <[email protected]> * add tests Signed-off-by: chrislai <[email protected]> * cr fixes Signed-off-by: chrislai <[email protected]> * tests message fix and spotlessApply Signed-off-by: chrislai <[email protected]> * spotlessApply Signed-off-by: chrislai <[email protected]> * add tests Signed-off-by: chrislai <[email protected]> * fix comments Signed-off-by: chrislai <[email protected]> * add visiblefortesting Signed-off-by: chrislai <[email protected]> * address comments Signed-off-by: chrislai <[email protected]> * javadoc Signed-off-by: chrislai <[email protected]> * more tests Signed-off-by: chrislai <[email protected]> * more class fixes Signed-off-by: chrislai <[email protected]> * overloaded method Signed-off-by: chrislai <[email protected]> --------- Signed-off-by: chrislai <[email protected]>
1 parent 85403d2 commit 8598ede

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.agent.tracing;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertNotSame;
11+
import static org.junit.Assert.assertTrue;
12+
import static org.mockito.ArgumentMatchers.any;
13+
import static org.mockito.Mockito.mock;
14+
import static org.mockito.Mockito.when;
15+
16+
import java.util.HashMap;
17+
import java.util.Map;
18+
import java.util.concurrent.CountDownLatch;
19+
import java.util.concurrent.ExecutorService;
20+
import java.util.concurrent.Executors;
21+
import java.util.concurrent.TimeUnit;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
24+
import org.junit.Before;
25+
import org.junit.Test;
26+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
27+
import org.opensearch.telemetry.tracing.Span;
28+
import org.opensearch.telemetry.tracing.Tracer;
29+
30+
/**
31+
* Tests for multi-span scenarios, span identification, and race conditions in MLAgentTracer.
32+
* These tests ensure that:
33+
* 1. Each span has unique identification
34+
* 2. Parent-child relationships are properly maintained
35+
* 3. No race conditions occur in concurrent usage
36+
* 4. Span attributes help identify which span is problematic
37+
*/
38+
public class MLAgentTracerMultiSpanTests {
39+
private MLFeatureEnabledSetting mockFeatureSetting;
40+
private Tracer mockTracer;
41+
42+
@Before
43+
public void setup() {
44+
mockFeatureSetting = mock(MLFeatureEnabledSetting.class);
45+
mockTracer = mock(Tracer.class);
46+
MLAgentTracer.resetForTest();
47+
}
48+
49+
private Span createMockSpan(String traceId, String spanId, String spanName) {
50+
Span mockSpan = mock(Span.class);
51+
when(mockSpan.getTraceId()).thenReturn(traceId);
52+
when(mockSpan.getSpanId()).thenReturn(spanId);
53+
when(mockSpan.getSpanName()).thenReturn(spanName);
54+
return mockSpan;
55+
}
56+
57+
@Test
58+
public void testMultiSpanHierarchyWithUniqueIdentification() {
59+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(true);
60+
when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true);
61+
62+
// Set up mock spans with unique IDs
63+
Span mockRootSpan = createMockSpan("trace-123", "span-root", MLAgentTracer.AGENT_TASK_SPAN);
64+
Span mockChildSpan = createMockSpan("trace-123", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN);
65+
Span mockGrandchildSpan = createMockSpan("trace-123", "span-grandchild", MLAgentTracer.AGENT_TOOL_CALL_SPAN);
66+
67+
when(mockTracer.startSpan(any())).thenReturn(mockRootSpan, mockChildSpan, mockGrandchildSpan);
68+
69+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
70+
MLAgentTracer tracer = MLAgentTracer.getInstance();
71+
72+
// Create root span
73+
Map<String, String> rootAttributes = new HashMap<>();
74+
rootAttributes.put("operation", "agent_task");
75+
rootAttributes.put("request_id", "req-123");
76+
Span rootSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, rootAttributes, null);
77+
assertNotNull("Root span should not be null", rootSpan);
78+
assertNotNull("Root span should have trace ID", rootSpan.getTraceId());
79+
assertNotNull("Root span should have span ID", rootSpan.getSpanId());
80+
81+
// Create child span
82+
Map<String, String> childAttributes = new HashMap<>();
83+
childAttributes.put("operation", "llm_call");
84+
childAttributes.put("model", "gpt-4");
85+
Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, rootSpan);
86+
assertNotNull("Child span should not be null", childSpan);
87+
assertNotNull("Child span should have trace ID", childSpan.getTraceId());
88+
assertNotNull("Child span should have span ID", childSpan.getSpanId());
89+
90+
// Verify parent-child relationship
91+
assertEquals("Child should have same trace ID as parent", rootSpan.getTraceId(), childSpan.getTraceId());
92+
assertNotSame("Child should have different span ID than parent", rootSpan.getSpanId(), childSpan.getSpanId());
93+
94+
// Create grandchild span
95+
Map<String, String> grandchildAttributes = new HashMap<>();
96+
grandchildAttributes.put("operation", "tool_call");
97+
grandchildAttributes.put("tool_name", "search_index");
98+
Span grandchildSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, grandchildAttributes, childSpan);
99+
assertNotNull("Grandchild span should not be null", grandchildSpan);
100+
assertEquals("Grandchild should have same trace ID as root", rootSpan.getTraceId(), grandchildSpan.getTraceId());
101+
assertNotSame("Grandchild should have different span ID than child", childSpan.getSpanId(), grandchildSpan.getSpanId());
102+
103+
// End spans in reverse order
104+
tracer.endSpan(grandchildSpan);
105+
tracer.endSpan(childSpan);
106+
tracer.endSpan(rootSpan);
107+
}
108+
109+
@Test
110+
public void testSpanIdentificationWithAttributes() {
111+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(true);
112+
when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true);
113+
114+
// Set up mock spans with unique IDs but same trace ID
115+
Span mockTaskSpan = createMockSpan("trace-456", "span-task", MLAgentTracer.AGENT_TASK_SPAN);
116+
Span mockLlmSpan = createMockSpan("trace-456", "span-llm", MLAgentTracer.AGENT_LLM_CALL_SPAN);
117+
Span mockToolSpan = createMockSpan("trace-456", "span-tool", MLAgentTracer.AGENT_TOOL_CALL_SPAN);
118+
119+
when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockLlmSpan, mockToolSpan);
120+
121+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
122+
MLAgentTracer tracer = MLAgentTracer.getInstance();
123+
124+
// Create spans with identifying attributes
125+
Map<String, String> taskAttributes = new HashMap<>();
126+
taskAttributes.put("operation", "agent_task");
127+
taskAttributes.put("request_id", "req-456");
128+
taskAttributes.put("user_id", "user-789");
129+
taskAttributes.put("session_id", "session-abc");
130+
131+
Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null);
132+
assertNotNull("Task span should not be null", taskSpan);
133+
134+
// Create multiple child spans with different operations
135+
Map<String, String> llmAttributes = new HashMap<>();
136+
llmAttributes.put("operation", "llm_call");
137+
llmAttributes.put("model", "gpt-4");
138+
llmAttributes.put("request_id", "req-456"); // Same request ID for correlation
139+
140+
Span llmSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, llmAttributes, taskSpan);
141+
assertNotNull("LLM span should not be null", llmSpan);
142+
143+
Map<String, String> toolAttributes = new HashMap<>();
144+
toolAttributes.put("operation", "tool_call");
145+
toolAttributes.put("tool_name", "search_index");
146+
toolAttributes.put("request_id", "req-456"); // Same request ID for correlation
147+
148+
Span toolSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, toolAttributes, taskSpan);
149+
assertNotNull("Tool span should not be null", toolSpan);
150+
151+
// Verify all spans have unique IDs but same trace ID
152+
assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), llmSpan.getTraceId());
153+
assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), toolSpan.getTraceId());
154+
assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), llmSpan.getSpanId());
155+
assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), toolSpan.getSpanId());
156+
assertNotSame("Spans should have different span IDs", llmSpan.getSpanId(), toolSpan.getSpanId());
157+
158+
// End spans
159+
tracer.endSpan(toolSpan);
160+
tracer.endSpan(llmSpan);
161+
tracer.endSpan(taskSpan);
162+
}
163+
164+
@Test
165+
public void testConcurrentSpanCreationNoRaceConditions() throws InterruptedException {
166+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(true);
167+
when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true);
168+
169+
// Set up mock spans for concurrent testing
170+
Span mockSpan = createMockSpan("trace-concurrent", "span-concurrent", MLAgentTracer.AGENT_TASK_SPAN);
171+
when(mockTracer.startSpan(any())).thenReturn(mockSpan);
172+
173+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
174+
MLAgentTracer tracer = MLAgentTracer.getInstance();
175+
176+
int numThreads = 10;
177+
int spansPerThread = 5;
178+
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
179+
CountDownLatch startLatch = new CountDownLatch(1);
180+
CountDownLatch endLatch = new CountDownLatch(numThreads);
181+
AtomicInteger totalSpansCreated = new AtomicInteger(0);
182+
183+
// Create multiple threads that create spans concurrently
184+
for (int i = 0; i < numThreads; i++) {
185+
final int threadId = i;
186+
executor.submit(() -> {
187+
try {
188+
startLatch.await(); // Wait for all threads to start together
189+
190+
for (int j = 0; j < spansPerThread; j++) {
191+
Map<String, String> attributes = new HashMap<>();
192+
attributes.put("thread_id", String.valueOf(threadId));
193+
attributes.put("span_index", String.valueOf(j));
194+
195+
Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null);
196+
assertNotNull("Span should not be null", span);
197+
assertNotNull("Span should have trace ID", span.getTraceId());
198+
assertNotNull("Span should have span ID", span.getSpanId());
199+
200+
totalSpansCreated.incrementAndGet();
201+
tracer.endSpan(span);
202+
}
203+
} catch (InterruptedException e) {
204+
Thread.currentThread().interrupt();
205+
} finally {
206+
endLatch.countDown();
207+
}
208+
});
209+
}
210+
211+
// Start all threads simultaneously
212+
startLatch.countDown();
213+
214+
// Wait for all threads to complete
215+
boolean completed = endLatch.await(10, TimeUnit.SECONDS);
216+
assertTrue("All threads should complete within timeout", completed);
217+
assertEquals("Should create expected number of spans", numThreads * spansPerThread, totalSpansCreated.get());
218+
219+
executor.shutdown();
220+
assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS));
221+
}
222+
223+
@Test
224+
public void testConcurrentParentChildSpanCreation() throws InterruptedException {
225+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(true);
226+
when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true);
227+
228+
// Set up mock spans for parent-child testing
229+
Span mockParentSpan = createMockSpan("trace-parent", "span-parent", MLAgentTracer.AGENT_TASK_SPAN);
230+
Span mockChildSpan = createMockSpan("trace-parent", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN);
231+
when(mockTracer.startSpan(any())).thenReturn(mockParentSpan, mockChildSpan);
232+
233+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
234+
MLAgentTracer tracer = MLAgentTracer.getInstance();
235+
236+
int numThreads = 5;
237+
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
238+
CountDownLatch startLatch = new CountDownLatch(1);
239+
CountDownLatch endLatch = new CountDownLatch(numThreads);
240+
241+
// Create multiple threads that create parent-child span hierarchies concurrently
242+
for (int i = 0; i < numThreads; i++) {
243+
final int threadId = i;
244+
executor.submit(() -> {
245+
try {
246+
startLatch.await(); // Wait for all threads to start together
247+
248+
// Create parent span
249+
Map<String, String> parentAttributes = new HashMap<>();
250+
parentAttributes.put("thread_id", String.valueOf(threadId));
251+
252+
Span parentSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, parentAttributes, null);
253+
assertNotNull("Parent span should not be null", parentSpan);
254+
String parentTraceId = parentSpan.getTraceId();
255+
String parentSpanId = parentSpan.getSpanId();
256+
257+
// Create multiple child spans
258+
for (int j = 0; j < 3; j++) {
259+
Map<String, String> childAttributes = new HashMap<>();
260+
childAttributes.put("thread_id", String.valueOf(threadId));
261+
childAttributes.put("child_index", String.valueOf(j));
262+
263+
Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, parentSpan);
264+
assertNotNull("Child span should not be null", childSpan);
265+
266+
// Verify parent-child relationship
267+
assertEquals("Child should have same trace ID as parent", parentTraceId, childSpan.getTraceId());
268+
assertNotSame("Child should have different span ID than parent", parentSpanId, childSpan.getSpanId());
269+
270+
tracer.endSpan(childSpan);
271+
}
272+
273+
tracer.endSpan(parentSpan);
274+
} catch (InterruptedException e) {
275+
Thread.currentThread().interrupt();
276+
} finally {
277+
endLatch.countDown();
278+
}
279+
});
280+
}
281+
282+
// Start all threads simultaneously
283+
startLatch.countDown();
284+
285+
// Wait for all threads to complete
286+
boolean completed = endLatch.await(10, TimeUnit.SECONDS);
287+
assertTrue("All threads should complete within timeout", completed);
288+
289+
executor.shutdown();
290+
assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS));
291+
}
292+
293+
@Test
294+
public void testSpanIdentificationWithErrorScenarios() {
295+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(true);
296+
when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true);
297+
298+
// Set up mock spans for error scenario testing
299+
Span mockTaskSpan = createMockSpan("trace-error", "span-task", MLAgentTracer.AGENT_TASK_SPAN);
300+
Span mockErrorSpan = createMockSpan("trace-error", "span-error", MLAgentTracer.AGENT_LLM_CALL_SPAN);
301+
Span mockSuccessSpan = createMockSpan("trace-error", "span-success", MLAgentTracer.AGENT_TOOL_CALL_SPAN);
302+
when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockErrorSpan, mockSuccessSpan);
303+
304+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
305+
MLAgentTracer tracer = MLAgentTracer.getInstance();
306+
307+
// Create a span hierarchy where one span fails
308+
Map<String, String> taskAttributes = new HashMap<>();
309+
taskAttributes.put("operation", "agent_task");
310+
taskAttributes.put("request_id", "req-error-test");
311+
taskAttributes.put("user_id", "user-error");
312+
313+
Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null);
314+
assertNotNull("Task span should not be null", taskSpan);
315+
316+
// Create child span that will "fail"
317+
Map<String, String> errorAttributes = new HashMap<>();
318+
errorAttributes.put("operation", "llm_call");
319+
errorAttributes.put("model", "gpt-4");
320+
errorAttributes.put("request_id", "req-error-test");
321+
322+
Span errorSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, errorAttributes, taskSpan);
323+
assertNotNull("Error span should not be null", errorSpan);
324+
325+
// Simulate an error in the child span
326+
try {
327+
// Simulate some operation that might fail
328+
throw new RuntimeException("Simulated LLM call failure");
329+
} catch (Exception e) {
330+
// Mark the span as having an error
331+
errorSpan.setError(e);
332+
}
333+
334+
// Create another child span that succeeds
335+
Map<String, String> successAttributes = new HashMap<>();
336+
successAttributes.put("operation", "tool_call");
337+
successAttributes.put("tool_name", "search_index");
338+
successAttributes.put("request_id", "req-error-test");
339+
340+
Span successSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, successAttributes, taskSpan);
341+
assertNotNull("Success span should not be null", successSpan);
342+
343+
// Verify all spans have proper identification for debugging
344+
assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), errorSpan.getTraceId());
345+
assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), successSpan.getTraceId());
346+
assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), errorSpan.getSpanId());
347+
assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), successSpan.getSpanId());
348+
assertNotSame("Each span should have unique span ID", errorSpan.getSpanId(), successSpan.getSpanId());
349+
350+
// End spans
351+
tracer.endSpan(successSpan);
352+
tracer.endSpan(errorSpan);
353+
tracer.endSpan(taskSpan);
354+
}
355+
356+
@Test
357+
public void testNoopTracerMultiSpanBehavior() {
358+
when(mockFeatureSetting.isTracingEnabled()).thenReturn(false);
359+
MLAgentTracer.initialize(mockTracer, mockFeatureSetting);
360+
MLAgentTracer tracer = MLAgentTracer.getInstance();
361+
362+
// Test that noop tracer still provides proper span identification
363+
Map<String, String> attributes = new HashMap<>();
364+
attributes.put("operation", "noop_test");
365+
attributes.put("request_id", "req-noop");
366+
367+
Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null);
368+
assertNotNull("Noop span should not be null", span);
369+
370+
// Noop spans should still have identification for debugging
371+
assertNotNull("Noop span should have trace ID", span.getTraceId());
372+
assertNotNull("Noop span should have span ID", span.getSpanId());
373+
374+
tracer.endSpan(span);
375+
}
376+
}

0 commit comments

Comments
 (0)