Skip to content
7 changes: 3 additions & 4 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_generate(
logger.info(f"Init model on init_device: {init_device}")
model = train_spec.model_cls(model_args)

world_mesh = None
parallel_dims = None
# Init distributed env
if world_size > 1:
dist_utils.init_distributed(config.comm)
Expand All @@ -127,15 +127,14 @@ def test_generate(
etp=1,
world_size=world_size,
)
world_mesh = parallel_dims.world_mesh

# apply_tp (with Sequence Parallel) on unevenly sharded
# sequences would require https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])
apply_tp_minus_sp(model, parallel_dims.get_mesh("tp"))

debug_config = DebugConfig(seed=seed, deterministic=deterministic)
dist_utils.set_determinism(
world_mesh=world_mesh,
parallel_dims=parallel_dims,
device=device,
debug_config=debug_config,
distinct_seed_mesh_dims=["pp"],
Expand Down
Loading
Loading