Skip to content

Commit da486e3

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Add JIT and Variable Batch Support to Benchmark (#3131)
Summary: Pull Request resolved: #3131 This update introduces an option to apply Just-In-Time (JIT) compilation in the training pipeline configuration for performance comparison. It also adds support for variable batch sizes, including the generation of Variable Batch KeyedJaggedTensor (VB-KJT). Reviewed By: aliafzal Differential Revision: D76928208 fbshipit-source-id: 0967e04d8d671c099e8f1ac9585034e8c8b124f9
1 parent 3790c26 commit da486e3

File tree

2 files changed

+78
-37
lines changed

2 files changed

+78
-37
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class BaseModelConfig(ABC):
6262

6363
# Common parameters for all model types
6464
batch_size: int
65+
batch_sizes: Optional[List[int]]
6566
num_float_features: int
6667
feature_pooling_avg: int
6768
use_offsets: bool
@@ -283,6 +284,7 @@ def generate_pipeline(
283284
model: nn.Module,
284285
opt: torch.optim.Optimizer,
285286
device: torch.device,
287+
apply_jit: bool = False,
286288
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
287289
"""
288290
Generate a training pipeline instance based on the configuration.
@@ -303,6 +305,8 @@ def generate_pipeline(
303305
model (nn.Module): The model to be trained.
304306
opt (torch.optim.Optimizer): The optimizer to use for training.
305307
device (torch.device): The device to run the training on.
308+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
309+
Default is False.
306310
307311
Returns:
308312
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
@@ -324,20 +328,28 @@ def generate_pipeline(
324328

325329
if pipeline_type == "semi":
326330
return TrainPipelineSemiSync(
327-
model=model, optimizer=opt, device=device, start_batch=0
331+
model=model,
332+
optimizer=opt,
333+
device=device,
334+
start_batch=0,
335+
apply_jit=apply_jit,
328336
)
329337
elif pipeline_type == "fused":
330338
return TrainPipelineFusedSparseDist(
331339
model=model,
332340
optimizer=opt,
333341
device=device,
334342
emb_lookup_stream=emb_lookup_stream,
343+
apply_jit=apply_jit,
335344
)
336-
elif pipeline_type in _pipeline_cls:
337-
Pipeline = _pipeline_cls[pipeline_type]
338-
return Pipeline(model=model, optimizer=opt, device=device)
345+
elif pipeline_type == "base":
346+
assert apply_jit is False, "JIT is not supported for base pipeline"
347+
348+
return TrainPipelineBase(model=model, optimizer=opt, device=device)
339349
else:
340-
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
350+
Pipeline = _pipeline_cls[pipeline_type]
351+
# pyre-ignore[28]
352+
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)
341353

342354

343355
def generate_planner(
@@ -347,8 +359,7 @@ def generate_planner(
347359
weighted_tables: Optional[List[EmbeddingBagConfig]],
348360
sharding_type: ShardingType,
349361
compute_kernel: EmbeddingComputeKernel,
350-
num_batches: int,
351-
batch_size: int,
362+
batch_sizes: List[int],
352363
pooling_factors: Optional[List[float]],
353364
num_poolings: Optional[List[float]],
354365
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
@@ -362,8 +373,7 @@ def generate_planner(
362373
weighted_tables: List of weighted embedding tables
363374
sharding_type: Strategy for sharding embedding tables
364375
compute_kernel: Compute kernel to use for embedding tables
365-
num_batches: Number of batches to process
366-
batch_size: Size of each batch
376+
batch_sizes: Sizes of each batch
367377
pooling_factors: Pooling factors for each feature of the table
368378
num_poolings: Number of poolings for each feature of the table
369379
@@ -375,15 +385,14 @@ def generate_planner(
375385
"""
376386
# Create parameter constraints for tables
377387
constraints = {}
388+
num_batches = len(batch_sizes)
378389

379390
if pooling_factors is None:
380391
pooling_factors = [POOLING_FACTOR] * num_batches
381392

382393
if num_poolings is None:
383394
num_poolings = [NUM_POOLINGS] * num_batches
384395

385-
batch_sizes = [batch_size] * num_batches
386-
387396
assert (
388397
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
389398
), "The length of pooling_factors and num_poolings must match the number of batches."
@@ -497,7 +506,7 @@ def generate_data(
497506
tables: List[EmbeddingBagConfig],
498507
weighted_tables: List[EmbeddingBagConfig],
499508
model_config: BaseModelConfig,
500-
num_batches: int,
509+
batch_sizes: List[int],
501510
) -> List[ModelInput]:
502511
"""
503512
Generate model input data for benchmarking.
@@ -515,7 +524,7 @@ def generate_data(
515524

516525
return [
517526
ModelInput.generate(
518-
batch_size=model_config.batch_size,
527+
batch_size=batch_size,
519528
tables=tables,
520529
weighted_tables=weighted_tables,
521530
num_float_features=model_config.num_float_features,
@@ -533,5 +542,5 @@ def generate_data(
533542
),
534543
pin_memory=model_config.pin_memory,
535544
)
536-
for _ in range(num_batches)
545+
for batch_size in batch_sizes
537546
]

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,13 @@ class PipelineConfig:
152152
emb_lookup_stream (str): The stream to use for embedding lookups.
153153
Only used by certain pipeline types (e.g., "fused").
154154
Default is "data_dist".
155+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
156+
Default is False.
155157
"""
156158

157159
pipeline: str = "base"
158160
emb_lookup_stream: str = "data_dist"
161+
apply_jit: bool = False
159162

160163

161164
@dataclass
@@ -164,6 +167,7 @@ class ModelSelectionConfig:
164167

165168
# Common config for all model types
166169
batch_size: int = 8192
170+
batch_sizes: Optional[List[int]] = None
167171
num_float_features: int = 10
168172
feature_pooling_avg: int = 10
169173
use_offsets: bool = False
@@ -216,6 +220,7 @@ def main(
216220
model_config = create_model_config(
217221
model_name=model_selection.model_name,
218222
batch_size=model_selection.batch_size,
223+
batch_sizes=model_selection.batch_sizes,
219224
num_float_features=model_selection.num_float_features,
220225
feature_pooling_avg=model_selection.feature_pooling_avg,
221226
use_offsets=model_selection.use_offsets,
@@ -282,6 +287,15 @@ def runner(
282287
compute_device=ctx.device.type,
283288
)
284289

290+
batch_sizes = model_config.batch_sizes
291+
292+
if batch_sizes is None:
293+
batch_sizes = [model_config.batch_size] * run_option.num_batches
294+
else:
295+
assert (
296+
len(batch_sizes) == run_option.num_batches
297+
), "The length of batch_sizes must match the number of batches."
298+
285299
# Create a planner for sharding based on the specified type
286300
planner = generate_planner(
287301
planner_type=run_option.planner_type,
@@ -290,16 +304,15 @@ def runner(
290304
weighted_tables=weighted_tables,
291305
sharding_type=run_option.sharding_type,
292306
compute_kernel=run_option.compute_kernel,
293-
num_batches=run_option.num_batches,
294-
batch_size=model_config.batch_size,
307+
batch_sizes=batch_sizes,
295308
pooling_factors=run_option.pooling_factors,
296309
num_poolings=run_option.num_poolings,
297310
)
298311
bench_inputs = generate_data(
299312
tables=tables,
300313
weighted_tables=weighted_tables,
301314
model_config=model_config,
302-
num_batches=run_option.num_batches,
315+
batch_sizes=batch_sizes,
303316
)
304317

305318
# Prepare fused_params for sparse optimizer
@@ -329,14 +342,6 @@ def runner(
329342
dense_weight_decay=run_option.dense_weight_decay,
330343
planner=planner,
331344
)
332-
pipeline = generate_pipeline(
333-
pipeline_type=pipeline_config.pipeline,
334-
emb_lookup_stream=pipeline_config.emb_lookup_stream,
335-
model=sharded_model,
336-
opt=optimizer,
337-
device=ctx.device,
338-
)
339-
pipeline.progress(iter(bench_inputs))
340345

341346
def _func_to_benchmark(
342347
bench_inputs: List[ModelInput],
@@ -350,20 +355,47 @@ def _func_to_benchmark(
350355
except StopIteration:
351356
break
352357

353-
result = benchmark_func(
354-
name=type(pipeline).__name__,
355-
bench_inputs=bench_inputs, # pyre-ignore
356-
prof_inputs=bench_inputs, # pyre-ignore
357-
num_benchmarks=5,
358-
num_profiles=2,
359-
profile_dir=run_option.profile,
360-
world_size=run_option.world_size,
361-
func_to_benchmark=_func_to_benchmark,
362-
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
363-
rank=rank,
358+
# Run comparison if apply_jit is True, otherwise run single benchmark
359+
jit_configs = (
360+
[(True, "WithJIT"), (False, "WithoutJIT")]
361+
if pipeline_config.apply_jit
362+
else [(False, "")]
364363
)
364+
results = []
365+
366+
for apply_jit, jit_suffix in jit_configs:
367+
pipeline = generate_pipeline(
368+
pipeline_type=pipeline_config.pipeline,
369+
emb_lookup_stream=pipeline_config.emb_lookup_stream,
370+
model=sharded_model,
371+
opt=optimizer,
372+
device=ctx.device,
373+
apply_jit=apply_jit,
374+
)
375+
pipeline.progress(iter(bench_inputs))
376+
377+
name = (
378+
f"{type(pipeline).__name__}{jit_suffix}"
379+
if jit_suffix
380+
else type(pipeline).__name__
381+
)
382+
result = benchmark_func(
383+
name=name,
384+
bench_inputs=bench_inputs, # pyre-ignore
385+
prof_inputs=bench_inputs, # pyre-ignore
386+
num_benchmarks=5,
387+
num_profiles=2,
388+
profile_dir=run_option.profile,
389+
world_size=run_option.world_size,
390+
func_to_benchmark=_func_to_benchmark,
391+
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
392+
rank=rank,
393+
)
394+
results.append(result)
395+
365396
if rank == 0:
366-
print(result)
397+
for result in results:
398+
print(result)
367399

368400

369401
if __name__ == "__main__":

0 commit comments

Comments
 (0)