3131
3232def run_lr_finder (cfg , datamanager , model , optimizer , scheduler , classes ,
3333 rebuild_model = True , gpu_num = 1 , split_models = False ):
34- if rebuild_model :
35- tmp_model = model
36- else :
37- tmp_model = deepcopy (model )
34+ if not rebuild_model :
35+ backup_model = deepcopy (model )
3836
39- engine = build_engine (cfg , datamanager , tmp_model , optimizer , scheduler , initial_lr = cfg .train .lr )
37+ engine = build_engine (cfg , datamanager , model , optimizer , scheduler , initial_lr = cfg .train .lr )
4038 lr_finder = LrFinder (engine = engine , ** lr_finder_run_kwargs (cfg ))
4139 aux_lr = lr_finder .process ()
4240
@@ -54,16 +52,18 @@ def run_lr_finder(cfg, datamanager, model, optimizer, scheduler, classes,
5452 set_random_seed (cfg .train .seed , cfg .train .deterministic )
5553 datamanager = build_datamanager (cfg , classes )
5654 num_train_classes = datamanager .num_train_pids
55+
5756 if rebuild_model :
58- model = torchreid .models .build_model (** model_kwargs (cfg , num_train_classes ))
57+ backup_model = torchreid .models .build_model (** model_kwargs (cfg , num_train_classes ))
5958 num_aux_models = len (cfg .mutual_learning .aux_configs )
60- model , _ = put_main_model_on_the_device (model , cfg .use_gpu , gpu_num , num_aux_models , split_models )
61- optimizer = torchreid .optim .build_optimizer (model , ** optimizer_kwargs (cfg ))
59+ backup_model , _ = put_main_model_on_the_device (backup_model , cfg .use_gpu , gpu_num , num_aux_models , split_models )
60+
61+ optimizer = torchreid .optim .build_optimizer (backup_model , ** optimizer_kwargs (cfg ))
6262 scheduler = torchreid .optim .build_lr_scheduler (optimizer = optimizer ,
6363 num_iter = datamanager .num_iter ,
6464 ** lr_scheduler_kwargs (cfg ))
6565
66- return cfg .train .lr
66+ return cfg .train .lr , backup_model , optimizer , scheduler
6767
6868
6969def run_training (cfg , datamanager , model , optimizer , scheduler , extra_device_ids , init_lr ,
0 commit comments