44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import ctypes
78import importlib
89import os
10+ import signal
911import time
1012from datetime import timedelta
1113from typing import Any , Generator , Iterable , Optional
3234 maybe_enable_profiling ,
3335)
3436
37+ c_globals = ctypes .CDLL (None ) # POSIX
38+
3539
3640class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
41+ torch_profiler : torch .profiler .profile | None = None
42+
3743 # core configs
3844 job_config : JobConfig
3945 parallel_dims : ParallelDims
@@ -580,13 +586,14 @@ def train(self):
580586 if not self .ft_manager .enabled
581587 else f"replica_{ self .ft_manager .replica_id } "
582588 )
589+ self .torch_profiler = maybe_enable_profiling (
590+ job_config .profiling ,
591+ global_step = self .step ,
592+ base_folder = job_config .job .dump_folder ,
593+ leaf_folder = leaf_folder ,
594+ )
595+
583596 with (
584- maybe_enable_profiling (
585- job_config .profiling ,
586- global_step = self .step ,
587- base_folder = job_config .job .dump_folder ,
588- leaf_folder = leaf_folder ,
589- ) as torch_profiler ,
590597 maybe_enable_memory_snapshot (
591598 job_config .profiling ,
592599 global_step = self .step ,
@@ -610,6 +617,15 @@ def train(self):
610617 ),
611618 ),
612619 ):
620+ if self .torch_profiler :
621+
622+ @ctypes .CFUNCTYPE (None , ctypes .c_int )
623+ def sigabrt_handler (signal ):
624+ logger .info ("SIGABRT received. Stopping profiler" )
625+ self .torch_profiler .export_chrome_trace ("trace.json" )
626+
627+ c_globals .signal (signal .SIGABRT , sigabrt_handler )
628+
613629 data_iterator = self .batch_generator (self .dataloader )
614630 while self .should_continue_training ():
615631 self .step += 1
@@ -633,8 +649,8 @@ def train(self):
633649 self .validator .validate (self .model_parts , self .step )
634650
635651 # signal the profiler that the next profiling step has started
636- if torch_profiler :
637- torch_profiler .step ()
652+ if self . torch_profiler :
653+ self . torch_profiler .step ()
638654 if memory_profiler :
639655 memory_profiler .step ()
640656
@@ -692,10 +708,12 @@ def close(self) -> None:
692708 else :
693709 trainer .train ()
694710 except Exception :
711+ logger .info ("Torchtitan training threw an exception" )
695712 if trainer :
696713 trainer .close ()
697714 raise
698715 else :
716+ logger .info ("Torchtitan training completed" )
699717 trainer .close ()
700718 torch .distributed .destroy_process_group ()
701719 logger .info ("Process group destroyed" )
0 commit comments