@@ -176,6 +176,9 @@ def __init__(self,
176
176
self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
177
177
self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
178
178
self .stream_interval = model_engine .pytorch_backend_config .stream_interval
179
+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
180
+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
181
+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
179
182
self .num_fetch_requests_cur_rank = 0
180
183
self .num_fetch_requests = 0
181
184
self .shutdown_event = threading .Event ()
@@ -214,6 +217,9 @@ def __init__(self,
214
217
self .draft_model_engine .warmup (self .resource_manager )
215
218
216
219
self .is_shutdown = False
220
+ self .max_batch_size = max_batch_size
221
+ self .self .adp_ctx_waiting_iters_count = 0
222
+ self .adp_ctx_batching_wait_iters_count = 0
217
223
218
224
# request fetcher initialization
219
225
self .executor_request_queue = ExecutorRequestQueue (
@@ -1119,8 +1125,66 @@ def _schedule(self):
1119
1125
scheduler_output = self .scheduler .schedule_request (
1120
1126
self .active_requests , self .inflight_req_ids )
1121
1127
scheduled_requests = ScheduledRequests ()
1128
+ context_requests = scheduler_output .context_requests
1129
+ if self .enable_attention_dp :
1130
+ num_scheduled_context_requests = len (
1131
+ scheduler_output .context_requests )
1132
+ num_scheduled_generation_requests = len (
1133
+ scheduler_output .generation_requests )
1134
+ num_scheduled_tokens = sum ([
1135
+ len (req .get_tokens (0 )) for req in context_requests
1136
+ ]) + num_scheduled_generation_requests
1137
+ responses_list = self .dist .tp_allgather ([
1138
+ num_scheduled_context_requests ,
1139
+ num_scheduled_generation_requests , num_scheduled_tokens
1140
+ ])
1141
+ all_ranks_num_scheduled_context_requests = [
1142
+ response [0 ] for response in responses_list
1143
+ ]
1144
+ all_ranks_num_scheduled_generation_requests = [
1145
+ response [1 ] for response in responses_list
1146
+ ]
1147
+ all_ranks_num_scheduled_tokens = [
1148
+ response [2 ] for response in responses_list
1149
+ ]
1150
+
1151
+ all_ranks_have_free_ctx_slots = all ([
1152
+ num_gen < self .max_batch_size
1153
+ for num_gen in all_ranks_num_scheduled_generation_requests
1154
+ ])
1155
+ all_ranks_have_multi_gen = all ([
1156
+ num_gen > 1
1157
+ for num_gen in all_ranks_num_scheduled_generation_requests
1158
+ ])
1159
+ all_ranks_have_ctx_requests = all ([
1160
+ num_ctx > 0
1161
+ for num_ctx in all_ranks_num_scheduled_context_requests
1162
+ ])
1163
+
1164
+ all_ranks_have_gen_requests = all ([
1165
+ num_gen > 0
1166
+ for num_gen in all_ranks_num_scheduled_generation_requests
1167
+ ])
1168
+
1169
+ if self .use_attention_dp_config :
1170
+ # wait for all ranks have context requests
1171
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1172
+ self .self .adp_ctx_waiting_iters_count = 0
1173
+ # balance number of context requests across ranks
1174
+ if all_ranks_have_gen_requests :
1175
+ if self .adp_ctx_batching_wait_iters_count < self .attention_dp_batching_wait_iters :
1176
+ self .adp_ctx_batching_wait_iters_count += 1
1177
+ context_requests = []
1178
+ else :
1179
+ self .adp_ctx_batching_wait_iters_count = 0
1180
+ else :
1181
+ self .self .adp_ctx_waiting_iters_count += 1
1182
+ context_requests = []
1183
+ if self .self .adp_ctx_waiting_iters_count >= self .attention_dp_time_out_iters or not all_ranks_have_gen_requests :
1184
+ self .self .adp_ctx_waiting_iters_count = 0
1185
+ context_requests = scheduler_output .context_requests
1122
1186
1123
- scheduled_requests .context_requests = scheduler_output . context_requests
1187
+ scheduled_requests .context_requests = context_requests
1124
1188
scheduled_requests .generation_requests = scheduler_output .generation_requests
1125
1189
scheduled_requests .paused_requests = scheduler_output .paused_requests
1126
1190
return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments