@@ -24,6 +24,7 @@ class TestAscendSchedulerConfig(TestBase):
24
24
def setUp (self ):
25
25
self .basic_scheduler_config = SchedulerConfig (
26
26
max_num_batched_tokens = 8192 ,
27
+ max_model_len = 8192 ,
27
28
is_multimodal_model = False ,
28
29
send_delta_data = False ,
29
30
scheduler_delay_factor = 0 ,
@@ -51,6 +52,7 @@ def test_initialize_from_config_with_override(self):
51
52
num_scheduler_steps = 1 ,
52
53
scheduler_cls = "vllm_ascend.core.scheduler.AscendScheduler" ,
53
54
max_num_batched_tokens = 2048 ,
55
+ max_model_len = 2048 ,
54
56
),
55
57
)
56
58
self .assertEqual (ascend_config .enable_chunked_prefill , False )
@@ -65,7 +67,11 @@ def test_not_implemented_policy(self):
65
67
with self .assertRaises (NotImplementedError ) as context :
66
68
AscendSchedulerConfig .initialize_from_config (
67
69
self .basic_scheduler_config ,
68
- AscendSchedulerConfig (policy = "custom_policy" , ),
70
+ AscendSchedulerConfig (
71
+ policy = "custom_policy" ,
72
+ max_num_batched_tokens = 2048 ,
73
+ max_model_len = 2048 ,
74
+ ),
69
75
)
70
76
self .assertIn (
71
77
"currently AscendScheduler only supports fcfs policy" ,
@@ -83,7 +89,11 @@ def test_not_implemented_multi_step(self):
83
89
with self .assertRaises (NotImplementedError ) as context :
84
90
AscendSchedulerConfig .initialize_from_config (
85
91
self .basic_scheduler_config ,
86
- AscendSchedulerConfig (num_scheduler_steps = 2 ),
92
+ AscendSchedulerConfig (
93
+ num_scheduler_steps = 2 ,
94
+ max_num_batched_tokens = 2048 ,
95
+ max_model_len = 2048 ,
96
+ ),
87
97
)
88
98
self .assertIn (
89
99
"currently AscendScheduler doesn't support multi-step" ,
@@ -94,7 +104,12 @@ def test_not_implemented_send_delta_data(self):
94
104
with self .assertRaises (NotImplementedError ) as context :
95
105
AscendSchedulerConfig .initialize_from_config (
96
106
self .basic_scheduler_config ,
97
- AscendSchedulerConfig (send_delta_data = True ))
107
+ AscendSchedulerConfig (
108
+ send_delta_data = True ,
109
+ max_num_batched_tokens = 2048 ,
110
+ max_model_len = 2048 ,
111
+ ),
112
+ )
98
113
self .assertIn (
99
114
"currently AscendScheduler doesn't support send_delta_data" ,
100
115
str (context .exception ),
@@ -104,7 +119,12 @@ def test_not_implemented_delay_factor(self):
104
119
with self .assertRaises (NotImplementedError ) as context :
105
120
AscendSchedulerConfig .initialize_from_config (
106
121
self .basic_scheduler_config ,
107
- AscendSchedulerConfig (delay_factor = 1 ))
122
+ AscendSchedulerConfig (
123
+ delay_factor = 1 ,
124
+ max_num_batched_tokens = 2048 ,
125
+ max_model_len = 2048 ,
126
+ ),
127
+ )
108
128
self .assertIn (
109
129
"currently AscendScheduler doesn't support scheduler_delay_factor" ,
110
130
str (context .exception ),
@@ -115,3 +135,33 @@ def test_no_override(self):
115
135
self .basic_scheduler_config , {})
116
136
self .assertEqual (ascend_config .max_num_encoder_input_tokens , 8192 )
117
137
self .assertEqual (ascend_config .encoder_cache_size , 8192 )
138
+
139
+ def test_valid_config_with_chunked_prefill (self ):
140
+ ascend_config = AscendSchedulerConfig .initialize_from_config (
141
+ self .basic_scheduler_config ,
142
+ AscendSchedulerConfig (
143
+ enable_chunked_prefill = True ,
144
+ max_num_batched_tokens = 2048 ,
145
+ max_model_len = 4096 ,
146
+ ),
147
+ )
148
+ self .assertEqual (ascend_config .max_num_batched_tokens , 2048 )
149
+ self .assertEqual (ascend_config .max_model_len , 4096 )
150
+ self .assertTrue (ascend_config .enable_chunked_prefill )
151
+
152
+ def test_invalid_config_without_chunked_prefill (self ):
153
+ with self .assertRaises (ValueError ) as context :
154
+ AscendSchedulerConfig .initialize_from_config (
155
+ self .basic_scheduler_config ,
156
+ AscendSchedulerConfig (
157
+ enable_chunked_prefill = False ,
158
+ max_num_batched_tokens = 2048 ,
159
+ max_model_len = 4096 ,
160
+ ),
161
+ )
162
+ self .assertIn (
163
+ "Ascend scheduler is enabled without chunked prefill feature" ,
164
+ str (context .exception )
165
+ )
166
+ self .assertIn ("max_num_batched_tokens (2048)" , str (context .exception ))
167
+ self .assertIn ("max_model_len (4096)" , str (context .exception ))
0 commit comments