Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,43 @@ def build_mse_loss(job_config: JobConfig, **kwargs):
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
return loss_fn


def moe_loss(
pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
labels: torch.Tensor,
loss_fn: LossFunction,
) -> torch.Tensor:
"""Sequence-wise (or batch-wise) auxiliary load balance loss function for MoE
model training.
"""
if isinstance(pred, tuple):
pred, load_balance_loss = pred
loss = loss_fn(pred, labels)
# Add auxiliary loss to the computation graph for gradients in the backward pass,
# but cancel out its numeric value so the forward pass only logs language model task loss.
loss = loss + (load_balance_loss - load_balance_loss.detach())
else:
loss = loss_fn(pred, labels)
return loss


def moe_loss_wrap(original_build_fn):
"""
Wraps a loss builder function. It builds the base loss function first,
then wraps it with the 'moe_loss' logic before returning it.
"""

@functools.wraps(original_build_fn) # Preserves name/docstring of original
def wrapper(job_config, **kwargs):
# 1. Create the base loss function (e.g., standard CrossEntropy)
# We pass job_config and kwargs through exactly as the original expects
base_loss_fn = original_build_fn(job_config, **kwargs)

# 2. Apply the MoE wrapper immediately
# This binds 'base_loss_fn' to the 'loss_fn' argument of 'moe_loss'
final_loss_fn = functools.partial(moe_loss, loss_fn=base_loss_fn)

return final_loss_fn

return wrapper
8 changes: 4 additions & 4 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool:
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if transformer_block.moe_enabled:
# Assumption: load_balance_coeff is set universally on all moe blocks.
return bool(transformer_block.moe.load_balance_coeff)
# Assumption: moe_aux_loss_free_bias_coeff is set universally on all moe blocks.
return bool(transformer_block.moe.moe_aux_loss_free_bias_coeff)
return False

# for MoE auxiliary-loss-free load balancing
Expand All @@ -366,7 +366,7 @@ def _update_expert_bias(
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
continue
if transformer_block.moe.load_balance_coeff is None:
if transformer_block.moe.moe_aux_loss_free_bias_coeff is None:
return
tokens_per_expert = transformer_block.moe.tokens_per_expert
if _is_recomputation_enabled(transformer_block):
Expand Down Expand Up @@ -401,7 +401,7 @@ def _update_expert_bias(

# update the expert bias
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
expert_bias_delta = moe.load_balance_coeff * torch.sign(
expert_bias_delta = moe.moe_aux_loss_free_bias_coeff * torch.sign(
tokens_per_expert.mean() - tokens_per_expert
)
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
Expand Down
15 changes: 15 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ class LRScheduler:
"""


@dataclass
class ExtraLosses:
load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
"""Type of load balance loss to use"""

load_balance_loss_weight: float = 0
"""Weight of load balance loss"""

moe_aux_loss_free_bias_coeff: float | None = 1e-3
"""The coefficient of the bias update for aux-loss-free load balancing"""


@dataclass
class Training:
dataset: str = "c4_test"
Expand Down Expand Up @@ -226,6 +238,9 @@ class Training:
steps: int = 10000
"""How many train steps to run"""

extra_losses: ExtraLosses = field(default_factory=ExtraLosses)
"""If we have multiple of losses, we can configure their weights here"""

enable_cpu_offload: bool = False
"""
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/experiments/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
moe_aux_loss_free_bias_coeff=1e-3,
),
attn_mask_type="causal",
),
Expand All @@ -53,7 +53,7 @@
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
moe_aux_loss_free_bias_coeff=1e-3,
),
),
"120b": GptOssModelArgs(
Expand All @@ -67,7 +67,7 @@
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
moe_aux_loss_free_bias_coeff=1e-3,
),
),
}
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
Expand Down Expand Up @@ -167,6 +167,6 @@ def get_train_spec() -> TrainSpec:
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_loss_fn=moe_loss_wrap(build_cross_entropy_loss),
state_dict_adapter=DeepSeekV3StateDictAdapter,
)
7 changes: 7 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

losses_config = job_config.training.extra_losses
self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type
self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight
self.moe_args.moe_aux_loss_free_bias_coeff = (
losses_config.moe_aux_loss_free_bias_coeff
)

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
Expand Down
25 changes: 20 additions & 5 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Expand All @@ -323,10 +324,15 @@ def forward(
"""
x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x))
accumulated_load_balance_loss = (
accumulated_load_balance_loss + load_balance_loss
)
else:
x = x + self.feed_forward(self.ffn_norm(x))
return x
ffn_moe_output = self.feed_forward(self.ffn_norm(x))

x = x + ffn_moe_output
return x, accumulated_load_balance_loss

def init_weights(self, buffer_device: torch.device):
for norm in (self.attention_norm, self.ffn_norm):
Expand Down Expand Up @@ -410,6 +416,7 @@ def get_attention_masks(
def forward(
self,
tokens: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor | None = None,
attention_masks: AttentionMasksType | None = None,
):
"""
Expand All @@ -427,8 +434,16 @@ def forward(

h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens

accumulated_load_balance_loss = (
torch.zeros((), device=h.device, dtype=torch.float32)
if accumulated_load_balance_loss is None
else accumulated_load_balance_loss
)

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks)
h, accumulated_load_balance_loss = layer(
h, self.freqs_cis, accumulated_load_balance_loss, attention_masks
)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output
return output, accumulated_load_balance_loss
4 changes: 2 additions & 2 deletions torchtitan/models/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_train_spec() -> TrainSpec:
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_loss_fn=moe_loss_wrap(build_cross_entropy_loss),
build_validator_fn=build_validator,
state_dict_adapter=Llama4StateDictAdapter,
)
7 changes: 7 additions & 0 deletions torchtitan/models/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

losses_config = job_config.training.extra_losses
self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type
self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight
self.moe_args.moe_aux_loss_free_bias_coeff = (
losses_config.moe_aux_loss_free_bias_coeff
)

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
Expand Down
Loading