File tree Expand file tree Collapse file tree 2 files changed +13
-0
lines changed Expand file tree Collapse file tree 2 files changed +13
-0
lines changed Original file line number Diff line number Diff line change 22import logging
33import math
44import pathlib
5+ import time
56import typing
67import warnings
78
@@ -420,6 +421,7 @@ def __getitem__(self, index: int) -> typing.Any:
420421 The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
421422 """
422423 self ._lazy_load ()
424+ start_time = time .perf_counter ()
423425
424426 if self ._parameters .use_preference_loss_spans :
425427 if index < self ._unshuffled_documents :
@@ -649,6 +651,13 @@ def __getitem__(self, index: int) -> typing.Any:
649651 image_positions = np .array (image_positions ) if image_positions else None
650652 Assert .eq (len (token_ids ), self ._parameters .sequence_length + self ._parameters .extra_tokens )
651653
654+ data_time = (time .perf_counter () - start_time ) * 1000
655+ if data_time > 100 :
656+ logger .warning (
657+ f"Data loading took { data_time :,.2f} ms for { image_tokens_added } image tokens and "
658+ f"{ text_tokens_added } text tokens. { len (images ) if images else 0 } images and { len (token_ids )} total tokens."
659+ )
660+
652661 return GPTSample (
653662 token_ids = token_ids ,
654663 loss_masking_spans = loss_masking_spans ,
Original file line number Diff line number Diff line change 11import logging
2+ import time
23import typing
34
45import torch
@@ -332,11 +333,14 @@ def preprocess(
332333 batch , reference_preprocessed_meta , phase = PhaseType .inference , iteration = iteration
333334 )
334335
336+ start_time = time .perf_counter ()
335337 # TODO: Do things work with >1?
336338 Assert .eq (len (reference_batch ), len (preprocessed_meta ), 1 )
337339 for i , (reference_tokens , reference_kwargs ) in enumerate (reference_batch ):
338340 reference_model .forward (reference_tokens , reference_kwargs , iteration = iteration )
339341 reference_logits [i ][f"{ name } _logits" ] = reference_kwargs ["logits" ]
342+ elapsed_time = (time .perf_counter () - start_time ) * 1000
343+ logger .info (f"Ref model { name } took { elapsed_time :.2f} ms for { len (reference_batch )} sequences." )
340344
341345 token_ids = batch .token_ids
342346 if sequence_first :
You can’t perform that action at this time.
0 commit comments