diff --git a/components/backends/sglang/src/dynamo/sglang/register.py b/components/backends/sglang/src/dynamo/sglang/register.py index 3b397ae5b1..d76dd7ddb4 100644 --- a/components/backends/sglang/src/dynamo/sglang/register.py +++ b/components/backends/sglang/src/dynamo/sglang/register.py @@ -23,7 +23,7 @@ async def register_llm_with_runtime_config( Returns: bool: True if registration succeeded, False if it failed """ - runtime_config = await _get_runtime_config(engine, dynamo_args) + runtime_config = await _get_runtime_config(engine, server_args, dynamo_args) input_type = ModelInput.Tokens output_type = ModelType.Chat | ModelType.Completions if not server_args.skip_tokenizer_init: @@ -51,13 +51,25 @@ async def register_llm_with_runtime_config( async def _get_runtime_config( - engine: sgl.Engine, dynamo_args: DynamoArgs + engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs ) -> Optional[ModelRuntimeConfig]: """Get runtime config from SGLang engine""" runtime_config = ModelRuntimeConfig() # set reasoning parser and tool call parser runtime_config.reasoning_parser = dynamo_args.reasoning_parser runtime_config.tool_call_parser = dynamo_args.tool_call_parser + + # In SGLang, these are server_args, not scheduler_info (unlike vLLM) + # Note: If --max-running-requests is not specified, SGLang uses an internal default + # undocumented value. The value here will be None if not explicitly set by user. + max_running_requests = getattr(server_args, "max_running_requests", None) + if max_running_requests: + runtime_config.max_num_seqs = max_running_requests + + max_prefill_tokens = getattr(server_args, "max_prefill_tokens", None) + if max_prefill_tokens: + runtime_config.max_num_batched_tokens = max_prefill_tokens + try: # Try to check if the engine has a scheduler attribute with the computed values if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: @@ -77,7 +89,10 @@ async def _get_runtime_config( f"(max_total_tokens={max_total_tokens}, page_size={page_size})" ) - # Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info + # Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info. + # SGLang separates configuration (server_args) from runtime stats (scheduler_info). + # In contrast, vLLM exposes both config and runtime values through engine config. + # These are config parameters, so they must be retrieved from server_args only. return runtime_config diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 6fcf424cc6..93999101ea 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -281,9 +281,29 @@ async def init(runtime: DistributedRuntime, config: Config): # TODO: fix this once we have a better way to get total_kv_blocks runtime_config = ModelRuntimeConfig() + # Set values from config that are available immediately + # Note: We populate max_num_seqs and max_num_batched_tokens from config + # to ensure Prometheus metrics are available even without engine stats + + # Naming clarification: + # - In vLLM: max_num_seqs = maximum concurrent requests (this is an unusual name due to vLLM's historic reasons) + # - In TensorRT-LLM: max_batch_size = maximum concurrent requests (clearer name) + # Both parameters control the same thing: how many requests can be processed simultaneously + runtime_config.max_num_seqs = config.max_batch_size + runtime_config.max_num_batched_tokens = config.max_num_tokens runtime_config.reasoning_parser = config.reasoning_parser runtime_config.tool_call_parser = config.tool_call_parser + logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}") + logging.info( + f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" + ) + + # The get_engine_runtime_config function exists but is not called here due to: + # 1. get_stats_async requires active requests to work properly + # 2. We need runtime config during registration, before any requests are made + # 3. total_kv_blocks would ideally come from engine stats but is not critical for basic operation + # publisher will be set later if publishing is enabled. handler_config = RequestHandlerConfig( component=component, diff --git a/deploy/metrics/README.md b/deploy/metrics/README.md index aa12ad9333..306d5e3bc9 100644 --- a/deploy/metrics/README.md +++ b/deploy/metrics/README.md @@ -79,7 +79,30 @@ When using Dynamo HTTP Frontend (`--framework VLLM` or `--framework TRTLLM`), th - `dynamo_frontend_requests_total`: Total LLM requests (counter) - `dynamo_frontend_time_to_first_token_seconds`: Time to first token (histogram) -**Note**: The `dynamo_frontend_inflight_requests_total` metric tracks requests from HTTP handler start until the complete response is finished, while `dynamo_frontend_queued_requests_total` tracks requests from HTTP handler start until first token generation begins (including prefill time). HTTP queue time is a subset of inflight time. +##### Model Configuration Metrics + +The frontend also exposes model configuration metrics with the `dynamo_frontend_model_*` prefix. These metrics are populated from the worker backend registration service when workers register with the system: + +**Runtime Config Metrics (from ModelRuntimeConfig):** +These metrics come from the runtime configuration provided by worker backends during registration. + +- `dynamo_frontend_model_total_kv_blocks`: Total KV blocks available for a worker serving the model (gauge) +- `dynamo_frontend_model_max_num_seqs`: Maximum number of sequences for a worker serving the model (gauge) +- `dynamo_frontend_model_max_num_batched_tokens`: Maximum number of batched tokens for a worker serving the model (gauge) + +**MDC Metrics (from ModelDeploymentCard):** +These metrics come from the Model Deployment Card information provided by worker backends during registration. + +- `dynamo_frontend_model_context_length`: Maximum context length for a worker serving the model (gauge) +- `dynamo_frontend_model_kv_cache_block_size`: KV cache block size for a worker serving the model (gauge) +- `dynamo_frontend_model_migration_limit`: Request migration limit for a worker serving the model (gauge) + +**Worker Management Metrics:** +- `dynamo_frontend_model_workers`: Number of worker instances currently serving the model (gauge) + +**Important Notes:** +- The `dynamo_frontend_inflight_requests_total` metric tracks requests from HTTP handler start until the complete response is finished, while `dynamo_frontend_queued_requests_total` tracks requests from HTTP handler start until first token generation begins (including prefill time). HTTP queue time is a subset of inflight time. +- **Model Name Deduplication**: When multiple worker instances register with the same model name, only the first instance's configuration metrics (runtime config and MDC metrics) will be populated. Subsequent instances with duplicate model names will be skipped for configuration metric updates, though the worker count metric will reflect all instances. #### Request Processing Flow diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 4e320630a5..95c01e15bd 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -18,6 +18,13 @@ use std::{ time::{Duration, Instant}, }; +use crate::discovery::ModelEntry; +use crate::local_model::runtime_config::ModelRuntimeConfig; +use crate::model_card::{ModelDeploymentCard, ROOT_PATH as MDC_ROOT_PATH}; +use dynamo_runtime::metrics::prometheus_names::clamp_u64_to_i64; +use dynamo_runtime::slug::Slug; +use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}; + pub use prometheus::Registry; use super::RouteDoc; @@ -32,6 +39,16 @@ pub struct Metrics { output_sequence_length: HistogramVec, time_to_first_token: HistogramVec, inter_token_latency: HistogramVec, + + // Runtime configuration metrics. Note: Some of these metrics represent counter-like values from + // source systems, but are implemented as gauges because they are copied/synchronized from upstream + // counter values rather than being directly incremented. + model_total_kv_blocks: IntGaugeVec, + model_max_num_seqs: IntGaugeVec, + model_max_num_batched_tokens: IntGaugeVec, + model_context_length: IntGaugeVec, + model_kv_cache_block_size: IntGaugeVec, + model_migration_limit: IntGaugeVec, } // Inflight tracks requests from HTTP handler start until complete response is finished. @@ -126,6 +143,26 @@ impl Metrics { /// - `{prefix}_output_sequence_tokens` - HistogramVec for output sequence length in tokens /// - `{prefix}_time_to_first_token_seconds` - HistogramVec for time to first token in seconds /// - `{prefix}_inter_token_latency_seconds` - HistogramVec for inter-token latency in seconds + /// + /// ## Model Configuration Metrics + /// + /// Runtime config metrics (from ModelRuntimeConfig): + /// - `{prefix}_model_total_kv_blocks` - IntGaugeVec for total KV cache blocks available for a worker serving the model + /// - `{prefix}_model_max_num_seqs` - IntGaugeVec for maximum sequences for a worker serving the model + /// - `{prefix}_model_max_num_batched_tokens` - IntGaugeVec for maximum batched tokens for a worker serving the model + /// + /// MDC metrics (from ModelDeploymentCard): + /// - `{prefix}_model_context_length` - IntGaugeVec for maximum context length for a worker serving the model + /// - `{prefix}_model_kv_cache_block_size` - IntGaugeVec for KV cache block size for a worker serving the model + /// - `{prefix}_model_migration_limit` - IntGaugeVec for request migration limit for a worker serving the model + /// + /// ## Runtime Config Polling Configuration + /// + /// The polling behavior can be configured via environment variables: + /// - `DYN_HTTP_SVC_CONFIG_METRICS_POLL_INTERVAL_SECS`: Poll interval in seconds (must be > 0, supports fractional seconds, defaults to 8) + /// + /// Metrics are never removed to preserve historical data. Runtime config and MDC + /// metrics are updated when models are discovered and their configurations are available. pub fn new() -> Self { let raw_prefix = std::env::var(frontend_service::METRICS_PREFIX_ENV) .unwrap_or_else(|_| name_prefix::FRONTEND.to_string()); @@ -235,6 +272,64 @@ impl Metrics { ) .unwrap(); + // Runtime configuration metrics + // Note: Some of these metrics represent counter-like values from source systems, + // but are implemented as gauges because they are copied/synchronized from upstream + // counter values rather than being directly incremented. + let model_total_kv_blocks = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_TOTAL_KV_BLOCKS), + "Total KV cache blocks available for a worker serving the model", + ), + &["model"], + ) + .unwrap(); + + let model_max_num_seqs = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_MAX_NUM_SEQS), + "Maximum number of sequences for a worker serving the model", + ), + &["model"], + ) + .unwrap(); + + let model_max_num_batched_tokens = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_MAX_NUM_BATCHED_TOKENS), + "Maximum number of batched tokens for a worker serving the model", + ), + &["model"], + ) + .unwrap(); + + let model_context_length = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_CONTEXT_LENGTH), + "Maximum context length in tokens for a worker serving the model", + ), + &["model"], + ) + .unwrap(); + + let model_kv_cache_block_size = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_KV_CACHE_BLOCK_SIZE), + "KV cache block size in tokens for a worker serving the model", + ), + &["model"], + ) + .unwrap(); + + let model_migration_limit = IntGaugeVec::new( + Opts::new( + frontend_metric_name(frontend_service::MODEL_MIGRATION_LIMIT), + "Maximum number of request migrations allowed for the model", + ), + &["model"], + ) + .unwrap(); + Metrics { request_counter, inflight_gauge, @@ -245,6 +340,12 @@ impl Metrics { output_sequence_length, time_to_first_token, inter_token_latency, + model_total_kv_blocks, + model_max_num_seqs, + model_max_num_batched_tokens, + model_context_length, + model_kv_cache_block_size, + model_migration_limit, } } @@ -333,9 +434,230 @@ impl Metrics { registry.register(Box::new(self.output_sequence_length.clone()))?; registry.register(Box::new(self.time_to_first_token.clone()))?; registry.register(Box::new(self.inter_token_latency.clone()))?; + + // Register runtime configuration metrics + registry.register(Box::new(self.model_total_kv_blocks.clone()))?; + registry.register(Box::new(self.model_max_num_seqs.clone()))?; + registry.register(Box::new(self.model_max_num_batched_tokens.clone()))?; + registry.register(Box::new(self.model_context_length.clone()))?; + registry.register(Box::new(self.model_kv_cache_block_size.clone()))?; + registry.register(Box::new(self.model_migration_limit.clone()))?; + + Ok(()) + } + + /// Update runtime configuration metrics for a model + /// This should be called when model runtime configuration is available or updated + pub fn update_runtime_config_metrics( + &self, + model_name: &str, + runtime_config: &ModelRuntimeConfig, + ) { + if let Some(total_kv_blocks) = runtime_config.total_kv_blocks { + self.model_total_kv_blocks + .with_label_values(&[model_name]) + .set(clamp_u64_to_i64(total_kv_blocks)); + } + + if let Some(max_num_seqs) = runtime_config.max_num_seqs { + self.model_max_num_seqs + .with_label_values(&[model_name]) + .set(clamp_u64_to_i64(max_num_seqs)); + } + + if let Some(max_batched_tokens) = runtime_config.max_num_batched_tokens { + self.model_max_num_batched_tokens + .with_label_values(&[model_name]) + .set(clamp_u64_to_i64(max_batched_tokens)); + } + } + + /// Update model deployment card metrics for a model + /// This should be called when model deployment card information is available + pub fn update_mdc_metrics( + &self, + model_name: &str, + context_length: u32, + kv_cache_block_size: u32, + migration_limit: u32, + ) { + self.model_context_length + .with_label_values(&[model_name]) + .set(context_length as i64); + + self.model_kv_cache_block_size + .with_label_values(&[model_name]) + .set(kv_cache_block_size as i64); + + self.model_migration_limit + .with_label_values(&[model_name]) + .set(migration_limit as i64); + } + + /// Update metrics from a ModelEntry + /// This is a convenience method that extracts runtime config from a ModelEntry + /// and updates the appropriate metrics + pub fn update_metrics_from_model_entry(&self, model_entry: &ModelEntry) { + if let Some(runtime_config) = &model_entry.runtime_config { + self.update_runtime_config_metrics(&model_entry.name, runtime_config); + } + } + + /// Update metrics from a ModelEntry and its ModelDeploymentCard + /// This updates both runtime config metrics and MDC-specific metrics + pub async fn update_metrics_from_model_entry_with_mdc( + &self, + model_entry: &ModelEntry, + etcd_client: &dynamo_runtime::transports::etcd::Client, + ) -> anyhow::Result<()> { + // Update runtime config metrics + if let Some(runtime_config) = &model_entry.runtime_config { + self.update_runtime_config_metrics(&model_entry.name, runtime_config); + } + + // Load and update MDC metrics + let model_slug = Slug::from_string(&model_entry.name); + let store: Box = Box::new(EtcdStorage::new(etcd_client.clone())); + let card_store = Arc::new(KeyValueStoreManager::new(store)); + + match card_store + .load::(MDC_ROOT_PATH, &model_slug) + .await + { + Ok(Some(mdc)) => { + self.update_mdc_metrics( + &model_entry.name, + mdc.context_length, + mdc.kv_cache_block_size, + mdc.migration_limit, + ); + tracing::debug!( + model = %model_entry.name, + "Successfully updated MDC metrics" + ); + } + Ok(None) => { + tracing::debug!( + model = %model_entry.name, + "No MDC found in storage, skipping MDC metrics" + ); + } + Err(e) => { + tracing::debug!( + model = %model_entry.name, + error = %e, + "Failed to load MDC for metrics update" + ); + } + } + Ok(()) } + /// Start a background task that periodically updates runtime config metrics + /// + /// ## Why Polling is Required + /// + /// Polling is necessary because new models may come online at any time through the distributed + /// discovery system. The ModelManager is continuously updated as workers register/deregister + /// with etcd, and we need to periodically check for these changes to expose their metrics. + /// + /// ## Behavior + /// + /// - Polls the ModelManager for current models and updates metrics accordingly + /// - Models are never removed from metrics to preserve historical data + /// - If multiple model instances have the same name, only the first instance's metrics are used + /// - Subsequent instances with duplicate names will be skipped + /// + /// ## MDC (Model Deployment Card) Behavior + /// + /// Currently, we don't overwrite an MDC. The first worker to start wins, and we assume + /// that all other workers claiming to serve that model really are using the same configuration. + /// Later, every worker will have its own MDC, and the frontend will validate that they + /// checksum the same. For right now, you can assume they have the same MDC, because + /// they aren't allowed to change it. + /// + /// The task will run until the provided cancellation token is cancelled. + pub fn start_runtime_config_polling_task( + metrics: Arc, + manager: Arc, + etcd_client: Option, + poll_interval: Duration, + cancel_token: tokio_util::sync::CancellationToken, + ) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(poll_interval); + let mut known_models = std::collections::HashSet::new(); + + tracing::info!( + interval_secs = poll_interval.as_secs(), + "Starting runtime config metrics polling task (metrics never removed)" + ); + + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + tracing::info!("Runtime config metrics polling task cancelled"); + break; + } + _ = interval.tick() => { + // Continue with polling logic + } + } + + // Get current model entries from the manager + let current_entries = manager.get_model_entries(); + let mut current_models = std::collections::HashSet::new(); + + // Note: If multiple model instances have the same name, only the first instance's config metrics are recorded. + // Subsequent instances with duplicate names will be skipped for config updates. + // This is based on the assumption that all workers serving the same model have identical + // configuration values (MDC content, runtime config, etc.). This assumption holds because + // workers are not allowed to change their configuration after registration. + + // Update configuration metrics for current models + for entry in current_entries { + // Skip config processing if we've already seen this model name + if !current_models.insert(entry.name.clone()) { + tracing::debug!( + model_name = %entry.name, + endpoint = ?entry.endpoint_id, + "Skipping duplicate model instance - only first instance config metrics are recorded" + ); + continue; + } + + // Update runtime config metrics if available + if let Some(runtime_config) = &entry.runtime_config { + metrics.update_runtime_config_metrics(&entry.name, runtime_config); + } + + // Optionally load MDC for additional metrics if etcd is available + if let Some(ref etcd) = etcd_client + && let Err(e) = metrics + .update_metrics_from_model_entry_with_mdc(&entry, etcd) + .await + { + tracing::debug!( + model = %entry.name, + error = %e, + "Failed to update MDC metrics (this is normal if MDC is not available)" + ); + } + } + + // Update our known models set + known_models.extend(current_models.iter().cloned()); + + tracing::trace!( + active_models = current_models.len(), + total_known_models = known_models.len(), + "Updated runtime config metrics for active models" + ); + } + }) + } + /// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request, /// and the kind of endpoint that was hit /// diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 99b8dad095..866a0fe704 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -133,6 +133,9 @@ pub struct HttpService { tls_cert_path: Option, tls_key_path: Option, route_docs: Vec, + + // Metrics polling configuration + etcd_client: Option, } #[derive(Clone, Builder)] @@ -201,6 +204,22 @@ impl HttpService { let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" }; tracing::info!(protocol, address, "Starting HTTP(S) service"); + // Start background task to poll runtime config metrics with proper cancellation + let poll_interval_secs = std::env::var("DYN_HTTP_SVC_CONFIG_METRICS_POLL_INTERVAL_SECS") + .ok() + .and_then(|s| s.parse::().ok()) + .filter(|&secs| secs > 0.0) // Guard against zero or negative values + .unwrap_or(8.0); + let poll_interval = Duration::from_secs_f64(poll_interval_secs); + + let _polling_task = super::metrics::Metrics::start_runtime_config_polling_task( + self.state.metrics_clone(), + self.state.manager_clone(), + self.etcd_client.clone(), + poll_interval, + cancel_token.child_token(), + ); + let router = self.router.clone(); let observer = cancel_token.child_token(); @@ -294,6 +313,7 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); + let etcd_client = config.etcd_client.clone(); let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client)); state @@ -313,6 +333,8 @@ impl HttpServiceConfigBuilder { let registry = metrics::Registry::new(); state.metrics_clone().register(®istry)?; + // Note: Metrics polling task will be started in run() method to have access to cancellation token + let mut router = axum::Router::new(); let mut all_docs = Vec::new(); @@ -344,6 +366,7 @@ impl HttpServiceConfigBuilder { tls_cert_path: config.tls_cert_path, tls_key_path: config.tls_key_path, route_docs: all_docs, + etcd_client, }) } diff --git a/lib/llm/tests/http_metrics.rs b/lib/llm/tests/http_metrics.rs index 33cb4608fd..5d76ab53f4 100644 --- a/lib/llm/tests/http_metrics.rs +++ b/lib/llm/tests/http_metrics.rs @@ -1,16 +1,66 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use dynamo_llm::http::service::metrics::Endpoint; -use dynamo_llm::http::service::service_v2::HttpService; -use dynamo_runtime::CancellationToken; +use anyhow::Error; +use async_stream::stream; +use dynamo_llm::{ + http::service::metrics::Endpoint, + http::service::service_v2::HttpService, + protocols::{ + Annotated, + openai::chat_completions::{ + NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, + }, + }, +}; use dynamo_runtime::metrics::prometheus_names::frontend_service::METRICS_PREFIX_ENV; -use std::time::Duration; +use dynamo_runtime::{ + CancellationToken, + pipeline::{ + AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait, + }, +}; +use std::{sync::Arc, time::Duration}; #[path = "common/ports.rs"] mod ports; use ports::get_random_port; +// Mock engine for testing metrics +struct MockModelEngine {} + +#[async_trait] +impl + AsyncEngine< + SingleIn, + ManyOut>, + Error, + > for MockModelEngine +{ + async fn generate( + &self, + request: SingleIn, + ) -> Result>, Error> { + let (request, context) = request.transfer(()); + let ctx = context.context(); + + let mut generator = request.response_generator(ctx.id().to_string()); + + let stream = stream! { + // Simulate some processing time + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Generate 5 response chunks + for i in 0..5 { + let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None); + yield Annotated::from_data(output); + } + }; + + Ok(ResponseStream::new(Box::pin(stream), ctx)) + } +} + #[tokio::test] async fn test_metrics_prefix_default() { // Test default prefix when no env var is set @@ -37,7 +87,12 @@ async fn test_metrics_prefix_default() { .text() .await .unwrap(); + + // Assert metrics that are actually present in the default configuration assert!(body.contains("dynamo_frontend_requests_total")); + assert!(body.contains("dynamo_frontend_inflight_requests_total")); + assert!(body.contains("dynamo_frontend_request_duration_seconds")); + assert!(body.contains("dynamo_frontend_client_disconnects")); token.cancel(); let _ = handle.await; @@ -129,3 +184,387 @@ async fn wait_for_metrics_ready(port: u16) { } } } + +#[tokio::test] +async fn test_metrics_with_mock_model() { + // Test metrics collection with a mock model serving requests + // Ensure we use the default prefix + temp_env::async_with_vars([(METRICS_PREFIX_ENV, None::<&str>)], async { + let port = get_random_port().await; + let service = HttpService::builder() + .port(port) + .enable_chat_endpoints(true) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + // Start the HTTP service + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Add mock model engine + let mock_engine = Arc::new(MockModelEngine {}); + manager + .add_chat_completions_model("mockmodel", mock_engine) + .unwrap(); + + // Wait for service to be ready + wait_for_metrics_ready(port).await; + + let client = reqwest::Client::new(); + + // Create a chat completion request + let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User( + dynamo_async_openai::types::ChatCompletionRequestUserMessage { + content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "Hello, mock model!".to_string(), + ), + name: None, + }, + ); + + let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default() + .model("mockmodel") + .messages(vec![message]) + .max_tokens(50u32) + .stream(true) + .build() + .expect("Failed to build request"); + + // Make the request to the HTTP service + let response = client + .post(format!("http://localhost:{}/v1/chat/completions", port)) + .json(&request) + .send() + .await + .unwrap(); + + assert!( + response.status().is_success(), + "Request failed: {:?}", + response + ); + + // Consume the response stream to complete the request + let _response_body = response.bytes().await.unwrap(); + + // Give some time for metrics to be updated + tokio::time::sleep(Duration::from_millis(100)).await; + + // Fetch and verify metrics + let metrics_response = client + .get(format!("http://localhost:{}/metrics", port)) + .send() + .await + .unwrap(); + + assert!(metrics_response.status().is_success()); + let metrics_body = metrics_response.text().await.unwrap(); + + println!("=== METRICS WITH MOCK MODEL ==="); + println!("{}", metrics_body); + println!("=== END METRICS ==="); + + // Assert that key metrics are present with the mockmodel + assert!(metrics_body.contains("dynamo_frontend_requests_total")); + assert!(metrics_body.contains("model=\"mockmodel\"")); + assert!(metrics_body.contains("dynamo_frontend_inflight_requests_total")); + assert!(metrics_body.contains("dynamo_frontend_request_duration_seconds")); + assert!(metrics_body.contains("dynamo_frontend_output_sequence_tokens")); + assert!(metrics_body.contains("dynamo_frontend_queued_requests_total")); + + // Verify specific request counter incremented + assert!(metrics_body.contains("endpoint=\"chat_completions\"")); + assert!(metrics_body.contains("request_type=\"stream\"")); + assert!(metrics_body.contains("status=\"success\"")); + + // Clean up + cancel_token.cancel(); + task.await.unwrap().unwrap(); + }) + .await; +} + +// Integration tests that require distributed runtime with etcd +#[cfg(feature = "integration")] +mod integration_tests { + use super::*; + use dynamo_llm::{ + discovery::ModelEntry, engines::make_echo_engine, entrypoint::EngineConfig, + local_model::LocalModelBuilder, + }; + use dynamo_runtime::DistributedRuntime; + + #[tokio::test] + #[ignore = "Requires etcd and distributed runtime"] + async fn test_metrics_with_mdc_registration() { + // Integration test for metrics collection with full MDC registration (like real model servers) + temp_env::async_with_vars([ + (METRICS_PREFIX_ENV, None::<&str>), + ("DYN_HTTP_SVC_CONFIG_METRICS_POLL_INTERVAL_SECS", Some("0.6")), // Fast polling for tests (600ms) + ], async { + let port = get_random_port().await; + + // Create distributed runtime (required for MDC registration) + let runtime = dynamo_runtime::Runtime::from_settings().unwrap(); + let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()) + .await + .unwrap(); + + // Create LocalModel with realistic configuration for testing + let mut local_model = LocalModelBuilder::default() + .model_name(Some("test-mdc-model".to_string())) + .build() + .await + .unwrap(); + + // Create EngineConfig with EchoEngine + let engine_config = EngineConfig::StaticFull { + engine: make_echo_engine(), + model: Box::new(local_model.clone()), + is_static: false, // This enables MDC registration! + }; + + let service = HttpService::builder() + .port(port) + .enable_chat_endpoints(true) + .with_etcd_client(distributed_runtime.etcd_client()) + .build() + .unwrap(); + + // Set up model watcher to discover models from etcd (like production) + // This is crucial for the polling task to find model entries + use dynamo_llm::discovery::{ModelWatcher, MODEL_ROOT_PATH}; + use dynamo_runtime::pipeline::RouterMode; + + let model_watcher = ModelWatcher::new( + distributed_runtime.clone(), + service.state().manager_clone(), + RouterMode::RoundRobin, + None, + None, + ); + + // Start watching etcd for model registrations + if let Some(etcd_client) = distributed_runtime.etcd_client() { + let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await.unwrap(); + let (_prefix, _watcher, receiver) = models_watcher.dissolve(); + + // Spawn watcher task to discover models from etcd + let _watcher_task = tokio::spawn(async move { + model_watcher.watch(receiver, None).await; + }); + + } + + // Set up the engine following the StaticFull pattern from http.rs + let EngineConfig::StaticFull { engine, model, .. } = engine_config else { + panic!("Expected StaticFull config"); + }; + + let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine)); + let manager = service.model_manager(); + manager + .add_chat_completions_model(model.service_name(), engine.clone()) + .unwrap(); + + // Now do the proper MDC registration via LocalModel::attach() + // Create a component and endpoint for proper registration + let namespace = distributed_runtime.namespace("test-namespace").unwrap(); + let test_component = namespace.component("test-mdc-component").unwrap(); + let test_endpoint = test_component.endpoint("test-mdc-endpoint"); + + // This will store the MDC in etcd and create the ModelEntry for discovery + local_model + .attach( + &test_endpoint, + dynamo_llm::model_type::ModelType::Chat, + dynamo_llm::model_type::ModelInput::Text, + ) + .await + .unwrap(); + + + // Start the HTTP service + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let service_for_task = service.clone(); + let task = tokio::spawn(async move { service_for_task.run(token.clone()).await }); + + // Wait for service to be ready + wait_for_metrics_ready(port).await; + + // Wait for MDC registration to complete by checking if the model appears + // This simulates the real polling that happens in production + let start = tokio::time::Instant::now(); + let timeout = Duration::from_secs(10); + loop { + if start.elapsed() > timeout { + break; // Continue with test even if MDC metrics aren't ready + } + + // Check if our model is registered in the manager (indicates MDC registration completed) + let model_service_name = model.service_name(); + if manager.has_model_any(model_service_name) { + tracing::info!("MDC registration completed for {}", model_service_name); + break; + } + + tokio::time::sleep(Duration::from_millis(100)).await; + } + + // Give a bit more time for background metrics collection + tokio::time::sleep(Duration::from_millis(200)).await; + + let client = reqwest::Client::new(); + + // Create a chat completion request + let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User( + dynamo_async_openai::types::ChatCompletionRequestUserMessage { + content: + dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "Hello, MDC model!".to_string(), + ), + name: None, + }, + ); + + let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default() + .model(model.service_name()) + .messages(vec![message]) + .max_tokens(50u32) + .stream(true) + .build() + .expect("Failed to build request"); + + // Make the request to the HTTP service + let response = client + .post(format!("http://localhost:{}/v1/chat/completions", port)) + .json(&request) + .send() + .await + .unwrap(); + + assert!( + response.status().is_success(), + "Request failed: {:?}", + response + ); + + // Consume the response stream to complete the request + let _response_body = response.bytes().await.unwrap(); + + // Wait for the fast polling interval (600ms) for MDC metrics + tokio::time::sleep(Duration::from_millis(5)).await; + + // Fetch and verify metrics + let metrics_response = client + .get(format!("http://localhost:{}/metrics", port)) + .send() + .await + .unwrap(); + + assert!(metrics_response.status().is_success()); + let metrics_body = metrics_response.text().await.unwrap(); + + println!("=== METRICS WITH FULL MDC REGISTRATION ==="); + println!("{}", metrics_body); + println!("=== END METRICS ==="); + + // Assert basic metrics are present (using service_name from the model) + let model_name = model.service_name(); + assert!(metrics_body.contains("dynamo_frontend_requests_total")); + assert!(metrics_body.contains(&format!("model=\"{}\"", model_name))); + assert!(metrics_body.contains("dynamo_frontend_inflight_requests_total")); + assert!(metrics_body.contains("dynamo_frontend_request_duration_seconds")); + assert!(metrics_body.contains("dynamo_frontend_output_sequence_tokens")); + assert!(metrics_body.contains("dynamo_frontend_queued_requests_total")); + + // Assert MDC-based model configuration metrics are present + // These MUST be present for the test to pass + assert!(metrics_body.contains("dynamo_frontend_model_context_length"), + "MDC metrics not found! Metrics body: {}", metrics_body); + + assert!(metrics_body.contains("dynamo_frontend_model_kv_cache_block_size")); + assert!(metrics_body.contains("dynamo_frontend_model_migration_limit")); + + // Note: The following metrics are not present in this test because they require + // actual inference engines (vllm/sglang/trtllm *.py) with real runtime configurations: + // - dynamo_frontend_model_total_kv_blocks (requires actual KV cache from real engines) + // - dynamo_frontend_model_max_num_seqs (requires actual batching config from real engines) + // - dynamo_frontend_model_max_num_batched_tokens (requires actual batching config from real engines) + + + // Verify specific request counter incremented + assert!(metrics_body.contains("endpoint=\"chat_completions\"")); + assert!(metrics_body.contains("request_type=\"stream\"")); + assert!(metrics_body.contains("status=\"success\"")); + + // Now test the complete lifecycle: remove the model from etcd + + // Get all model entries to find the one we need to delete + if let Some(etcd_client) = distributed_runtime.etcd_client() { + let kvs = etcd_client.kv_get_prefix("models").await.unwrap(); + + // Find our model's etcd key + let mut model_key_to_delete = None; + for kv in kvs { + if let Ok(model_entry) = serde_json::from_slice::(kv.value()) + && model_entry.name == "test-mdc-model" + { + model_key_to_delete = Some(kv.key_str().unwrap().to_string()); + break; + } + } + + if let Some(key) = model_key_to_delete { + etcd_client.kv_delete(key.as_str(), None).await.unwrap(); + + // Poll every 80ms for up to 2 seconds to check when worker count drops to 0 + + let start_time = tokio::time::Instant::now(); + let timeout = Duration::from_millis(2000); + let mut worker_count_dropped = false; + + while start_time.elapsed() < timeout { + // Check if the model was removed from the manager + let has_model = manager.has_model_any(model.service_name()); + + // Fetch current metrics + let metrics_response = client + .get(format!("http://localhost:{}/metrics", port)) + .send() + .await + .unwrap(); + + if metrics_response.status().is_success() { + + // Since model_workers metric was removed, just check if model is gone from manager + if !has_model { + worker_count_dropped = true; + break; + } + } + + tokio::time::sleep(Duration::from_millis(80)).await; + } + + // Assert that model was removed from manager + assert!(worker_count_dropped, + "Model should be removed from manager after etcd removal and polling cycles"); + + } else { + } + } + + + // Clean up + cancel_token.cancel(); + task.await.unwrap().unwrap(); + }) + .await; + } +} diff --git a/lib/runtime/src/metrics/prometheus_names.rs b/lib/runtime/src/metrics/prometheus_names.rs index bec6358451..0a382e2b1b 100644 --- a/lib/runtime/src/metrics/prometheus_names.rs +++ b/lib/runtime/src/metrics/prometheus_names.rs @@ -94,6 +94,28 @@ pub mod frontend_service { /// Inter-token latency in seconds pub const INTER_TOKEN_LATENCY_SECONDS: &str = "inter_token_latency_seconds"; + /// Model configuration metrics + /// + /// Runtime config metrics (from ModelRuntimeConfig): + /// Total KV blocks available for a worker serving the model + pub const MODEL_TOTAL_KV_BLOCKS: &str = "model_total_kv_blocks"; + + /// Maximum number of sequences for a worker serving the model (runtime config) + pub const MODEL_MAX_NUM_SEQS: &str = "model_max_num_seqs"; + + /// Maximum number of batched tokens for a worker serving the model (runtime config) + pub const MODEL_MAX_NUM_BATCHED_TOKENS: &str = "model_max_num_batched_tokens"; + + /// MDC metrics (from ModelDeploymentCard): + /// Maximum context length for a worker serving the model (MDC) + pub const MODEL_CONTEXT_LENGTH: &str = "model_context_length"; + + /// KV cache block size for a worker serving the model (MDC) + pub const MODEL_KV_CACHE_BLOCK_SIZE: &str = "model_kv_cache_block_size"; + + /// Request migration limit for a worker serving the model (MDC) + pub const MODEL_MIGRATION_LIMIT: &str = "model_migration_limit"; + /// Status label values pub mod status { /// Value for successful requests @@ -421,6 +443,33 @@ pub fn build_component_metric_name(metric_name: &str) -> String { format!("{}_{}", name_prefix::COMPONENT, sanitized_name) } +/// Safely converts a u64 value to i64 for Prometheus metrics +/// +/// Since Prometheus IntGaugeVec uses i64 but our data types use u64, +/// this function clamps large u64 values to i64::MAX to prevent overflow +/// and ensure metrics remain positive. +/// +/// # Arguments +/// * `value` - The u64 value to convert +/// +/// # Returns +/// An i64 value, clamped to i64::MAX if the input exceeds i64::MAX +/// +/// # Examples +/// ``` +/// use dynamo_runtime::metrics::prometheus_names::clamp_u64_to_i64; +/// +/// assert_eq!(clamp_u64_to_i64(100), 100); +/// assert_eq!(clamp_u64_to_i64(u64::MAX), i64::MAX); +/// ``` +pub fn clamp_u64_to_i64(value: u64) -> i64 { + if value > i64::MAX as u64 { + i64::MAX + } else { + value as i64 + } +} + #[cfg(test)] mod tests { use super::*; @@ -645,4 +694,20 @@ mod tests { // Test that empty input panics with clear message build_component_metric_name(""); } + + #[test] + fn test_clamp_u64_to_i64() { + // Test normal values within i64 range + assert_eq!(clamp_u64_to_i64(0), 0); + assert_eq!(clamp_u64_to_i64(100), 100); + assert_eq!(clamp_u64_to_i64(1000000), 1000000); + + // Test maximum i64 value + assert_eq!(clamp_u64_to_i64(i64::MAX as u64), i64::MAX); + + // Test values that exceed i64::MAX + assert_eq!(clamp_u64_to_i64(u64::MAX), i64::MAX); + assert_eq!(clamp_u64_to_i64((i64::MAX as u64) + 1), i64::MAX); + assert_eq!(clamp_u64_to_i64((i64::MAX as u64) + 1000), i64::MAX); + } }