From 0e9b616cebaf8043a684462df624a10b5fc6f11b Mon Sep 17 00:00:00 2001 From: mbort Date: Thu, 1 Apr 2021 15:23:54 +0200 Subject: [PATCH 1/2] fix: 'resume' from a snapshot continue from latest trained kimg. --- training/training_loop.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/training/training_loop.py b/training/training_loop.py index 14836ad2e..5f8370410 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -7,6 +7,7 @@ # license agreement from NVIDIA CORPORATION is strictly prohibited. import os +import re import time import copy import json @@ -152,6 +153,7 @@ def training_loop( G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. + cur_nimg = 0 if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: @@ -159,6 +161,12 @@ def training_loop( for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + # resume from a snapshot file(`network-snapshot-.pkl`) continues from + # where it stopped.. + match = re.match(r"^.*(network-snapshot-)(\d+)(.pkl)$", resume_pkl, re.IGNORECASE) + if match: + cur_nimg = int(match.group(2)) * 1000 + # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) @@ -245,7 +253,7 @@ def training_loop( if rank == 0: print(f'Training for {total_kimg} kimg...') print() - cur_nimg = 0 + cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() From 464cef9a2744dedf7c880e25799d28520d258462 Mon Sep 17 00:00:00 2001 From: mbort Date: Tue, 27 Apr 2021 19:39:57 +0200 Subject: [PATCH 2/2] fix: initial progress when resuming a training. --- training/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/training_loop.py b/training/training_loop.py index 5f8370410..f83f04a6e 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -260,7 +260,7 @@ def training_loop( maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: - progress_fn(0, total_kimg) + progress_fn(cur_nimg // 1000, total_kimg) while True: # Fetch training data.