Skip to content

Commit 53f6642

Browse files
authored
Validator integration with current metrics processor for logging (#1395)
Integrated the validator together with metrics processor for better metrics logging. Key changes: - Metrics processor is passed to validator within training loop - Validator can reuse metrics processor's built-in functionalities such as memory profiling, throughput tracking, and tensorboard/wandb logging This is how the new logging looks from terminal: <img width="959" height="374" alt="Screenshot 2025-07-14 at 3 22 56 PM" src="https://github.com/user-attachments/assets/b16a9e00-3ab2-46ed-a42a-0c92d13697cb" />
1 parent 27e3ad8 commit 53f6642

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

torchtitan/components/metrics.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def log(
403403
f"{color.red}step: {step:2} "
404404
f"{color.green}loss: {global_avg_loss:7.4f} "
405405
f"{color.orange}grad_norm: {grad_norm:7.4f} "
406-
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
406+
f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
407407
f"({device_mem_stats.max_reserved_pct:.2f}%) "
408408
f"{color.blue}tps: {round(tps):,} "
409409
f"{color.cyan}tflops: {tflops:,.2f} "
@@ -415,6 +415,39 @@ def log(
415415
self.time_last_log = time.perf_counter()
416416
self.device_memory_monitor.reset_peak_stats()
417417

418+
def log_validation(self, loss: float, step: int):
419+
time_delta = time.perf_counter() - self.time_last_log
420+
421+
device_mem_stats = self.device_memory_monitor.get_peak_stats()
422+
423+
# tokens per second per device, abbreviated as tps
424+
tps = self.ntokens_since_last_log / (
425+
time_delta * self.parallel_dims.non_data_parallel_size
426+
)
427+
428+
metrics = {
429+
"validation_metrics/loss": loss,
430+
"validation_metrics/throughput(tps)": tps,
431+
"validation_metrics/memory/max_active(GiB)": device_mem_stats.max_active_gib,
432+
"validation_metrics/memory/max_active(%)": device_mem_stats.max_active_pct,
433+
"validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
434+
"validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
435+
}
436+
self.logger.log(metrics, step)
437+
438+
color = self.color
439+
logger.info(
440+
f"{color.yellow}validate step: {step:2} "
441+
f"{color.green}loss: {loss:7.4f} "
442+
f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
443+
f"({device_mem_stats.max_reserved_pct:.2f}%) "
444+
f"{color.blue}tps: {round(tps):,}{color.reset}"
445+
)
446+
447+
self.ntokens_since_last_log = 0
448+
self.time_last_log = time.perf_counter()
449+
self.device_memory_monitor.reset_peak_stats()
450+
418451
def close(self):
419452
self.logger.close()
420453

torchtitan/components/validate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from torch.distributed.fsdp import FSDPModule
1212
from torchtitan.components.dataloader import BaseDataLoader
1313
from torchtitan.components.loss import LossFunction
14+
from torchtitan.components.metrics import MetricsProcessor
1415
from torchtitan.components.tokenizer import BaseTokenizer
1516
from torchtitan.config_manager import JobConfig
1617
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
1718
from torchtitan.distributed import ParallelDims, utils as dist_utils
1819
from torchtitan.tools import utils
19-
from torchtitan.tools.logging import logger
2020

2121

2222
class BaseValidator:
@@ -53,6 +53,7 @@ def __init__(
5353
loss_fn: LossFunction,
5454
validation_context: Generator[None, None, None],
5555
maybe_enable_amp: Generator[None, None, None],
56+
metrics_processor: MetricsProcessor,
5657
):
5758
self.job_config = job_config
5859
self.parallel_dims = parallel_dims
@@ -65,11 +66,13 @@ def __init__(
6566
)
6667
self.validation_context = validation_context
6768
self.maybe_enable_amp = maybe_enable_amp
69+
self.metrics_processor = metrics_processor
6870

6971
@torch.no_grad()
7072
def validate(
7173
self,
7274
model_parts: list[nn.Module],
75+
step: int,
7376
) -> dict[str, float]:
7477
# Set model to eval mode
7578
# TODO: currently does not support pipeline parallelism
@@ -89,6 +92,7 @@ def validate(
8992
):
9093
break
9194

95+
self.metrics_processor.ntokens_since_last_log += labels.numel()
9296
for k, v in input_dict.items():
9397
input_dict[k] = v.to(device_type)
9498
inputs = input_dict["input"]
@@ -124,11 +128,9 @@ def validate(
124128
loss, parallel_dims.world_mesh["dp_cp"]
125129
)
126130
else:
127-
global_avg_loss = loss
131+
global_avg_loss = loss.item()
128132

129-
logger.info(
130-
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches"
131-
)
133+
self.metrics_processor.log_validation(loss=global_avg_loss, step=step)
132134

133135
# Reshard after run forward pass
134136
# This is to ensure the model weights are sharded the same way for checkpoint saving.
@@ -149,6 +151,7 @@ def build_validator(
149151
loss_fn: LossFunction,
150152
validation_context: Generator[None, None, None],
151153
maybe_enable_amp: Generator[None, None, None],
154+
metrics_processor: MetricsProcessor | None = None,
152155
) -> BaseValidator:
153156
"""Build a simple validator focused on correctness."""
154157
return Validator(
@@ -160,4 +163,5 @@ def build_validator(
160163
loss_fn=loss_fn,
161164
validation_context=validation_context,
162165
maybe_enable_amp=maybe_enable_amp,
166+
metrics_processor=metrics_processor,
163167
)

torchtitan/tools/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class Color:
134134
white = "\033[37m"
135135
reset = "\033[39m"
136136
orange = "\033[38;2;180;60;0m"
137+
turquoise = "\033[38;2;54;234;195m"
137138

138139

139140
@dataclass(frozen=True)

torchtitan/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def __init__(self, job_config: JobConfig):
336336
loss_fn=self.train_spec.build_loss_fn(job_config),
337337
validation_context=self.train_context,
338338
maybe_enable_amp=self.maybe_enable_amp,
339+
metrics_processor=self.metrics_processor,
339340
)
340341

341342
logger.info(
@@ -530,7 +531,7 @@ def train(self):
530531
self.job_config.validation.enabled
531532
and self.validator.should_validate(self.step)
532533
):
533-
self.validator.validate(self.model_parts)
534+
self.validator.validate(self.model_parts, self.step)
534535

535536
self.checkpointer.save(
536537
self.step, last_step=(self.step == job_config.training.steps)

0 commit comments

Comments
 (0)