diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index 7d631ac2f6..cd569659ce 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -350,4 +350,18 @@ private MLCommonsSettings() {} // Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor public static final Setting ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting .boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final); + + // Feature flag for enabling telemetry tracer + // This setting is Final because it controls the core tracing infrastructure initialization. + // Once the tracer is initialized, changing this setting would require a node restart + // to properly reinitialize the tracing components. + public static final Setting ML_COMMONS_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Final); + + // Feature flag for enabling telemetry agent tracing + // This setting is Dynamic because agent tracing can be enabled/disabled at runtime + // without requiring a node restart. The MLAgentTracer singleton can be updated + // to switch between real tracer and NoopTracer based on this setting. + public static final Setting ML_COMMONS_AGENT_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.agent_tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index 786af9e29c..fb44acffce 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -8,6 +8,7 @@ package org.opensearch.ml.common.settings; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; @@ -19,6 +20,7 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_TRACING_ENABLED; import java.util.ArrayList; import java.util.List; @@ -51,6 +53,9 @@ public class MLFeatureEnabledSetting { private volatile Boolean isMetricCollectionEnabled; private volatile Boolean isStaticMetricCollectionEnabled; + private volatile Boolean isTracingEnabled; + private volatile Boolean isAgentTracingEnabled; + private final List listeners = new ArrayList<>(); public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { @@ -66,6 +71,8 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings); isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings); isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings); + isTracingEnabled = ML_COMMONS_TRACING_ENABLED.get(settings); + isAgentTracingEnabled = ML_COMMONS_AGENT_TRACING_ENABLED.get(settings); clusterService .getClusterSettings() @@ -88,6 +95,9 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED, it -> isAgentTracingEnabled = it); } /** @@ -178,6 +188,14 @@ public boolean isStaticMetricCollectionEnabled() { return isStaticMetricCollectionEnabled; } + public boolean isTracingEnabled() { + return isTracingEnabled; + } + + public boolean isAgentTracingEnabled() { + return isAgentTracingEnabled; + } + @VisibleForTesting public void notifyMultiTenancyListeners(boolean isEnabled) { for (SettingsChangeListener listener : listeners) { diff --git a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java index e1dc2b2030..485d203d63 100644 --- a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java @@ -43,7 +43,9 @@ public void setUp() { MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED, - MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED + MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, + MLCommonsSettings.ML_COMMONS_TRACING_ENABLED, + MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED ) ); when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings); @@ -65,6 +67,8 @@ public void testDefaults_allFeaturesEnabled() { .put("plugins.ml_commons.rag_pipeline_feature_enabled", true) .put("plugins.ml_commons.metrics_collection_enabled", true) .put("plugins.ml_commons.metrics_static_collection_enabled", true) + .put("plugins.ml_commons.tracing_enabled", true) + .put("plugins.ml_commons.agent_tracing_enabled", true) .build(); MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); @@ -81,6 +85,8 @@ public void testDefaults_allFeaturesEnabled() { assertTrue(setting.isRagSearchPipelineEnabled()); assertTrue(setting.isMetricCollectionEnabled()); assertTrue(setting.isStaticMetricCollectionEnabled()); + assertTrue(setting.isTracingEnabled()); + assertTrue(setting.isAgentTracingEnabled()); } @Test @@ -99,6 +105,8 @@ public void testDefaults_someFeaturesDisabled() { .put("plugins.ml_commons.rag_pipeline_feature_enabled", false) .put("plugins.ml_commons.metrics_collection_enabled", false) .put("plugins.ml_commons.metrics_static_collection_enabled", false) + .put("plugins.ml_commons.tracing_enabled", false) + .put("plugins.ml_commons.agent_tracing_enabled", false) .build(); MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); @@ -115,6 +123,8 @@ public void testDefaults_someFeaturesDisabled() { assertFalse(setting.isRagSearchPipelineEnabled()); assertFalse(setting.isMetricCollectionEnabled()); assertFalse(setting.isStaticMetricCollectionEnabled()); + assertFalse(setting.isTracingEnabled()); + assertFalse(setting.isAgentTracingEnabled()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/AbstractMLTracer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/AbstractMLTracer.java new file mode 100644 index 0000000000..3d878c1cac --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/AbstractMLTracer.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent.tracing; + +import java.util.Map; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.telemetry.tracing.Span; +import org.opensearch.telemetry.tracing.Tracer; + +/** + * Abstract base class for tracing implementations in ML Commons. + * + * This class defines the common interface and shared state for all ML tracing logic, + * such as starting and ending spans. Concrete subclasses (such as {@link MLAgentTracer}) + * implement tracing for specific ML components or workflows. + * + * The intention is to allow for future extension: additional tracers can be created + * for other ML features (e.g., connector tracing) by extending this class. + * + * Each call to {@link #startSpan(String, Map, Span)} returns a {@link Span} object, + * which acts as a handle to the started span. The {@link Span} object typically contains + * a unique identifier (span ID) that can be used for logging and debugging. When ending + * a span, always pass the same {@link Span} object to {@link #endSpan(Span)}. + */ +public abstract class AbstractMLTracer { + protected final Tracer tracer; + protected final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructs a new AbstractMLTracer with the specified tracer and feature settings. + * + * @param tracer The underlying tracer implementation to use for span operations. + * This may be a real tracer or a no-op tracer depending on configuration. + * @param mlFeatureEnabledSetting The ML feature settings that control tracing behavior. + * Used to determine if tracing is enabled and which features + * should be traced. + */ + protected AbstractMLTracer(Tracer tracer, MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.tracer = tracer; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + /** + * Starts a new span for agent tracing with the specified name and attributes, and no parent span. + *

+ * This is a convenience overload for starting a root span. It is equivalent to calling + * {@link #startSpan(String, Map, Span)} with {@code parentSpan} set to {@code null}. + *

+ * The returned span should be passed to {@link #endSpan(Span)} when the operation completes. + * + * @param name The name of the span. Should follow the naming convention defined by + * the span constants (e.g., AGENT_TASK_SPAN, AGENT_TOOL_CALL_SPAN). + * @param attributes A map of key-value pairs to associate with the span. These + * provide additional context about the operation being traced. + * May be null or empty if no attributes are needed. + * @return A Span object representing the started root span, or null if tracing is disabled. + */ + public abstract Span startSpan(String name, Map attributes); + + /** + * Starts a new span for tracing ML operations. + * + * This method creates a new span with the specified name and attributes. The span + * can be either a root span (when parentSpan is null) or a child span of the + * specified parent span. The returned Span object should be passed to + * {@link #endSpan(Span)} when the operation completes. + * + * @param name The name of the span. + * @param attributes A map of key-value pairs to associate with the span. These attributes + * provide additional context about the operation being traced. May be null + * or empty if no attributes are needed. + * @param parentSpan The parent span, or null if this should be a root span. Child spans + * are nested under their parent spans in the trace hierarchy. + * @return A Span object representing the started span, or null if tracing is disabled. + * The returned span should be passed to {@link #endSpan(Span)} when the + * operation completes. + */ + public abstract Span startSpan(String name, Map attributes, Span parentSpan); + + /** + * Ends a previously started span. + * + * This method marks the completion of a span and finalizes its timing information. + * The span will be recorded in the trace with its start time, end time, and any + * attributes that were set during its lifetime. + * + * @param span The span to end. This should be the same Span object that was returned + * by a previous call to {@link #startSpan(String, Map, Span)}. If null, + * this method is a no-op. + */ + public abstract void endSpan(Span span); +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracer.java new file mode 100644 index 0000000000..318683e440 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracer.java @@ -0,0 +1,340 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent.tracing; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.settings.MLCommonsSettings; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; +import org.opensearch.telemetry.tracing.Span; +import org.opensearch.telemetry.tracing.SpanContext; +import org.opensearch.telemetry.tracing.SpanCreationContext; +import org.opensearch.telemetry.tracing.Tracer; +import org.opensearch.telemetry.tracing.attributes.Attributes; +import org.opensearch.telemetry.tracing.noop.NoopTracer; + +import lombok.extern.log4j.Log4j2; + +/** + * MLAgentTracer is a concrete implementation of AbstractMLTracer for agent tracing in ML Commons. + * It manages the lifecycle of agent-related spans, including creation, context propagation, and completion. + * + * This class is implemented as a singleton to ensure that only one tracer is active + * for agent tracing at any time. This design provides consistent management of tracing state and configuration, + * and avoids issues with multiple tracers being active at once. + * The singleton can be dynamically enabled or disabled based on cluster settings. + * + * This class is thread-safe: multiple threads can use the singleton instance to start and end spans concurrently. + * Each call to {@link #startSpan(String, Map, Span)} creates a new, independent span. + */ +@Log4j2 +public class MLAgentTracer extends AbstractMLTracer { + public static final String AGENT_TASK_SPAN = "agent.task"; + public static final String AGENT_CONV_TASK_SPAN = "agent.conv_task"; + public static final String AGENT_LLM_CALL_SPAN = "agent.llm_call"; + public static final String AGENT_TOOL_CALL_SPAN = "agent.tool_call"; + public static final String AGENT_PLAN_SPAN = "agent.plan"; + public static final String AGENT_EXECUTE_STEP_SPAN = "agent.execute_step"; + public static final String AGENT_REFLECT_STEP_SPAN = "agent.reflect_step"; + public static final String AGENT_TASK_PER_SPAN = "agent.task_per"; + public static final String AGENT_TASK_CONV_SPAN = "agent.task_conv"; + public static final String AGENT_TASK_CONV_FLOW_SPAN = "agent.task_convflow"; + public static final String AGENT_TASK_FLOW_SPAN = "agent.task_flow"; + + private static MLAgentTracer instance; + + /** + * Private constructor for MLAgentTracer. + * @param tracer The tracer implementation to use (may be a real tracer or NoopTracer). + * @param mlFeatureEnabledSetting The ML feature settings. + */ + private MLAgentTracer(Tracer tracer, MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super(tracer, mlFeatureEnabledSetting); + } + + /** + * Initializes the singleton MLAgentTracer instance with the given tracer and settings. + * This is a convenience method that calls the full initialize method with a null ClusterService. + * + * @param tracer The tracer implementation to use. If null or if tracing is disabled, + * a NoopTracer will be used instead. + * @param mlFeatureEnabledSetting The ML feature settings that control tracing behavior. + * If null, tracing will be disabled. + */ + public static synchronized void initialize(Tracer tracer, MLFeatureEnabledSetting mlFeatureEnabledSetting) { + initialize(tracer, mlFeatureEnabledSetting, null); + } + + /** + * Initializes the singleton MLAgentTracer instance with the given tracer and settings. + * If agent tracing is disabled, a NoopTracer is used. + * @param tracer The tracer implementation to use. If null or if tracing is disabled, + * a NoopTracer will be used instead. + * @param mlFeatureEnabledSetting The ML feature settings that control tracing behavior. + * If null, tracing will be disabled. + */ + public static synchronized void initialize( + Tracer tracer, + MLFeatureEnabledSetting mlFeatureEnabledSetting, + ClusterService clusterService + ) { + Tracer tracerToUse = (mlFeatureEnabledSetting != null + && mlFeatureEnabledSetting.isTracingEnabled() + && mlFeatureEnabledSetting.isAgentTracingEnabled() + && tracer != null) ? tracer : NoopTracer.INSTANCE; + + instance = new MLAgentTracer(tracerToUse, mlFeatureEnabledSetting); + log.info("MLAgentTracer initialized with {}", tracerToUse.getClass().getSimpleName()); + + if (clusterService != null) { + clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED, enabled -> { + Tracer newTracerToUse = (mlFeatureEnabledSetting != null + && mlFeatureEnabledSetting.isTracingEnabled() + && enabled + && tracer != null) ? tracer : NoopTracer.INSTANCE; + instance = new MLAgentTracer(newTracerToUse, mlFeatureEnabledSetting); + log.info("MLAgentTracer re-initialized with {} due to setting change", newTracerToUse.getClass().getSimpleName()); + }); + } + } + + /** + * Returns the singleton MLAgentTracer instance. + * @return The MLAgentTracer instance. + * @throws IllegalStateException if the tracer is not initialized. + */ + public static synchronized MLAgentTracer getInstance() { + if (instance == null) { + throw new IllegalStateException("MLAgentTracer is not initialized. Call initialize() first before using getInstance()."); + } + return instance; + } + + /** + * Starts a new span for agent tracing. + * + * This method creates a new span with the specified name and attributes. For agent.task* + * spans, this method attempts to create them as root spans to ensure proper trace + * grouping. If the reflection-based root span creation fails, it falls back to + * normal span creation which might result in ghost parent span. + * + * The method handles both real tracers and NoopTracer instances. When using a real + * tracer, spans are created with proper parent-child relationships. When using + * NoopTracer, the spans are no-ops but still maintain the expected interface. + * + * @param name The name of the span. Should follow the naming convention defined by + * the span constants (e.g., AGENT_TASK_SPAN, AGENT_TOOL_CALL_SPAN). + * @param attributes A map of key-value pairs to associate with the span. These + * provide additional context about the operation being traced. + * May be null or empty if no attributes are needed. + * @param parentSpan The parent span, or null if this should be a root span. + * For agent.task* spans, this parameter is ignored when using + * real tracers as they are forced to be root spans. + * @return A Span object representing the started span, or null if tracing is disabled. + * The returned span should be passed to {@link #endSpan(Span)} when the + * operation completes. + */ + @Override + public Span startSpan(String name, Map attributes, Span parentSpan) { + Attributes attrBuilder = Attributes.create(); + if (attributes != null && !attributes.isEmpty()) { + for (Map.Entry entry : attributes.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + if (key != null && value != null) { + attrBuilder.addAttribute(key, value); + } + } + } + SpanCreationContext context = SpanCreationContext.server().name(name).attributes(attrBuilder); + Span newSpan; + if (name != null && name.startsWith(AGENT_TASK_SPAN) && !(tracer instanceof NoopTracer)) { + // Force agent.task spans to be root span + try { + Field defaultTracerField = tracer.getClass().getDeclaredField("defaultTracer"); + defaultTracerField.setAccessible(true); + Object defaultTracer = defaultTracerField.get(tracer); + + Field tracingTelemetryField = defaultTracer.getClass().getDeclaredField("tracingTelemetry"); + tracingTelemetryField.setAccessible(true); + Object tracingTelemetry = tracingTelemetryField.get(defaultTracer); + + Method createSpanMethod = tracingTelemetry.getClass().getMethod("createSpan", SpanCreationContext.class, Span.class); + createSpanMethod.setAccessible(true); + + newSpan = (Span) createSpanMethod.invoke(tracingTelemetry, context, null); + + newSpan.addAttribute("thread.name", Thread.currentThread().getName()); + } catch (Exception e) { + log.warn("Failed to create root span for agent.task*, falling back to normal span creation", e); + if (parentSpan != null) { + context = context.parent(new SpanContext(parentSpan)); + } + newSpan = tracer.startSpan(context); + } + } else { + if (parentSpan != null) { + context = context.parent(new SpanContext(parentSpan)); + } + newSpan = tracer.startSpan(context); + } + + return newSpan; + } + + /** + * Starts a new span for agent tracing with the specified name and attributes, and no parent span. + *

+ * This is a convenience overload for starting a root span. It is equivalent to calling + * {@link #startSpan(String, Map, Span)} with {@code parentSpan} set to {@code null}. + *

+ * The returned span should be passed to {@link #endSpan(Span)} when the operation completes. + * + * @param name The name of the span. Should follow the naming convention defined by + * the span constants (e.g., AGENT_TASK_SPAN, AGENT_TOOL_CALL_SPAN). + * @param attributes A map of key-value pairs to associate with the span. These + * provide additional context about the operation being traced. + * May be null or empty if no attributes are needed. + * @return A Span object representing the started root span, or null if tracing is disabled. + */ + public Span startSpan(String name, Map attributes) { + return startSpan(name, attributes, null); + } + + /** + * Ends the given span. + * + * This method marks the completion of a span and finalizes its timing information. + * The span will be recorded in the trace with its start time, end time, and any + * attributes that were set during its lifetime. + * + * @param span The span to end. This should be the same Span object that was returned + * by a previous call to {@link #startSpan(String, Map, Span)}. If null, + * an IllegalArgumentException is thrown. + * @throws IllegalArgumentException if the span parameter is null. + */ + @Override + public void endSpan(Span span) { + if (span == null) { + throw new IllegalArgumentException("Span cannot be null"); + } + span.endSpan(); + } + + /** + * Returns the underlying tracer implementation. + * + * This method provides access to the tracer instance that is currently being used + * by this MLAgentTracer. The returned tracer may be either a real tracer implementation + * or a NoopTracer, depending on the current configuration and feature settings. + * + * @return The tracer instance currently in use. This may be a real tracer or + * NoopTracer.INSTANCE if tracing is disabled. + */ + public Tracer getTracer() { + return tracer; + } + + /** + * Resets the singleton instance for testing purposes. + */ + @VisibleForTesting + static void resetForTest() { + instance = null; + } + + /** + * Injects the span context into a carrier map using the TracingContextPropagator. + * + * This method serializes the span context into a map that can be transmitted + * across process boundaries (e.g., in HTTP headers, message queues, etc.). + * The injected context can later be extracted using {@link #extractSpanContext(Map)} + * to continue the trace in another process or thread. + * + * The method uses reflection to access the underlying tracing telemetry components, + * as the OpenSearch tracing API doesn't provide direct access to context propagation. + * If the reflection fails, the method logs a warning but doesn't throw an exception. + * + * @param span The span whose context to inject. If null, this method is a no-op. + * @param carrier The map to inject context into. The span context will be added + * as key-value pairs to this map. Must not be null. + */ + public void injectSpanContext(Span span, Map carrier) { + if (tracer instanceof NoopTracer) { + return; + } + + try { + Field defaultTracerField = tracer.getClass().getDeclaredField("defaultTracer"); + defaultTracerField.setAccessible(true); + Object defaultTracer = defaultTracerField.get(tracer); + + Field tracingTelemetryField = defaultTracer.getClass().getDeclaredField("tracingTelemetry"); + tracingTelemetryField.setAccessible(true); + Object tracingTelemetry = tracingTelemetryField.get(defaultTracer); + + Method getContextPropagatorMethod = tracingTelemetry.getClass().getMethod("getContextPropagator"); + Object propagator = getContextPropagatorMethod.invoke(tracingTelemetry); + + Method injectMethod = propagator.getClass().getMethod("inject", Span.class, BiConsumer.class); + injectMethod.invoke(propagator, span, (BiConsumer) carrier::put); + } catch (Exception e) { + log.warn("Failed to inject span context", e); + } + } + + /** + * Extracts a parent span from a carrier map using the TracingContextPropagator. + * + * This method deserializes a span context from a map that was previously created + * by {@link #injectSpanContext(Span, Map)}. The extracted context can be used + * as a parent span to continue a trace across process boundaries. + * + * The method uses reflection to access the underlying tracing telemetry components, + * as the OpenSearch tracing API doesn't provide direct access to context propagation. + * If the reflection fails or no context is found, the method returns null and logs + * a warning. + * + * @param carrier The map containing the context. This should be the same map that + * was populated by a previous call to {@link #injectSpanContext(Span, Map)}. + * May be null or empty, in which case null is returned. + * @return The extracted parent span, or null if no context is found, the carrier + * is null/empty, or tracing is disabled (NoopTracer is being used). + */ + public Span extractSpanContext(Map carrier) { + if (tracer instanceof NoopTracer) { + return null; + } + + try { + Field defaultTracerField = tracer.getClass().getDeclaredField("defaultTracer"); + defaultTracerField.setAccessible(true); + Object defaultTracer = defaultTracerField.get(tracer); + + Field tracingTelemetryField = defaultTracer.getClass().getDeclaredField("tracingTelemetry"); + tracingTelemetryField.setAccessible(true); + Object tracingTelemetry = tracingTelemetryField.get(defaultTracer); + + Method getContextPropagatorMethod = tracingTelemetry.getClass().getMethod("getContextPropagator"); + Object propagator = getContextPropagatorMethod.invoke(tracingTelemetry); + + Method extractMethod = propagator.getClass().getMethod("extract", Map.class); + Optional spanOpt = (Optional) extractMethod.invoke(propagator, carrier); + if (spanOpt.isPresent()) { + return (Span) spanOpt.get(); + } + } catch (Exception e) { + log.warn("Failed to extract span context", e); + } + return null; + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerMultiSpanTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerMultiSpanTests.java new file mode 100644 index 0000000000..fd5e9a60bc --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerMultiSpanTests.java @@ -0,0 +1,376 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent.tracing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.telemetry.tracing.Span; +import org.opensearch.telemetry.tracing.Tracer; + +/** + * Tests for multi-span scenarios, span identification, and race conditions in MLAgentTracer. + * These tests ensure that: + * 1. Each span has unique identification + * 2. Parent-child relationships are properly maintained + * 3. No race conditions occur in concurrent usage + * 4. Span attributes help identify which span is problematic + */ +public class MLAgentTracerMultiSpanTests { + private MLFeatureEnabledSetting mockFeatureSetting; + private Tracer mockTracer; + + @Before + public void setup() { + mockFeatureSetting = mock(MLFeatureEnabledSetting.class); + mockTracer = mock(Tracer.class); + MLAgentTracer.resetForTest(); + } + + private Span createMockSpan(String traceId, String spanId, String spanName) { + Span mockSpan = mock(Span.class); + when(mockSpan.getTraceId()).thenReturn(traceId); + when(mockSpan.getSpanId()).thenReturn(spanId); + when(mockSpan.getSpanName()).thenReturn(spanName); + return mockSpan; + } + + @Test + public void testMultiSpanHierarchyWithUniqueIdentification() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + + // Set up mock spans with unique IDs + Span mockRootSpan = createMockSpan("trace-123", "span-root", MLAgentTracer.AGENT_TASK_SPAN); + Span mockChildSpan = createMockSpan("trace-123", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN); + Span mockGrandchildSpan = createMockSpan("trace-123", "span-grandchild", MLAgentTracer.AGENT_TOOL_CALL_SPAN); + + when(mockTracer.startSpan(any())).thenReturn(mockRootSpan, mockChildSpan, mockGrandchildSpan); + + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + // Create root span + Map rootAttributes = new HashMap<>(); + rootAttributes.put("operation", "agent_task"); + rootAttributes.put("request_id", "req-123"); + Span rootSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, rootAttributes, null); + assertNotNull("Root span should not be null", rootSpan); + assertNotNull("Root span should have trace ID", rootSpan.getTraceId()); + assertNotNull("Root span should have span ID", rootSpan.getSpanId()); + + // Create child span + Map childAttributes = new HashMap<>(); + childAttributes.put("operation", "llm_call"); + childAttributes.put("model", "gpt-4"); + Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, rootSpan); + assertNotNull("Child span should not be null", childSpan); + assertNotNull("Child span should have trace ID", childSpan.getTraceId()); + assertNotNull("Child span should have span ID", childSpan.getSpanId()); + + // Verify parent-child relationship + assertEquals("Child should have same trace ID as parent", rootSpan.getTraceId(), childSpan.getTraceId()); + assertNotSame("Child should have different span ID than parent", rootSpan.getSpanId(), childSpan.getSpanId()); + + // Create grandchild span + Map grandchildAttributes = new HashMap<>(); + grandchildAttributes.put("operation", "tool_call"); + grandchildAttributes.put("tool_name", "search_index"); + Span grandchildSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, grandchildAttributes, childSpan); + assertNotNull("Grandchild span should not be null", grandchildSpan); + assertEquals("Grandchild should have same trace ID as root", rootSpan.getTraceId(), grandchildSpan.getTraceId()); + assertNotSame("Grandchild should have different span ID than child", childSpan.getSpanId(), grandchildSpan.getSpanId()); + + // End spans in reverse order + tracer.endSpan(grandchildSpan); + tracer.endSpan(childSpan); + tracer.endSpan(rootSpan); + } + + @Test + public void testSpanIdentificationWithAttributes() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + + // Set up mock spans with unique IDs but same trace ID + Span mockTaskSpan = createMockSpan("trace-456", "span-task", MLAgentTracer.AGENT_TASK_SPAN); + Span mockLlmSpan = createMockSpan("trace-456", "span-llm", MLAgentTracer.AGENT_LLM_CALL_SPAN); + Span mockToolSpan = createMockSpan("trace-456", "span-tool", MLAgentTracer.AGENT_TOOL_CALL_SPAN); + + when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockLlmSpan, mockToolSpan); + + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + // Create spans with identifying attributes + Map taskAttributes = new HashMap<>(); + taskAttributes.put("operation", "agent_task"); + taskAttributes.put("request_id", "req-456"); + taskAttributes.put("user_id", "user-789"); + taskAttributes.put("session_id", "session-abc"); + + Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null); + assertNotNull("Task span should not be null", taskSpan); + + // Create multiple child spans with different operations + Map llmAttributes = new HashMap<>(); + llmAttributes.put("operation", "llm_call"); + llmAttributes.put("model", "gpt-4"); + llmAttributes.put("request_id", "req-456"); // Same request ID for correlation + + Span llmSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, llmAttributes, taskSpan); + assertNotNull("LLM span should not be null", llmSpan); + + Map toolAttributes = new HashMap<>(); + toolAttributes.put("operation", "tool_call"); + toolAttributes.put("tool_name", "search_index"); + toolAttributes.put("request_id", "req-456"); // Same request ID for correlation + + Span toolSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, toolAttributes, taskSpan); + assertNotNull("Tool span should not be null", toolSpan); + + // Verify all spans have unique IDs but same trace ID + assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), llmSpan.getTraceId()); + assertEquals("All spans should have same trace ID", taskSpan.getTraceId(), toolSpan.getTraceId()); + assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), llmSpan.getSpanId()); + assertNotSame("Spans should have different span IDs", taskSpan.getSpanId(), toolSpan.getSpanId()); + assertNotSame("Spans should have different span IDs", llmSpan.getSpanId(), toolSpan.getSpanId()); + + // End spans + tracer.endSpan(toolSpan); + tracer.endSpan(llmSpan); + tracer.endSpan(taskSpan); + } + + @Test + public void testConcurrentSpanCreationNoRaceConditions() throws InterruptedException { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + + // Set up mock spans for concurrent testing + Span mockSpan = createMockSpan("trace-concurrent", "span-concurrent", MLAgentTracer.AGENT_TASK_SPAN); + when(mockTracer.startSpan(any())).thenReturn(mockSpan); + + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + int numThreads = 10; + int spansPerThread = 5; + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(numThreads); + AtomicInteger totalSpansCreated = new AtomicInteger(0); + + // Create multiple threads that create spans concurrently + for (int i = 0; i < numThreads; i++) { + final int threadId = i; + executor.submit(() -> { + try { + startLatch.await(); // Wait for all threads to start together + + for (int j = 0; j < spansPerThread; j++) { + Map attributes = new HashMap<>(); + attributes.put("thread_id", String.valueOf(threadId)); + attributes.put("span_index", String.valueOf(j)); + + Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null); + assertNotNull("Span should not be null", span); + assertNotNull("Span should have trace ID", span.getTraceId()); + assertNotNull("Span should have span ID", span.getSpanId()); + + totalSpansCreated.incrementAndGet(); + tracer.endSpan(span); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + // Start all threads simultaneously + startLatch.countDown(); + + // Wait for all threads to complete + boolean completed = endLatch.await(10, TimeUnit.SECONDS); + assertTrue("All threads should complete within timeout", completed); + assertEquals("Should create expected number of spans", numThreads * spansPerThread, totalSpansCreated.get()); + + executor.shutdown(); + assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS)); + } + + @Test + public void testConcurrentParentChildSpanCreation() throws InterruptedException { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + + // Set up mock spans for parent-child testing + Span mockParentSpan = createMockSpan("trace-parent", "span-parent", MLAgentTracer.AGENT_TASK_SPAN); + Span mockChildSpan = createMockSpan("trace-parent", "span-child", MLAgentTracer.AGENT_LLM_CALL_SPAN); + when(mockTracer.startSpan(any())).thenReturn(mockParentSpan, mockChildSpan); + + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + int numThreads = 5; + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(numThreads); + + // Create multiple threads that create parent-child span hierarchies concurrently + for (int i = 0; i < numThreads; i++) { + final int threadId = i; + executor.submit(() -> { + try { + startLatch.await(); // Wait for all threads to start together + + // Create parent span + Map parentAttributes = new HashMap<>(); + parentAttributes.put("thread_id", String.valueOf(threadId)); + + Span parentSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, parentAttributes, null); + assertNotNull("Parent span should not be null", parentSpan); + String parentTraceId = parentSpan.getTraceId(); + String parentSpanId = parentSpan.getSpanId(); + + // Create multiple child spans + for (int j = 0; j < 3; j++) { + Map childAttributes = new HashMap<>(); + childAttributes.put("thread_id", String.valueOf(threadId)); + childAttributes.put("child_index", String.valueOf(j)); + + Span childSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, childAttributes, parentSpan); + assertNotNull("Child span should not be null", childSpan); + + // Verify parent-child relationship + assertEquals("Child should have same trace ID as parent", parentTraceId, childSpan.getTraceId()); + assertNotSame("Child should have different span ID than parent", parentSpanId, childSpan.getSpanId()); + + tracer.endSpan(childSpan); + } + + tracer.endSpan(parentSpan); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + // Start all threads simultaneously + startLatch.countDown(); + + // Wait for all threads to complete + boolean completed = endLatch.await(10, TimeUnit.SECONDS); + assertTrue("All threads should complete within timeout", completed); + + executor.shutdown(); + assertTrue("Executor should shutdown cleanly", executor.awaitTermination(5, TimeUnit.SECONDS)); + } + + @Test + public void testSpanIdentificationWithErrorScenarios() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + + // Set up mock spans for error scenario testing + Span mockTaskSpan = createMockSpan("trace-error", "span-task", MLAgentTracer.AGENT_TASK_SPAN); + Span mockErrorSpan = createMockSpan("trace-error", "span-error", MLAgentTracer.AGENT_LLM_CALL_SPAN); + Span mockSuccessSpan = createMockSpan("trace-error", "span-success", MLAgentTracer.AGENT_TOOL_CALL_SPAN); + when(mockTracer.startSpan(any())).thenReturn(mockTaskSpan, mockErrorSpan, mockSuccessSpan); + + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + // Create a span hierarchy where one span fails + Map taskAttributes = new HashMap<>(); + taskAttributes.put("operation", "agent_task"); + taskAttributes.put("request_id", "req-error-test"); + taskAttributes.put("user_id", "user-error"); + + Span taskSpan = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, taskAttributes, null); + assertNotNull("Task span should not be null", taskSpan); + + // Create child span that will "fail" + Map errorAttributes = new HashMap<>(); + errorAttributes.put("operation", "llm_call"); + errorAttributes.put("model", "gpt-4"); + errorAttributes.put("request_id", "req-error-test"); + + Span errorSpan = tracer.startSpan(MLAgentTracer.AGENT_LLM_CALL_SPAN, errorAttributes, taskSpan); + assertNotNull("Error span should not be null", errorSpan); + + // Simulate an error in the child span + try { + // Simulate some operation that might fail + throw new RuntimeException("Simulated LLM call failure"); + } catch (Exception e) { + // Mark the span as having an error + errorSpan.setError(e); + } + + // Create another child span that succeeds + Map successAttributes = new HashMap<>(); + successAttributes.put("operation", "tool_call"); + successAttributes.put("tool_name", "search_index"); + successAttributes.put("request_id", "req-error-test"); + + Span successSpan = tracer.startSpan(MLAgentTracer.AGENT_TOOL_CALL_SPAN, successAttributes, taskSpan); + assertNotNull("Success span should not be null", successSpan); + + // Verify all spans have proper identification for debugging + assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), errorSpan.getTraceId()); + assertEquals("All spans should have same trace ID for correlation", taskSpan.getTraceId(), successSpan.getTraceId()); + assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), errorSpan.getSpanId()); + assertNotSame("Each span should have unique span ID", taskSpan.getSpanId(), successSpan.getSpanId()); + assertNotSame("Each span should have unique span ID", errorSpan.getSpanId(), successSpan.getSpanId()); + + // End spans + tracer.endSpan(successSpan); + tracer.endSpan(errorSpan); + tracer.endSpan(taskSpan); + } + + @Test + public void testNoopTracerMultiSpanBehavior() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(false); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer tracer = MLAgentTracer.getInstance(); + + // Test that noop tracer still provides proper span identification + Map attributes = new HashMap<>(); + attributes.put("operation", "noop_test"); + attributes.put("request_id", "req-noop"); + + Span span = tracer.startSpan(MLAgentTracer.AGENT_TASK_SPAN, attributes, null); + assertNotNull("Noop span should not be null", span); + + // Noop spans should still have identification for debugging + assertNotNull("Noop span should have trace ID", span.getTraceId()); + assertNotNull("Noop span should have span ID", span.getSpanId()); + + tracer.endSpan(span); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerTests.java new file mode 100644 index 0000000000..7406f45594 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/tracing/MLAgentTracerTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent.tracing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.telemetry.tracing.Tracer; +import org.opensearch.telemetry.tracing.noop.NoopTracer; + +public class MLAgentTracerTests { + private MLFeatureEnabledSetting mockFeatureSetting; + private Tracer mockTracer; + + @Before + public void setup() { + mockFeatureSetting = mock(MLFeatureEnabledSetting.class); + mockTracer = mock(Tracer.class); + MLAgentTracer.resetForTest(); + } + + @Test + public void testExceptionThrownForNotInitialized() { + IllegalStateException exception = assertThrows(IllegalStateException.class, MLAgentTracer::getInstance); + String msg = exception.getMessage(); + assertEquals("MLAgentTracer is not initialized. Call initialize() first before using getInstance().", msg); + } + + @Test + public void testInitializeWithFeatureFlagDisabled() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(false); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertNotNull(instance); + assertTrue(instance.getTracer() instanceof NoopTracer); + } + + @Test + public void testInitializeWithFeatureFlagEnabledAndDynamicEnabled() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertNotNull(instance); + assertEquals(mockTracer, instance.getTracer()); + } + + @Test + public void testInitializeWithFeatureFlagEnabledAndDynamicDisabled() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(false); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertNotNull(instance); + assertTrue(instance.getTracer() instanceof NoopTracer); + } + + @Test + public void testStartSpanWorksWithNullTracer() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + MLAgentTracer.initialize(null, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertNotNull(instance); + assertTrue(instance.getTracer() instanceof NoopTracer); + // Should not throw exception when using NoopTracer + instance.startSpan("test", null, null); + } + + @Test + public void testEndSpanThrowsExceptionIfSpanIsNull() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertThrows(IllegalArgumentException.class, () -> instance.endSpan(null)); + } + + @Test + public void testGetTracerReturnsTracer() { + when(mockFeatureSetting.isTracingEnabled()).thenReturn(true); + when(mockFeatureSetting.isAgentTracingEnabled()).thenReturn(true); + MLAgentTracer.initialize(mockTracer, mockFeatureSetting); + MLAgentTracer instance = MLAgentTracer.getInstance(); + assertEquals(mockTracer, instance.getTracer()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e1d7e78d2b..c97f79f78b 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -218,6 +218,7 @@ import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor; +import org.opensearch.ml.engine.algorithms.agent.tracing.MLAgentTracer; import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl; import org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation; import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; @@ -805,6 +806,8 @@ public Collection createComponents( mcpToolsHelper = new McpToolsHelper(client, threadPool, toolFactoryWrapper); McpAsyncServerHolder.init(mlIndicesHandler, mcpToolsHelper); + MLAgentTracer.initialize(tracer, mlFeatureEnabledSetting, clusterService); + return ImmutableList .of( encryptor, @@ -1156,7 +1159,9 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED, MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED, MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED, - MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED + MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, + MLCommonsSettings.ML_COMMONS_TRACING_ENABLED, + MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED ); return settings; } diff --git a/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java b/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java index cf9cf9c956..75ae9445e3 100644 --- a/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java +++ b/plugin/src/test/java/org/opensearch/ml/settings/MLFeatureEnabledSettingTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; @@ -23,6 +24,7 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_TRACING_ENABLED; import java.util.Set; @@ -70,7 +72,9 @@ public void setUp() { ML_COMMONS_MCP_SERVER_ENABLED, ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, ML_COMMONS_METRIC_COLLECTION_ENABLED, - ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED + ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, + ML_COMMONS_TRACING_ENABLED, + ML_COMMONS_AGENT_TRACING_ENABLED ) ) ); @@ -111,4 +115,24 @@ public void testMetricCollectionSettings() { assertFalse(mlFeatureEnabledSetting.isMetricCollectionEnabled()); assertTrue(mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()); } + + @Test + public void testAgentTracingSettings() { + // Test initial values (not set, should be false) + assertFalse(mlFeatureEnabledSetting.isTracingEnabled()); + assertFalse(mlFeatureEnabledSetting.isAgentTracingEnabled()); + + // Simulate settings change: enable both + Settings newSettings = Settings + .builder() + .put(ML_COMMONS_TRACING_ENABLED.getKey(), true) + .put(ML_COMMONS_AGENT_TRACING_ENABLED.getKey(), true) + .build(); + when(clusterService.getSettings()).thenReturn(newSettings); + mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, newSettings); + + // Verify updated values + assertTrue(mlFeatureEnabledSetting.isTracingEnabled()); + assertTrue(mlFeatureEnabledSetting.isAgentTracingEnabled()); + } }