@@ -176,6 +176,9 @@ def __init__(self,
176
176
self .enable_iter_perf_stats = model_engine .pytorch_backend_config .enable_iter_perf_stats
177
177
self .enable_iter_req_stats = model_engine .pytorch_backend_config .enable_iter_req_stats
178
178
self .stream_interval = model_engine .pytorch_backend_config .stream_interval
179
+ self .use_attention_dp_config = model_engine .pytorch_backend_config .use_attention_dp_config
180
+ self .attention_dp_time_out_iters = model_engine .pytorch_backend_config .attention_dp_time_out_iters
181
+ self .attention_dp_batching_wait_iters = model_engine .pytorch_backend_config .attention_dp_batching_wait_iters
179
182
self .num_fetch_requests_cur_rank = 0
180
183
self .num_fetch_requests = 0
181
184
self .shutdown_event = threading .Event ()
@@ -214,6 +217,9 @@ def __init__(self,
214
217
self .draft_model_engine .warmup (self .resource_manager )
215
218
216
219
self .is_shutdown = False
220
+ self .max_batch_size = max_batch_size
221
+ self .self .adp_ctx_waiting_iters_count = 0
222
+ self .adp_ctx_batching_wait_iters_count = 0
217
223
218
224
# request fetcher initialization
219
225
self .executor_request_queue = ExecutorRequestQueue (
@@ -1088,8 +1094,66 @@ def _schedule(self):
1088
1094
scheduler_output = self .scheduler .schedule_request (
1089
1095
self .active_requests , self .inflight_req_ids )
1090
1096
scheduled_requests = ScheduledRequests ()
1097
+ context_requests = scheduler_output .context_requests
1098
+ if self .enable_attention_dp :
1099
+ num_scheduled_context_requests = len (
1100
+ scheduler_output .context_requests )
1101
+ num_scheduled_generation_requests = len (
1102
+ scheduler_output .generation_requests )
1103
+ num_scheduled_tokens = sum ([
1104
+ len (req .get_tokens (0 )) for req in context_requests
1105
+ ]) + num_scheduled_generation_requests
1106
+ responses_list = self .dist .tp_allgather ([
1107
+ num_scheduled_context_requests ,
1108
+ num_scheduled_generation_requests , num_scheduled_tokens
1109
+ ])
1110
+ all_ranks_num_scheduled_context_requests = [
1111
+ response [0 ] for response in responses_list
1112
+ ]
1113
+ all_ranks_num_scheduled_generation_requests = [
1114
+ response [1 ] for response in responses_list
1115
+ ]
1116
+ all_ranks_num_scheduled_tokens = [
1117
+ response [2 ] for response in responses_list
1118
+ ]
1119
+
1120
+ all_ranks_have_free_ctx_slots = all ([
1121
+ num_gen < self .max_batch_size
1122
+ for num_gen in all_ranks_num_scheduled_generation_requests
1123
+ ])
1124
+ all_ranks_have_multi_gen = all ([
1125
+ num_gen > 1
1126
+ for num_gen in all_ranks_num_scheduled_generation_requests
1127
+ ])
1128
+ all_ranks_have_ctx_requests = all ([
1129
+ num_ctx > 0
1130
+ for num_ctx in all_ranks_num_scheduled_context_requests
1131
+ ])
1132
+
1133
+ all_ranks_have_gen_requests = all ([
1134
+ num_gen > 0
1135
+ for num_gen in all_ranks_num_scheduled_generation_requests
1136
+ ])
1137
+
1138
+ if self .use_attention_dp_config :
1139
+ # wait for all ranks have context requests
1140
+ if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests :
1141
+ self .self .adp_ctx_waiting_iters_count = 0
1142
+ # balance number of context requests across ranks
1143
+ if all_ranks_have_gen_requests :
1144
+ if self .adp_ctx_batching_wait_iters_count < self .attention_dp_batching_wait_iters :
1145
+ self .adp_ctx_batching_wait_iters_count += 1
1146
+ context_requests = []
1147
+ else :
1148
+ self .adp_ctx_batching_wait_iters_count = 0
1149
+ else :
1150
+ self .self .adp_ctx_waiting_iters_count += 1
1151
+ context_requests = []
1152
+ if self .self .adp_ctx_waiting_iters_count >= self .attention_dp_time_out_iters or not all_ranks_have_gen_requests :
1153
+ self .self .adp_ctx_waiting_iters_count = 0
1154
+ context_requests = scheduler_output .context_requests
1091
1155
1092
- scheduled_requests .context_requests = scheduler_output . context_requests
1156
+ scheduled_requests .context_requests = context_requests
1093
1157
scheduled_requests .generation_requests = scheduler_output .generation_requests
1094
1158
scheduled_requests .paused_requests = scheduler_output .paused_requests
1095
1159
return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
0 commit comments