Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,8 @@ pub async fn create_worker_selection_pipeline_chat(
let component = distributed_runtime
.namespace(namespace)?
.component(component_name)?;
let client = component.endpoint(GENERATE_ENDPOINT).client().await?;
let endpoint = component.endpoint(GENERATE_ENDPOINT);
let client = endpoint.client().await?;

// Discover the model card by searching all instances with this model name
tracing::debug!("Looking for model: {}", model_name);
Expand Down Expand Up @@ -980,7 +981,7 @@ pub async fn create_worker_selection_pipeline_chat(
let chooser = if router_mode == RouterMode::KV {
Some(
model_manager
.kv_chooser_for(&component, card.kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, kv_router_config)
.await?,
)
} else {
Expand Down
5 changes: 1 addition & 4 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,13 +993,10 @@ async fn create_kv_router_from_endpoint(
block_size: usize,
kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Get component from endpoint
let component = endpoint.inner.component();

// Create ModelManager and use it to create KvRouter (ensures registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(component, block_size as u32, kv_router_config)
.kv_chooser_for(&endpoint.inner, block_size as u32, kv_router_config)
.await
.map_err(to_pyerr)?;

Expand Down
25 changes: 12 additions & 13 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;

use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::{
component::{Component, Endpoint},
storage::key_value_store::Key,
};
use dynamo_runtime::{component::Endpoint, storage::key_value_store::Key};

use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
Expand Down Expand Up @@ -292,40 +289,42 @@ impl ModelManager {

pub async fn kv_chooser_for(
&self,
component: &Component,
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
let service_name = component.service_name();
let endpoint_path = endpoint.path();

if let Some(kv_chooser) = self.get_kv_chooser(&service_name) {
if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_path) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
component = %service_name,
endpoint = %endpoint_path,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Component is requesting a different kv_cache_block_size than the existing router. \
"KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
return Ok(kv_chooser);
}

let store = component.drt().store();
let client = endpoint.client().await?;
let store = endpoint.component().drt().store();
let router_bucket = store
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None)
.await?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = Key::from_raw(format!("{}/{router_uuid}", component.path()));
let router_key = Key::from_raw(format!("{}/{router_uuid}", endpoint.path()));
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?;
router_bucket
.insert(&router_key, json_router_config.into(), 0)
.await?;

let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
component.clone(),
endpoint.clone(),
Some(client),
kv_cache_block_size,
Some(selector),
kv_router_config,
Expand All @@ -335,7 +334,7 @@ impl ModelManager {
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.insert(service_name, new_kv_chooser.clone());
.insert(endpoint_path, new_kv_chooser.clone());
Ok(new_kv_chooser)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/llm/src/discovery/watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,11 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports.

let endpoint = component.endpoint(&endpoint_id.name);
let kv_chooser = if self.router_mode == RouterMode::KV {
Some(
self.manager
.kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
.kv_chooser_for(&endpoint, card.kv_cache_block_size, self.kv_router_config)
.await?,
)
} else {
Expand Down
14 changes: 12 additions & 2 deletions lib/llm/src/entrypoint/input/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,27 @@ where
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator();

// For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV {
let Some(ref chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
};
chooser.client().clone()
} else {
client.clone()
};

// Create worker monitor only if busy_threshold is set
let worker_monitor = busy_threshold.map(|threshold| {
Arc::new(crate::discovery::KvWorkerMonitor::new(
Arc::new(client.clone()),
Arc::new(router_client.clone()),
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});

let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
router_client,
router_mode,
busy_threshold,
worker_monitor,
Expand Down
34 changes: 24 additions & 10 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::Component,
component::{Client, Endpoint},
discovery::{DiscoveryQuery, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
Expand Down Expand Up @@ -211,29 +211,37 @@ pub struct KvRouter {
kv_router_config: KvRouterConfig,

cancellation_token: tokio_util::sync::CancellationToken,

client: Client,
}

impl KvRouter {
pub async fn new(
component: Component,
endpoint: Endpoint,
client: Option<Client>,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;

let instances_rx = client.instance_source.as_ref().clone();
let client = match client {
Some(c) => c,
None => endpoint.client().await?,
};

let instance_ids_rx = client.instance_avail_watcher();

// Watch for runtime config updates via discovery interface
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
Expand All @@ -247,7 +255,7 @@ impl KvRouter {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
block_size,
Expand All @@ -265,7 +273,7 @@ impl KvRouter {
let scheduler = KvScheduler::start(
component.clone(),
block_size,
instances_rx,
instance_ids_rx,
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
Expand Down Expand Up @@ -300,9 +308,15 @@ impl KvRouter {
block_size,
kv_router_config,
cancellation_token,
client,
})
}

/// Get a reference to the client used by this KvRouter
pub fn client(&self) -> &Client {
&self.client
}

/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
Expand Down
14 changes: 9 additions & 5 deletions lib/llm/src/kv_router/prefill_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ impl PrefillRouter {
"Activating prefill router"
);

let client = endpoint.client().await?;

let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the component from the endpoint
// Create KV chooser using the endpoint
let kv_chooser = model_manager
.kv_chooser_for(endpoint.component(), kv_cache_block_size, kv_router_config)
.kv_chooser_for(&endpoint, kv_cache_block_size, kv_router_config)
.await?;

// Build the PushRouter for prefill with KV mode
// Extract client from kv_chooser to ensure shared state
let client = kv_chooser.client().clone();

// Build the PushRouter for prefill with KV mode using the shared client
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
Expand All @@ -127,6 +128,9 @@ impl PrefillRouter {
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create client for simple router
let client = endpoint.client().await?;

// Create simple push router with the frontend's router mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
Expand Down
28 changes: 13 additions & 15 deletions lib/llm/src/kv_router/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
Expand Down Expand Up @@ -96,27 +96,26 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>,
instance_ids_rx: watch::Receiver<Vec<u64>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
let instance_ids: Vec<u64> = instance_ids_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();

// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for instance in &instances {
let worker_id = instance.instance_id;
let config = runtime_configs.get(&worker_id).cloned();
for worker_id in &instance_ids {
let config = runtime_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
initial_map.insert(worker_id, config);
initial_map.insert(*worker_id, config);
}
Arc::new(RwLock::new(initial_map))
};
Expand All @@ -132,7 +131,7 @@ impl KvScheduler {
// Spawn background task to monitor and update workers_with_configs
let workers_monitor = workers_with_configs.clone();
let slots_monitor = slots.clone();
let mut instances_monitor_rx = instances_rx.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token();
tokio::spawn(async move {
Expand All @@ -144,9 +143,9 @@ impl KvScheduler {
tracing::trace!("workers monitoring task shutting down");
break;
}
result = instances_monitor_rx.changed() => {
result = instance_ids_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("endpoint watch sender shutdown in monitor");
tracing::warn!("instance IDs watch sender shutdown in monitor");
break;
}
}
Expand All @@ -159,18 +158,17 @@ impl KvScheduler {
}

// Get the latest values from both channels
let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_instance_ids = instance_ids_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();

// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
new_workers_with_configs.insert(worker_id, config);
new_workers_with_configs.insert(*worker_id, config);
}

// Update workers when instances change
Expand Down
Loading
Loading