diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 0959d1a6c..bed2d9c21 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -478,6 +478,17 @@ DEFINE_bool(enable_constrained_decoding, "that the output meets specific format or structural requirements " "through pre-defined rules."); +// --- concurrent llm worker config --- +DEFINE_uint32(llm_worker_max_concurrency, + 1, + "Concurrency for llm worker parallel execution. Less than or " + "equal to 1 means disable concurrent llm worker."); + +// --- fixedsteps scheduler config --- +DEFINE_bool(enable_fixedsteps_scheduler, + false, + "Whether to use fixsteps scheduler."); + #if defined(USE_NPU) DEFINE_string( npu_kernel_backend, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7bcd8043c..d5509542a 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -235,6 +235,9 @@ DECLARE_double(dit_cache_residual_diff_threshold); DECLARE_bool(enable_constrained_decoding); +DECLARE_uint32(llm_worker_max_concurrency); + +DECLARE_bool(enable_fixedsteps_scheduler); #if defined(USE_NPU) DECLARE_string(npu_kernel_backend); #endif diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index a0054bb23..0cad6dd4a 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -18,6 +18,7 @@ cc_library( worker.h worker_impl.h llm_worker_impl.h + concurrent_llm_worker_impl.h vlm_worker_impl.h dit_worker.h embed_worker_impl.h @@ -34,6 +35,7 @@ cc_library( worker.cpp worker_impl.cpp llm_worker_impl.cpp + concurrent_llm_worker_impl.cpp vlm_worker_impl.cpp dit_worker.cpp embed_worker_impl.cpp diff --git a/xllm/core/runtime/concurrent_llm_worker_impl.cpp b/xllm/core/runtime/concurrent_llm_worker_impl.cpp new file mode 100644 index 000000000..a5f5641fa --- /dev/null +++ b/xllm/core/runtime/concurrent_llm_worker_impl.cpp @@ -0,0 +1,465 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "concurrent_llm_worker_impl.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/device_monitor.h" +#include "common/metrics.h" +#include "common/types.h" +#include "core/common/global_flags.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_input_params.h" +#include "framework/model_loader.h" +#include "framework/state_dict/state_dict.h" +#if defined(USE_CUDA) || defined(USE_ILU) +#include "layers/cuda/flashinfer_workspace.h" +#endif +#include "models/model_registry.h" +#include "util/threadpool.h" +#include "util/timer.h" + +namespace xllm { + +ConcurrentLLMWorkerImpl::ConcurrentLLMWorkerImpl( + const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : LLMWorkerImpl(parallel_args, device, options), + max_concurrency_(FLAGS_llm_worker_max_concurrency) { + CHECK_GT(max_concurrency_, 0) + << "llm_worker_max_concurrency must be greater than 0"; + + device_.set_device(); + + // Create independent step_threadpool_ dedicated to parallel execution of + // step() Use schedule() to assign tasks, letting ThreadPool automatically + // select idle threads + step_threadpool_ = std::make_unique( + max_concurrency_, [this]() mutable { device_.set_device(); }); + + LOG(INFO) << "ConcurrentLLMWorkerImpl: Created step_threadpool_ with " + << max_concurrency_ << " threads for parallel step execution"; + +#if defined(USE_CUDA) + // initialize flashinfer workspace + layer::FlashinferWorkspace::get_instance().initialize(device_); +#endif +} + +bool ConcurrentLLMWorkerImpl::init_model(ModelContext& context) { + CHECK(model_ == nullptr) << "Model is already initialized."; + + // Create multiple model instances + model_instances_.reserve(max_concurrency_); + executor_instances_.reserve(max_concurrency_); + execute_streams_.reserve(max_concurrency_); + context_instances_.reserve(max_concurrency_); + + for (int32_t i = 0; i < max_concurrency_; ++i) { + // Create corresponding execute stream + auto stream = device_.get_stream_from_pool(); + execute_streams_.push_back(std::move(stream)); + + auto stream_guard = execute_streams_[i]->set_stream_guard(); + // Create independent ModelContext for each model instance + // Use constructor to create new context, ensuring each instance has + // independent context For NPU, this creates a new atb::Context + ModelContext instance_context(context.get_parallel_args(), + context.get_model_args(), + context.get_quant_args(), + context.get_tensor_options()); + context_instances_.push_back(std::move(instance_context)); + + // Create model instance using the corresponding context + auto model_instance = create_llm_model(context_instances_[i]); + CHECK(model_instance != nullptr) << "Failed to create model instance " << i; + model_instances_.push_back(std::move(model_instance)); + + // Create corresponding executor using the corresponding context + auto executor = + std::make_unique(model_instances_[i].get(), + context_instances_[i].get_model_args(), + device_, + options_); + executor_instances_.push_back(std::move(executor)); + + LOG(INFO) << "Created model instance " << i + << " with executor, execute stream and context"; + } + + // For compatibility with base class interface, set base class's model_ and + // model_executor_ to point to the first instance Note: Need to access base + // class's protected members model_ and model_executor_ Use reset() to set + // pointers, but note: ownership of model_ actually belongs to + // model_instances_[0]. In destructor, need to release model_ and + // model_executor_ first to avoid double deletion + model_.reset(model_instances_[0].get()); + model_executor_.reset(executor_instances_[0].get()); + + // Complete other initialization (EPLB, BeamSearcher, etc.) + // Note: These are members of base class LLMWorkerImpl, can be accessed + // directly + if (FLAGS_enable_eplb) { + eplb_executor_ = std::make_unique(model_.get(), device_); + } + + if (FLAGS_enable_beam_search_kernel) { + // Use base class's protected member beam_searcher_ + beam_searcher_ = std::make_unique(); + } + + return true; +} + +void ConcurrentLLMWorkerImpl::load_model(std::unique_ptr loader) { + CHECK(!model_instances_.empty()) + << "Model instances are not initialized. Call init_model() first."; + + // Save model weights path to create new loaders for other instances + std::string model_weights_path = loader->model_weights_path(); + + // Load weights for the first model instance (using the original loader) + model_instances_[0]->load_model(std::move(loader)); + LOG(INFO) << "Loaded weights for model instance 0"; + + // Create new loaders and load weights for other model instances + for (size_t i = 1; i < model_instances_.size(); ++i) { + auto model_loader = ModelLoader::create(model_weights_path); + CHECK(model_loader != nullptr) + << "Failed to create ModelLoader for model instance " << i; + model_instances_[i]->load_model(std::move(model_loader)); + LOG(INFO) << "Loaded weights for model instance " << i; + } + + LOG(INFO) << "Loaded weights for all " << model_instances_.size() + << " model instances"; +} + +void ConcurrentLLMWorkerImpl::allocate_instance_id_for_current_thread() { + std::thread::id current_thread_id = std::this_thread::get_id(); + + // Lock to protect the allocation process + std::lock_guard lock(allocation_mutex_); + + // Check if current thread is already in the map (may have been allocated by + // other tasks) + auto it = thread_id_to_instance_id_.find(current_thread_id); + if (it != thread_id_to_instance_id_.end()) { + return; + } + + // Select the smallest unallocated instance id + size_t instance_id = SIZE_MAX; + size_t stream_num = static_cast(max_concurrency_); + for (size_t i = 0; i < stream_num; ++i) { + if (allocated_instance_ids_.find(i) == allocated_instance_ids_.end()) { + instance_id = i; + break; + } + } + + CHECK_NE(instance_id, SIZE_MAX) + << "No available instance id, all " << max_concurrency_ + << " instance ids are allocated"; + + // Establish mapping relationship + thread_id_to_instance_id_[current_thread_id] = instance_id; + allocated_instance_ids_.insert(instance_id); + + LOG(INFO) << "Allocated instance_id " << instance_id << " for thread " + << current_thread_id; +} + +void ConcurrentLLMWorkerImpl::get_thread_model_instance( + CausalLM*& model, + Executor*& executor, + Stream*& execute_stream, + ModelContext*& context) { + std::thread::id current_thread_id = std::this_thread::get_id(); + + // If current thread hasn't been allocated an instance id yet, allocate it + // first + auto it = thread_id_to_instance_id_.find(current_thread_id); + if (it == thread_id_to_instance_id_.end()) { + allocate_instance_id_for_current_thread(); + it = thread_id_to_instance_id_.find(current_thread_id); + } + + CHECK(it != thread_id_to_instance_id_.end()) + << "Failed to find instance id for thread " << current_thread_id; + size_t instance_id = it->second; + // LOG(INFO) << "get_thread_model_instance: thread " << current_thread_id + // << " allocated instance_id " << instance_id; + + CHECK_LT(instance_id, model_instances_.size()) + << "Thread model index " << instance_id + << " exceeds model instances size " << model_instances_.size(); + + model = model_instances_[instance_id].get(); + executor = executor_instances_[instance_id].get(); + execute_stream = execute_streams_[instance_id].get(); + context = &context_instances_[instance_id]; +} + +folly::SemiFuture> +ConcurrentLLMWorkerImpl::step_async(const ForwardInput& input) { + ForwardInput input_on_device; + prepare_work_before_execute(input, input_on_device); + + folly::Promise> promise; + auto future = promise.getSemiFuture(); + + // Use schedule() to assign tasks, letting ThreadPool automatically select + // idle threads The logic for allocating instance_id happens when the task + // executes (see lambda below) + step_threadpool_->schedule([this, + input = std::move(input_on_device), + promise = std::move(promise)]() mutable { + // When the task executes, if the current thread hasn't been allocated an + // instance id yet, allocate it The allocation logic will lock, select the + // smallest unallocated instance id, establish a mapping from thread id to + // instance id Once allocation is complete, the mapping relationship is + // saved in thread_id_to_instance_id_. This way, multiple threads complete + // allocation after executing once + + // Handle hierarchy_kv_cache_transfer if needed (from base class logic) + if (hierarchy_kv_cache_transfer_ != nullptr) { + hierarchy_kv_cache_transfer_->set_layer_synchronizer(input.input_params); + } + + // Call step() using the model instance corresponding to the current thread + const auto output = this->step(input); + + // Handle enable_schedule_overlap logic (if needed) + if (!enable_schedule_overlap()) { + promise.setValue(output); + } else { + if (last_step_output_valid_ && !input.input_params.empty_kv_cache) { + // replace step i model input with true output of step i-1 + input = update_input_by_last_step_output(input); + } + + const auto output_overlap = this->step(input); + if (output_overlap.has_value()) { + if (is_driver() || FLAGS_enable_eplb) { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this] { return !is_recorded_; }); + update_last_step_output(output_overlap); + is_recorded_ = true; + cv_.notify_one(); + } else { + update_last_step_output(output_overlap); + } + } else { + if (is_driver() || FLAGS_enable_eplb) { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this] { return !is_recorded_; }); + last_step_output_valid_ = false; + is_recorded_ = true; + cv_.notify_one(); + } else { + last_step_output_valid_ = false; + } + } + promise.setValue(output_overlap); + } + }); + return future; +} + +std::optional ConcurrentLLMWorkerImpl::step( + const ForwardInput& input) { + Timer timer; + auto& sampling_params = input.sampling_params; + + // Get the model, executor, stream and context corresponding to the current + // thread + CausalLM* model = nullptr; + Executor* executor = nullptr; + Stream* execute_stream = nullptr; + ModelContext* context = nullptr; + get_thread_model_instance(model, executor, execute_stream, context); + + c10::StreamGuard stream_guard = execute_stream->set_stream_guard(); + + std::vector> futures; + + if (options_.kv_cache_transfer_mode() == "PUSH" && + !input.transfer_kv_infos.empty()) { +#if defined(USE_NPU) + std::shared_ptr layer_synchronizer = + std::make_shared( + context->get_model_args().n_layers()); + const_cast(&(input.input_params))->layer_synchronizer = + layer_synchronizer; + + futures.emplace_back( + kv_cache_transfer_->push_kv_blocks_async(input.transfer_kv_infos, + context->get_parallel_args(), + layer_synchronizer, + is_spec_draft_)); +#endif + } + + if (FLAGS_enable_eplb) { + eplb_executor_->eplb_execute(input.eplb_info); + } + + // Use the executor and model corresponding to the thread + auto hidden_states = executor->forward( + input.token_ids, input.positions, kv_caches_, input.input_params); + if (!hidden_states.defined()) { + return std::nullopt; + } + + torch::Tensor logits; + if (sampling_params.selected_token_idxes.defined()) { + logits = model->logits(hidden_states, sampling_params.selected_token_idxes); + } + + ForwardOutput output; + if (FLAGS_enable_eplb) { + output.expert_load_data = expert_load_data_; + output.prepared_layer_id = eplb_executor_->get_ready_layer_id(); + if (output.prepared_layer_id != -1) { + eplb_executor_->reset_ready_layer_id(); + } + } + + if (!enable_schedule_overlap() && !driver_ && !dp_driver_ && + !options_.enable_speculative_decode()) { + // Synchronize the current thread's stream (if using independent stream) + if (execute_stream != nullptr) { + execute_stream->synchronize(); + } else { + device_.synchronize_default_stream(); + } + + // in p-d disaggregation scene, all micro batches should be in same + // prefill/decode stage, so, to judge transfer_kv_infos.empty, + if (options_.kv_cache_transfer_mode() == "PUSH" && + !input.transfer_kv_infos.empty()) { + auto results = + folly::collectAll(futures).within(std::chrono::seconds(60)).get(); + for (const auto& result : results) { + if (!result.value()) { + LOG(ERROR) << "kv_cache_transfer_ failed"; + return std::nullopt; + } + } + } + if (FLAGS_enable_eplb) { + return output; + } + return std::nullopt; + } + + // driver prepare model output + SampleOutput sample_output; + if (sampling_params.selected_token_idxes.defined()) { + sample_output = sampler_->forward(logits, sampling_params); + output.logits = logits; + + // beam search kernel + BeamSearchOutput beam_search_output; + if (sampling_params.use_beam_search && input.acc_logprob.defined() && + input.acc_logprob.numel() > 0) { + beam_search_output = beam_searcher_->forward(input.acc_logprob, + sample_output.top_tokens, + sample_output.top_logprobs); + } + + // set sample output to output + output.sample_output = sample_output; + // carry over the sampling params + output.do_sample = sampling_params.do_sample; + output.logprobs = sampling_params.logprobs; + output.max_top_logprobs = sampling_params.max_top_logprobs; + // set beam search output to output + output.beam_search_output = beam_search_output; + } + + if (options_.enable_speculative_decode()) { + if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) { + output.sample_output.embeddings = hidden_states; + } else if (sampling_params.selected_token_idxes.defined()) { + auto embeddings = hidden_states.index_select( + /*dim=*/0, sampling_params.selected_token_idxes); + output.sample_output.embeddings = embeddings; + } + } + + // Synchronize the current thread's stream (if using independent stream) + if (execute_stream != nullptr) { + execute_stream->synchronize(); + } else { + device_.synchronize_default_stream(); + } + + if (options_.kv_cache_transfer_mode() == "PUSH" && + !input.transfer_kv_infos.empty()) { + auto results = + folly::collectAll(futures).within(std::chrono::seconds(60)).get(); + for (const auto& result : results) { + if (!result.value()) { + LOG(ERROR) << "kv_cache_transfer_ failed"; + return std::nullopt; + } + } + } + + COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); + DeviceMonitor::get_instance().update_active_activation_memory( + device_.index()); + + return output; +} + +void ConcurrentLLMWorkerImpl::update_last_step_output( + const std::optional& output) { + // Implement the same logic as the base class because the base class's method + // is private + if (output.has_value()) { + if (output.value().sample_output.next_tokens.defined()) { + last_step_output_ = std::move(output.value()); + last_step_output_valid_ = true; + } else { + if (FLAGS_enable_eplb) { + last_step_output_ = std::move(output.value()); + } + last_step_output_valid_ = false; + } + } else { + last_step_output_valid_ = false; + } +} + +} // namespace xllm diff --git a/xllm/core/runtime/concurrent_llm_worker_impl.h b/xllm/core/runtime/concurrent_llm_worker_impl.h new file mode 100644 index 000000000..571367d0b --- /dev/null +++ b/xllm/core/runtime/concurrent_llm_worker_impl.h @@ -0,0 +1,119 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "executor.h" +#include "forward_params.h" +#include "framework/model/causal_lm.h" +#include "llm_worker_impl.h" +#include "platform/device.h" +#include "platform/stream.h" +#include "util/threadpool.h" + +namespace xllm { + +// ConcurrentLLMWorkerImpl: LLM Worker supporting multi-stream parallel +// execution Inherits from LLMWorkerImpl, adds support for multiple model +// instances and execute stream pool +class ConcurrentLLMWorkerImpl : public LLMWorkerImpl { + public: + // execute_stream_num: execution parallelism, determines the number of model + // instances and execute streams + explicit ConcurrentLLMWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + ~ConcurrentLLMWorkerImpl() override { + // Release model_ and model_executor_ in destructor to avoid double deletion + // Ownership actually belongs to model_instances_[0] and + // executor_instances_[0] + model_.release(); + model_executor_.release(); + } + + // initialize model, cache manager. blocking call + bool init_model(ModelContext& context) override; + + // Override load_model to load weights for all model instances + void load_model(std::unique_ptr loader) override; + + // Override step_async to support multi-threaded parallel execution + folly::SemiFuture> step_async( + const ForwardInput& inputs) override; + + std::optional step(const ForwardInput& input) override; + + private: + // Execution parallelism (number of model instances and execute streams) + uint32_t max_concurrency_; + + // Multiple model instances (one per stream) + // Note: Multiple model instances are needed because many model forward() + // methods don't support parallel execution Multiple threads concurrently + // calling the same model's forward() will cause data races and undefined + // behavior Therefore, we need to create independent model instances for each + // parallel execution thread + std::vector> model_instances_; + + // Multiple executor instances (one per stream, corresponding to + // model_instances_) + std::vector> executor_instances_; + + // Execute stream pool (one stream per model instance) + std::vector> execute_streams_; + + // Multiple ModelContext instances (one per model instance) + // Each context instance contains independent model args, parallel args, etc. + std::vector context_instances_; + + // Independent ThreadPool dedicated to parallel execution of step() + std::unique_ptr step_threadpool_; + + // Mapping from thread id to instance id (used to find the instance id for the + // current thread) + std::unordered_map thread_id_to_instance_id_; + + // Set of allocated instance ids (used to select the smallest unallocated + // instance id) + std::set allocated_instance_ids_; + + // Mutex protecting the allocation process + std::mutex allocation_mutex_; + + // Helper method: Get the corresponding model, executor, stream and context + // based on thread ID + void get_thread_model_instance(CausalLM*& model, + Executor*& executor, + Stream*& execute_stream, + ModelContext*& context); + + // Allocate instance id for the current thread (thread-safe) + void allocate_instance_id_for_current_thread(); + + // Update last_step_output (because the base class's update_last_step_output + // is private) + void update_last_step_output(const std::optional& output); +}; + +} // namespace xllm diff --git a/xllm/core/runtime/llm_worker_impl.h b/xllm/core/runtime/llm_worker_impl.h index 597f705c5..282c67b96 100644 --- a/xllm/core/runtime/llm_worker_impl.h +++ b/xllm/core/runtime/llm_worker_impl.h @@ -56,7 +56,7 @@ class LLMWorkerImpl : public WorkerImpl { model_->set_word_embedding(embedding); }; - private: + protected: std::unique_ptr beam_searcher_; }; diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index d4d0124b1..533e49001 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -26,9 +26,11 @@ limitations under the License. #include #include "common/metrics.h" +#include "core/common/global_flags.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/state_dict/state_dict.h" +#include "runtime/concurrent_llm_worker_impl.h" #include "runtime/embed_vlm_worker_impl.h" #include "runtime/embed_worker_impl.h" #include "runtime/llm_worker_impl.h" @@ -44,7 +46,11 @@ Worker::Worker(const ParallelArgs& parallel_args, if (options.enable_speculative_decode()) { impl_ = new SpeculativeWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::LLM) { - impl_ = new LLMWorkerImpl(parallel_args, device, options); + if (FLAGS_llm_worker_max_concurrency > 1) { + impl_ = new ConcurrentLLMWorkerImpl(parallel_args, device, options); + } else { + impl_ = new LLMWorkerImpl(parallel_args, device, options); + } } else if (worker_type == WorkerType::VLM) { impl_ = new VLMWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::ELM) { diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index d694b3b1f..5a127184c 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -17,6 +17,7 @@ cc_library( scheduler.h dit_scheduler.h prefill_only_scheduler.h + fixedsteps_scheduler.h scheduler_factory.h decode_priority_queue.h perf_model.h @@ -29,6 +30,7 @@ cc_library( async_response_processor.cpp dit_scheduler.cpp prefill_only_scheduler.cpp + fixedsteps_scheduler.cpp scheduler_factory.cpp perf_model.cpp DEPS diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index 59af7483b..7476c2c14 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -898,7 +898,6 @@ void ContinuousScheduler::step(const absl::Duration& timeout) { if (all_empty) { return; } - if (!options_.enable_pd_ooc()) { engine_->step(batch); } else { diff --git a/xllm/core/scheduler/fixedsteps_scheduler.cpp b/xllm/core/scheduler/fixedsteps_scheduler.cpp new file mode 100644 index 000000000..71f6298a0 --- /dev/null +++ b/xllm/core/scheduler/fixedsteps_scheduler.cpp @@ -0,0 +1,350 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://github.com/jd-opensource/xllm/blob/main/LICENSE +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "fixedsteps_scheduler.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/global_flags.h" +#include "common/metrics.h" +#include "common/types.h" +#include "distributed_runtime/engine.h" +#include "framework/batch/batch.h" +#include "framework/batch/batch_factory.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "util/threadpool.h" + +namespace xllm { + +namespace { +constexpr size_t kRequestQueueSize = 100000; +} // namespace + +FixedStepsScheduler::FixedStepsScheduler( + Engine* engine, + const ContinuousScheduler::Options& options) + : ContinuousScheduler(engine, options), + step_threadpool_(std::make_unique( + static_cast(FLAGS_llm_worker_max_concurrency))) {} + +bool FixedStepsScheduler::add_request(std::shared_ptr& request) { + CHECK(request != nullptr); + CHECK(!request->sequences().empty()); + + if (request_queue_.write(request)) { //.get() + // take over the ownership of the request + // request.release(); + return true; + } + // queue is full + return false; +} + +void FixedStepsScheduler::handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests) { + // Handle new request prompt first. + // Include those requests that are preempted by others. + // + // schedule the prefill requests in the waiting priority queue until budgets + // are exhausted. + // When the KV Cache usage reaches the threshold, prefill requests will no + // longer be scheduled to avoid frequent preemption. + // + // NOTE: preempted requests will be pushed in waiting_priority_queue, + // they may contian many sequences, so we should check here. + bool budget_exhausted = false; + bool blocks_exhausted = false; + while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && + remaining_token_budget > 0 && + kv_cache_manager_->kv_cache_utilization() < + FLAGS_prefill_scheduling_memory_usage_threshold) { + std::shared_ptr request(waiting_priority_queue_.top()); + if (request->finished() || request->cancelled()) { + kv_cache_manager_->deallocate(request.get()); + finished_requests.emplace_back(request); + // remove the request from the priority queue + waiting_priority_queue_.pop(); + continue; + } + + const size_t num_sequences = request->sequences().size(); + if (!request->preempted()) { + CHECK(num_sequences == 1) + << "Waiting request should have only one sequence."; + } + + // TODO: FIXME later + // Optimization of the scheduling algorithm under multiple sequences + size_t allocated_tokens = 0; + size_t allocated_seqs = 0; + double allocated_estimate_latency = 0; + bool can_schedule = true; + std::vector prefill_sequences; + std::vector prefill_sequences_budget; + prefill_sequences.reserve(request->sequences().size()); + prefill_sequences_budget.reserve(request->sequences().size()); + for (auto& prefill_sequence : request->sequences()) { + if (prefill_sequence->finished()) { + continue; + } + size_t num_tokens = prefill_sequence->num_need_compute_tokens(); + if (remaining_token_budget < allocated_tokens + num_tokens || + remaining_seq_budget < allocated_seqs + 1) { + can_schedule = false; + budget_exhausted = true; + break; + } + + // allocate KV cache blocks for the sequence + if (!kv_cache_manager_->allocate(prefill_sequence.get())) { + can_schedule = false; + blocks_exhausted = true; + break; + } + + prefill_sequences_budget.emplace_back(num_tokens); + prefill_sequences.emplace_back(prefill_sequence.get()); + allocated_tokens += num_tokens; + allocated_seqs += 1; + } + + if (!can_schedule) { + for (auto& seq : prefill_sequences) { + // release shared blocks + kv_cache_manager_->deallocate(seq); + } + break; + } + + if (prefill_sequences.empty()) { + continue; + } + + remaining_token_budget -= allocated_tokens; + remaining_seq_budget -= allocated_seqs; + waiting_priority_queue_.pop(); + running_requests_.emplace_back(request); + running_sequences_.insert(running_sequences_.end(), + prefill_sequences.begin(), + prefill_sequences.end()); + running_sequences_budgets_.insert(running_sequences_budgets_.end(), + prefill_sequences_budget.begin(), + prefill_sequences_budget.end()); + } + + if (running_sequences_.empty() && !waiting_priority_queue_.empty() && + running_queue_->empty()) { + LOG(ERROR) + << "Request prompt is too long, no enough budget/memory to schedule " + "a single sequence."; + // no enough memory to schedule single sequence, just finish the request + std::shared_ptr request(waiting_priority_queue_.top()); + waiting_priority_queue_.pop(); + // block_manager_->release_blocks_for(request.get()); + response_processor_->process_failed_request( + request, + {StatusCode::RESOURCE_EXHAUSTED, + "No enough budget to schedule single sequence."}); + } +} + +std::vector FixedStepsScheduler::prepare_batch() { + Timer timer; + // propogate new requests to waiting_priority_queue_ + // Include those requests that are preempted by others. + std::shared_ptr request; + // read from request queue then push to waiting priority queue + while (request_queue_.read(request)) { + CHECK(request); + + // expand sequences to the target number if prefix cache is disabled. + if (!enable_prefix_cache_) { + // expand sequences to the target number + request->expand_sequences(false); + } + + if (request->sequences()[0]->kv_state().kv_cache_tokens_num() == 0) { + waiting_priority_queue_.push(request); + } else { + // request from prefill instance in disagge pd mode. + running_requests_.emplace_back(request); + } + } + + std::vector> finished_requests; + for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + ++it) { + if (*it == nullptr) { + continue; + } + std::shared_ptr request = *it; + request->update_connection_status(); + if (request->finished() || request->cancelled()) { + kv_cache_manager_->deallocate(request.get()); + finished_requests.emplace_back(request); + // finished request is set to nullptr + *it = nullptr; + } + } + running_requests_.clear(); + running_sequences_.clear(); + running_sequences_budgets_.clear(); + + // remaining budget for the current batch + size_t remaining_token_budget = options_.max_tokens_per_batch(); + size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); + size_t num_preempted_requests = 0; + + handle_prefill_requests( + remaining_token_budget, remaining_seq_budget, finished_requests); + + // only forward once, no decode requests + // handle_decode_requests( + // remaining_token_budget, remaining_seq_budget, num_preempted_requests); + + if (!finished_requests.empty()) { + response_processor_->process_completed_requests(finished_requests); + } + + // update the batch + auto batches = + BatchFactory::get_instance(options_.dp_size()) + ->create_batches(running_requests_, + running_sequences_, + running_sequences_budgets_, + kv_cache_manager_->get_swap_block_transfer_infos()); + + bool is_batches_empty = + (std::all_of(batches.begin(), batches.end(), [](const Batch& one_batch) { + return one_batch.empty(); + })); + if (!is_batches_empty) { + // only update the scheduling latency when there are requests to process + COUNTER_ADD(scheduling_latency_seconds, timer.elapsed_seconds()); + } + + GAUGE_SET(num_pending_requests, + pending_requests_.load(std::memory_order_relaxed)); + GAUGE_SET(num_running_requests, running_requests_.size()); + GAUGE_SET(num_waiting_requests, + waiting_priority_queue_.size() + running_queue_->size()); + + GAUGE_ADD(num_preempted_requests, num_preempted_requests); + + GAUGE_SET(num_running_sequences, running_sequences_.size()); + + GAUGE_SET(kv_cache_utilization_perc, + kv_cache_manager_->kv_cache_utilization()); + if (!FLAGS_enable_continuous_kvcache) { + GAUGE_SET(num_blocks_in_prefix_cache, + kv_cache_manager_->num_blocks_in_prefix_cache().size()); + GAUGE_SET(num_free_blocks, kv_cache_manager_->num_free_blocks().size()); + GAUGE_SET(num_used_blocks, kv_cache_manager_->num_used_blocks().size()); + } + return batches; +} + +ScheduleResult FixedStepsScheduler::schedule_request( + const absl::Duration& timeout) { + const auto deadline = absl::Now() + timeout; + ScheduleResult result; + while (true) { + result.batches = prepare_batch(); + bool all_empty = + std::all_of(result.batches.begin(), + result.batches.end(), + [](const Batch& one_batch) { return one_batch.empty(); }); + if (!all_empty) { + // Move running_requests_ and running_sequences_ into result + result.requests = std::move(running_requests_); + result.sequences = std::move(running_sequences_); + return result; + } + const auto now = absl::Now(); + if (now > deadline) { + break; + } + // wait for new requests to arrive + constexpr uint64_t kStepSleepTimeMs = 1; + const auto time_to_sleep = + std::min(absl::Milliseconds(kStepSleepTimeMs), deadline - now); + absl::SleepFor(time_to_sleep); + } + // return empty result + return result; +} + +// step the scheduler forward by one step +// may get blocked if there are no requests to process +void FixedStepsScheduler::step(const absl::Duration& timeout) { + if (!options_.enable_schedule_overlap()) { + // get a new batch of requests + ScheduleResult result = schedule_request(timeout); + bool all_empty = + std::all_of(result.batches.begin(), + result.batches.end(), + [](const Batch& one_batch) { return one_batch.empty(); }); + if (all_empty) { + return; + } + + // Submit task to thread pool for asynchronous execution + // After engine_->step() completes, process finished/cancelled requests + auto function = [this, + batches = std::move(result.batches), + requests = std::move(result.requests), + sequences = std::move(result.sequences)]() mutable { + engine_->step(batches); + kv_cache_manager_->reset_transfer_infos(); + + // After step completes, check and process finished/cancelled requests + std::vector> finished_requests; + for (auto& request : requests) { + if (request) { + request->update_connection_status(); + if (request->finished() || request->cancelled()) { + kv_cache_manager_->deallocate(request.get()); + finished_requests.emplace_back(request); + } + } + } + + // Process finished requests + if (!finished_requests.empty()) { + response_processor_->process_completed_requests(finished_requests); + } + }; + if (FLAGS_llm_worker_max_concurrency > 1) { + step_threadpool_->schedule(function); + } else { + function(); + } + + // Return immediately to allow the next step() call to execute in parallel + } else { + LOG(ERROR) << "FixedStepsScheduler::step() not supported with " + "enable_schedule_overlap"; + } +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/fixedsteps_scheduler.h b/xllm/core/scheduler/fixedsteps_scheduler.h new file mode 100644 index 000000000..9ffc56e00 --- /dev/null +++ b/xllm/core/scheduler/fixedsteps_scheduler.h @@ -0,0 +1,71 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://github.com/jd-opensource/xllm/blob/main/LICENSE +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "async_response_processor.h" +#include "common/macros.h" +#include "common/types.h" +#include "framework/batch/batch.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/xservice_client.h" +#include "scheduler.h" +#include "scheduler/continuous_scheduler.h" +#include "util/threadpool.h" + +namespace xllm { +class Engine; + +// Return value structure for schedule_request +struct ScheduleResult { + std::vector batches; + std::vector> requests; + std::vector sequences; +}; + +class FixedStepsScheduler final : public ContinuousScheduler { + public: + FixedStepsScheduler(Engine* engine, + const ContinuousScheduler::Options& options); + virtual ~FixedStepsScheduler() = default; + + bool add_request(std::shared_ptr& request) override; + + // step the scheduler forward by one step + // may get blocked if there are no requests to process + void step(const absl::Duration& timeout) override; + + private: + ScheduleResult schedule_request(const absl::Duration& timeout); + + // build a batch of requests from the priority queue + virtual std::vector prepare_batch(); + + void handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests); + + // Scheduler thread pool for parallel execution of step() + std::unique_ptr step_threadpool_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/scheduler_factory.cpp b/xllm/core/scheduler/scheduler_factory.cpp index 8be5a8b84..f5de5e131 100644 --- a/xllm/core/scheduler/scheduler_factory.cpp +++ b/xllm/core/scheduler/scheduler_factory.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "scheduler/continuous_scheduler.h" #include "scheduler/disagg_pd_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixedsteps_scheduler.h" #include "scheduler/pd_ooc_scheduler.h" #include "scheduler/prefill_only_scheduler.h" #include "scheduler/zero_eviction_scheduler.h" @@ -28,6 +29,9 @@ namespace xllm { std::unique_ptr create_continuous_scheduler( Engine* engine, ContinuousScheduler::Options options) { + if (FLAGS_enable_fixedsteps_scheduler) { + return std::make_unique(engine, options); + } if (options.enable_disagg_pd()) { if (options.enable_pd_ooc()) { return std::make_unique(engine, options);