Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,8 @@ pub async fn create_worker_selection_pipeline_chat(
let component = distributed_runtime
.namespace(namespace)?
.component(component_name)?;
let client = component.endpoint(GENERATE_ENDPOINT).client().await?;
let endpoint = component.endpoint(GENERATE_ENDPOINT);
let client = endpoint.client().await?;

// Discover the model card by searching all instances with this model name
tracing::debug!("Looking for model: {}", model_name);
Expand Down Expand Up @@ -980,7 +981,7 @@ pub async fn create_worker_selection_pipeline_chat(
let chooser = if router_mode == RouterMode::KV {
Some(
model_manager
.kv_chooser_for(&component, card.kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, kv_router_config)
.await?,
)
} else {
Expand Down
5 changes: 1 addition & 4 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -998,13 +998,10 @@ async fn create_kv_router_from_endpoint(
block_size: usize,
kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Get component from endpoint
let component = endpoint.inner.component();

// Create ModelManager and use it to create KvRouter (ensures registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(component, block_size as u32, kv_router_config)
.kv_chooser_for(&endpoint.inner, block_size as u32, kv_router_config)
.await
.map_err(to_pyerr)?;

Expand Down
25 changes: 12 additions & 13 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;

use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::{
component::{Component, Endpoint},
storage::key_value_store::Key,
};
use dynamo_runtime::{component::Endpoint, storage::key_value_store::Key};

use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
Expand Down Expand Up @@ -292,40 +289,42 @@ impl ModelManager {

pub async fn kv_chooser_for(
&self,
component: &Component,
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
let service_name = component.service_name();
let endpoint_path = endpoint.path();

if let Some(kv_chooser) = self.get_kv_chooser(&service_name) {
if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_path) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
component = %service_name,
endpoint = %endpoint_path,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Component is requesting a different kv_cache_block_size than the existing router. \
"KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
return Ok(kv_chooser);
}

let store = component.drt().store();
let client = endpoint.client().await?;
let store = endpoint.component().drt().store();
let router_bucket = store
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None)
.await?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = Key::from_raw(format!("{}/{router_uuid}", component.path()));
let router_key = Key::from_raw(format!("{}/{router_uuid}", endpoint.path()));
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?;
router_bucket
.insert(&router_key, json_router_config.into(), 0)
.await?;

let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
component.clone(),
endpoint.clone(),
client,
kv_cache_block_size,
Some(selector),
kv_router_config,
Expand All @@ -335,7 +334,7 @@ impl ModelManager {
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.insert(service_name, new_kv_chooser.clone());
.insert(endpoint_path, new_kv_chooser.clone());
Ok(new_kv_chooser)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/llm/src/discovery/watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,11 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports.

let endpoint = component.endpoint(&endpoint_id.name);
let kv_chooser = if self.router_mode == RouterMode::KV {
Some(
self.manager
.kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, self.kv_router_config)
.await?,
)
} else {
Expand Down
14 changes: 12 additions & 2 deletions lib/llm/src/entrypoint/input/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,27 @@ where
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator();

// For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV {
let Some(ref chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
};
chooser.client().clone()
} else {
client.clone()
};

// Create worker monitor only if busy_threshold is set
let worker_monitor = busy_threshold.map(|threshold| {
Arc::new(crate::discovery::KvWorkerMonitor::new(
Arc::new(client.clone()),
Arc::new(router_client.clone()),
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});

let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
router_client,
router_mode,
busy_threshold,
worker_monitor,
Expand Down
29 changes: 19 additions & 10 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::Component,
component::{Client, Endpoint},
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
Expand Down Expand Up @@ -213,29 +213,32 @@ pub struct KvRouter {
kv_router_config: KvRouterConfig,

cancellation_token: tokio_util::sync::CancellationToken,

client: Client,
}

impl KvRouter {
pub async fn new(
component: Component,
endpoint: Endpoint,
client: Client,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;

let instances_rx = client.instance_source.as_ref().clone();
let instance_ids_rx = client.instance_avail_watcher();

// Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
Expand All @@ -249,7 +252,7 @@ impl KvRouter {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
block_size,
Expand All @@ -271,7 +274,7 @@ impl KvRouter {
let scheduler = KvScheduler::start(
component.clone(),
block_size,
instances_rx,
instance_ids_rx,
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
Expand Down Expand Up @@ -306,9 +309,15 @@ impl KvRouter {
block_size,
kv_router_config,
cancellation_token,
client,
})
}

/// Get a reference to the client used by this KvRouter
pub fn client(&self) -> &Client {
&self.client
}

/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
Expand Down
14 changes: 9 additions & 5 deletions lib/llm/src/kv_router/prefill_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ impl PrefillRouter {
"Activating prefill router"
);

let client = endpoint.client().await?;

let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the component from the endpoint
// Create KV chooser using the endpoint
let kv_chooser = model_manager
.kv_chooser_for(endpoint.component(), kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, kv_cache_block_size, kv_router_config)
.await?;

// Build the PushRouter for prefill with KV mode
// Extract client from kv_chooser to ensure shared state
let client = kv_chooser.client().clone();

// Build the PushRouter for prefill with KV mode using the shared client
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
Expand All @@ -127,6 +128,9 @@ impl PrefillRouter {
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create client for simple router
let client = endpoint.client().await?;

// Create simple push router with the frontend's router mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
Expand Down
28 changes: 13 additions & 15 deletions lib/llm/src/kv_router/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
Expand Down Expand Up @@ -96,27 +96,26 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>,
instance_ids_rx: watch::Receiver<Vec<u64>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
let instance_ids: Vec<u64> = instance_ids_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();

// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for instance in &instances {
let worker_id = instance.instance_id;
let config = runtime_configs.get(&worker_id).cloned();
for worker_id in &instance_ids {
let config = runtime_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
initial_map.insert(worker_id, config);
initial_map.insert(*worker_id, config);
}
Arc::new(RwLock::new(initial_map))
};
Expand All @@ -132,7 +131,7 @@ impl KvScheduler {
// Spawn background task to monitor and update workers_with_configs
let workers_monitor = workers_with_configs.clone();
let slots_monitor = slots.clone();
let mut instances_monitor_rx = instances_rx.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token();
tokio::spawn(async move {
Expand All @@ -144,9 +143,9 @@ impl KvScheduler {
tracing::trace!("workers monitoring task shutting down");
break;
}
result = instances_monitor_rx.changed() => {
result = instance_ids_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("endpoint watch sender shutdown in monitor");
tracing::warn!("instance IDs watch sender shutdown in monitor");
break;
}
}
Expand All @@ -159,18 +158,17 @@ impl KvScheduler {
}

// Get the latest values from both channels
let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_instance_ids = instance_ids_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();

// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
new_workers_with_configs.insert(worker_id, config);
new_workers_with_configs.insert(*worker_id, config);
}

// Update workers when instances change
Expand Down
Loading
Loading