@@ -152,10 +152,13 @@ class PipelineConfig:
152
152
emb_lookup_stream (str): The stream to use for embedding lookups.
153
153
Only used by certain pipeline types (e.g., "fused").
154
154
Default is "data_dist".
155
+ apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
156
+ Default is False.
155
157
"""
156
158
157
159
pipeline : str = "base"
158
160
emb_lookup_stream : str = "data_dist"
161
+ apply_jit : bool = False
159
162
160
163
161
164
@dataclass
@@ -164,6 +167,7 @@ class ModelSelectionConfig:
164
167
165
168
# Common config for all model types
166
169
batch_size : int = 8192
170
+ batch_sizes : Optional [List [int ]] = None
167
171
num_float_features : int = 10
168
172
feature_pooling_avg : int = 10
169
173
use_offsets : bool = False
@@ -216,6 +220,7 @@ def main(
216
220
model_config = create_model_config (
217
221
model_name = model_selection .model_name ,
218
222
batch_size = model_selection .batch_size ,
223
+ batch_sizes = model_selection .batch_sizes ,
219
224
num_float_features = model_selection .num_float_features ,
220
225
feature_pooling_avg = model_selection .feature_pooling_avg ,
221
226
use_offsets = model_selection .use_offsets ,
@@ -282,6 +287,15 @@ def runner(
282
287
compute_device = ctx .device .type ,
283
288
)
284
289
290
+ batch_sizes = model_config .batch_sizes
291
+
292
+ if batch_sizes is None :
293
+ batch_sizes = [model_config .batch_size ] * run_option .num_batches
294
+ else :
295
+ assert (
296
+ len (batch_sizes ) == run_option .num_batches
297
+ ), "The length of batch_sizes must match the number of batches."
298
+
285
299
# Create a planner for sharding based on the specified type
286
300
planner = generate_planner (
287
301
planner_type = run_option .planner_type ,
@@ -290,16 +304,15 @@ def runner(
290
304
weighted_tables = weighted_tables ,
291
305
sharding_type = run_option .sharding_type ,
292
306
compute_kernel = run_option .compute_kernel ,
293
- num_batches = run_option .num_batches ,
294
- batch_size = model_config .batch_size ,
307
+ batch_sizes = batch_sizes ,
295
308
pooling_factors = run_option .pooling_factors ,
296
309
num_poolings = run_option .num_poolings ,
297
310
)
298
311
bench_inputs = generate_data (
299
312
tables = tables ,
300
313
weighted_tables = weighted_tables ,
301
314
model_config = model_config ,
302
- num_batches = run_option . num_batches ,
315
+ batch_sizes = batch_sizes ,
303
316
)
304
317
305
318
# Prepare fused_params for sparse optimizer
@@ -329,14 +342,6 @@ def runner(
329
342
dense_weight_decay = run_option .dense_weight_decay ,
330
343
planner = planner ,
331
344
)
332
- pipeline = generate_pipeline (
333
- pipeline_type = pipeline_config .pipeline ,
334
- emb_lookup_stream = pipeline_config .emb_lookup_stream ,
335
- model = sharded_model ,
336
- opt = optimizer ,
337
- device = ctx .device ,
338
- )
339
- pipeline .progress (iter (bench_inputs ))
340
345
341
346
def _func_to_benchmark (
342
347
bench_inputs : List [ModelInput ],
@@ -350,20 +355,47 @@ def _func_to_benchmark(
350
355
except StopIteration :
351
356
break
352
357
353
- result = benchmark_func (
354
- name = type (pipeline ).__name__ ,
355
- bench_inputs = bench_inputs , # pyre-ignore
356
- prof_inputs = bench_inputs , # pyre-ignore
357
- num_benchmarks = 5 ,
358
- num_profiles = 2 ,
359
- profile_dir = run_option .profile ,
360
- world_size = run_option .world_size ,
361
- func_to_benchmark = _func_to_benchmark ,
362
- benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
363
- rank = rank ,
358
+ # Run comparison if apply_jit is True, otherwise run single benchmark
359
+ jit_configs = (
360
+ [(True , "WithJIT" ), (False , "WithoutJIT" )]
361
+ if pipeline_config .apply_jit
362
+ else [(False , "" )]
364
363
)
364
+ results = []
365
+
366
+ for apply_jit , jit_suffix in jit_configs :
367
+ pipeline = generate_pipeline (
368
+ pipeline_type = pipeline_config .pipeline ,
369
+ emb_lookup_stream = pipeline_config .emb_lookup_stream ,
370
+ model = sharded_model ,
371
+ opt = optimizer ,
372
+ device = ctx .device ,
373
+ apply_jit = apply_jit ,
374
+ )
375
+ pipeline .progress (iter (bench_inputs ))
376
+
377
+ name = (
378
+ f"{ type (pipeline ).__name__ } { jit_suffix } "
379
+ if jit_suffix
380
+ else type (pipeline ).__name__
381
+ )
382
+ result = benchmark_func (
383
+ name = name ,
384
+ bench_inputs = bench_inputs , # pyre-ignore
385
+ prof_inputs = bench_inputs , # pyre-ignore
386
+ num_benchmarks = 5 ,
387
+ num_profiles = 2 ,
388
+ profile_dir = run_option .profile ,
389
+ world_size = run_option .world_size ,
390
+ func_to_benchmark = _func_to_benchmark ,
391
+ benchmark_func_kwargs = {"model" : sharded_model , "pipeline" : pipeline },
392
+ rank = rank ,
393
+ )
394
+ results .append (result )
395
+
365
396
if rank == 0 :
366
- print (result )
397
+ for result in results :
398
+ print (result )
367
399
368
400
369
401
if __name__ == "__main__" :
0 commit comments