Skip to content

Commit 874cb2a

Browse files
committed
debug logs
1 parent 6c66033 commit 874cb2a

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import math
44
import pathlib
5+
import time
56
import typing
67
import warnings
78

@@ -420,6 +421,7 @@ def __getitem__(self, index: int) -> typing.Any:
420421
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
421422
"""
422423
self._lazy_load()
424+
start_time = time.perf_counter()
423425

424426
if self._parameters.use_preference_loss_spans:
425427
if index < self._unshuffled_documents:
@@ -649,6 +651,13 @@ def __getitem__(self, index: int) -> typing.Any:
649651
image_positions = np.array(image_positions) if image_positions else None
650652
Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens)
651653

654+
data_time = (time.perf_counter() - start_time) * 1000
655+
if data_time > 100:
656+
logger.warning(
657+
f"Data loading took {data_time:,.2f} ms for {image_tokens_added} image tokens and "
658+
f"{text_tokens_added} text tokens. {len(images) if images else 0} images and {len(token_ids)} total tokens."
659+
)
660+
652661
return GPTSample(
653662
token_ids=token_ids,
654663
loss_masking_spans=loss_masking_spans,

fast_llm/models/gpt/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
import typing
34

45
import torch
@@ -332,11 +333,14 @@ def preprocess(
332333
batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration
333334
)
334335

336+
start_time = time.perf_counter()
335337
# TODO: Do things work with >1?
336338
Assert.eq(len(reference_batch), len(preprocessed_meta), 1)
337339
for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch):
338340
reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration)
339341
reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"]
342+
elapsed_time = (time.perf_counter() - start_time) * 1000
343+
logger.info(f"Ref model {name} took {elapsed_time:.2f} ms for {len(reference_batch)} sequences.")
340344

341345
token_ids = batch.token_ids
342346
if sequence_first:

0 commit comments

Comments
 (0)