From 59dc807bd67053485729b5bd1efc7e34b660aff2 Mon Sep 17 00:00:00 2001 From: CryptoSalamander Date: Fri, 21 Nov 2025 08:25:53 +0000 Subject: [PATCH 1/5] Validate tokenizer and model alignment before training --- torchtitan/models/utils.py | 42 ++++++++++++++++++++++++++++++++++++++ torchtitan/train.py | 3 +++ 2 files changed, 45 insertions(+) diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index addfa17421..6cab8ad056 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -468,3 +468,45 @@ def get_moe_model_nparams_and_flops( nparams = nparams - nparams_embedding return nparams, num_flops_per_token + + +def validate_tokenizer_model_alignment( + tokenizer: "BaseTokenizer | None", + model_args: "BaseModelArgs", +) -> None: + """ + Validate that tokenizer configuration matches model configuration. + + Args: + tokenizer: Tokenizer instance to validate. Can be None. + model_args: Model arguments object containing configuration to validate against. + + Raises: + ValueError: If tokenizer and model configurations don't match. + """ + if tokenizer is None: + return + + # Validate vocab_size + if hasattr(model_args, "vocab_size"): + tokenizer_vocab_size = tokenizer.get_vocab_size() + model_vocab_size = model_args.vocab_size + if tokenizer_vocab_size != model_vocab_size: + raise ValueError( + f"Tokenizer vocab_size ({tokenizer_vocab_size}) does not match " + f"model vocab_size ({model_vocab_size}). " + f"This mismatch will cause training errors. " + f"Please ensure the tokenizer and model configuration are aligned." + ) + + # Validate eos_id + if hasattr(model_args, "eos_id"): + tokenizer_eos_id = getattr(tokenizer, "eos_id", None) + model_eos_id = model_args.eos_id + if tokenizer_eos_id is not None and tokenizer_eos_id != model_eos_id: + raise ValueError( + f"Tokenizer eos_id ({tokenizer_eos_id}) does not match " + f"model eos_id ({model_eos_id}). " + f"This mismatch may cause training errors. " + f"Please ensure the tokenizer and model configuration are aligned." + ) diff --git a/torchtitan/train.py b/torchtitan/train.py index d157a3a307..082b24e55a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -25,6 +25,7 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.models.utils import validate_tokenizer_model_alignment from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -134,6 +135,8 @@ def __init__(self, job_config: JobConfig): model_args.update_from_config(job_config) self.model_args = model_args + validate_tokenizer_model_alignment(self.tokenizer, model_args) + logger.info( f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) From faa2122b35bafe6f544558821f9a304e8747d57e Mon Sep 17 00:00:00 2001 From: CryptoSalamander Date: Sat, 22 Nov 2025 02:01:35 +0000 Subject: [PATCH 2/5] Update tokenizer-model alignment validation logic --- torchtitan/models/utils.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index 6cab8ad056..384669b3a4 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -475,38 +475,27 @@ def validate_tokenizer_model_alignment( model_args: "BaseModelArgs", ) -> None: """ - Validate that tokenizer configuration matches model configuration. + Validate that tokenizer configuration is compatible with model configuration. Args: tokenizer: Tokenizer instance to validate. Can be None. model_args: Model arguments object containing configuration to validate against. Raises: - ValueError: If tokenizer and model configurations don't match. + ValueError: If tokenizer vocab_size exceeds model vocab_size, which would + cause index out of bounds errors during training. """ if tokenizer is None: return - # Validate vocab_size if hasattr(model_args, "vocab_size"): tokenizer_vocab_size = tokenizer.get_vocab_size() model_vocab_size = model_args.vocab_size - if tokenizer_vocab_size != model_vocab_size: + if model_vocab_size < tokenizer_vocab_size: raise ValueError( - f"Tokenizer vocab_size ({tokenizer_vocab_size}) does not match " - f"model vocab_size ({model_vocab_size}). " - f"This mismatch will cause training errors. " - f"Please ensure the tokenizer and model configuration are aligned." - ) - - # Validate eos_id - if hasattr(model_args, "eos_id"): - tokenizer_eos_id = getattr(tokenizer, "eos_id", None) - model_eos_id = model_args.eos_id - if tokenizer_eos_id is not None and tokenizer_eos_id != model_eos_id: - raise ValueError( - f"Tokenizer eos_id ({tokenizer_eos_id}) does not match " - f"model eos_id ({model_eos_id}). " - f"This mismatch may cause training errors. " - f"Please ensure the tokenizer and model configuration are aligned." + f"Model vocab_size ({model_vocab_size}) is smaller than " + f"tokenizer vocab_size ({tokenizer_vocab_size}). " + f"This will cause index out of bounds errors during training. " + f"The model's embedding layer must be at least as large as the " + f"tokenizer's vocabulary size." ) From 50b10b8b0f22af7ad1db21e7da33087c49aeb526 Mon Sep 17 00:00:00 2001 From: CryptoSalamander Date: Sat, 22 Nov 2025 02:48:41 +0000 Subject: [PATCH 3/5] Remove unused `eos_id` from model args --- .../deterministic_vllm_rl/models/qwen3/model_vllm_compat.py | 1 - torchtitan/experiments/deterministic_vllm_rl/simple_rl.py | 1 - torchtitan/experiments/transformers_backend/model/args.py | 1 - torchtitan/models/llama3/model/args.py | 1 - torchtitan/models/qwen3/model/args.py | 1 - torchtitan/models/qwen3/model/model.py | 1 - 6 files changed, 6 deletions(-) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py index dd84665091..a24ef640e9 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py @@ -288,7 +288,6 @@ def __init__(self, model_args: Qwen3ModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id self.head_dim = model_args.head_dim self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py index ffc7d52eb0..97ebea0743 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py @@ -332,7 +332,6 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, - eos_id=getattr(hf_config, "eos_token_id", 151645), ) # state_dict is in standard TorchTitan format (w1, w2, w3) diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_backend/model/args.py index 25ab328f15..02fb00cdc1 100644 --- a/torchtitan/experiments/transformers_backend/model/args.py +++ b/torchtitan/experiments/transformers_backend/model/args.py @@ -54,7 +54,6 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): "n_kv_heads": "num_key_value_heads", "norm_eps": "rms_norm_eps", "max_seq_len": "max_position_embeddings", - "eos_id": "eos_token_id", } } diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index d83fb83102..f91414950b 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -45,7 +45,6 @@ class TransformerModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" - eos_id: int = 0 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..a80bfdb26f 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -38,7 +38,6 @@ class Qwen3ModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" - eos_id: int = 151645 enable_weight_tying: bool = False diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a4f0a59844..775ba8ea1c 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -384,7 +384,6 @@ def __init__(self, model_args: Qwen3ModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id self.head_dim = model_args.head_dim self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) From 883c281b5bb27684a1ef1dbe1046e7926a925d9b Mon Sep 17 00:00:00 2001 From: CryptoSalamander Date: Thu, 4 Dec 2025 17:21:36 +0900 Subject: [PATCH 4/5] Revert removal of eos_id_args --- .../deterministic_vllm_rl/models/qwen3/model_vllm_compat.py | 1 + torchtitan/experiments/deterministic_vllm_rl/simple_rl.py | 1 + torchtitan/experiments/transformers_backend/model/args.py | 1 + torchtitan/models/llama3/model/args.py | 1 + torchtitan/models/qwen3/model/args.py | 1 + torchtitan/models/qwen3/model/model.py | 1 + 6 files changed, 6 insertions(+) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py index a24ef640e9..dd84665091 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py @@ -288,6 +288,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id self.head_dim = model_args.head_dim self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py index 97ebea0743..ffc7d52eb0 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py @@ -332,6 +332,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, + eos_id=getattr(hf_config, "eos_token_id", 151645), ) # state_dict is in standard TorchTitan format (w1, w2, w3) diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_backend/model/args.py index 02fb00cdc1..25ab328f15 100644 --- a/torchtitan/experiments/transformers_backend/model/args.py +++ b/torchtitan/experiments/transformers_backend/model/args.py @@ -54,6 +54,7 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): "n_kv_heads": "num_key_value_heads", "norm_eps": "rms_norm_eps", "max_seq_len": "max_position_embeddings", + "eos_id": "eos_token_id", } } diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index f91414950b..d83fb83102 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -45,6 +45,7 @@ class TransformerModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" + eos_id: int = 0 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index a80bfdb26f..0c700ce2e0 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -38,6 +38,7 @@ class Qwen3ModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" + eos_id: int = 151645 enable_weight_tying: bool = False diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 775ba8ea1c..a4f0a59844 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -384,6 +384,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id self.head_dim = model_args.head_dim self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) From cdc1f1b3f31e6568de9ce50bbc6a48135a09916d Mon Sep 17 00:00:00 2001 From: CryptoSalamander Date: Thu, 4 Dec 2025 17:26:52 +0900 Subject: [PATCH 5/5] Rename validate_tokenizer_model_alignment to compatibility --- torchtitan/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index 384669b3a4..053eda6c15 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -470,7 +470,7 @@ def get_moe_model_nparams_and_flops( return nparams, num_flops_per_token -def validate_tokenizer_model_alignment( +def validate_tokenizer_model_compatibility( tokenizer: "BaseTokenizer | None", model_args: "BaseModelArgs", ) -> None: