Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import re
import time
import copy
import json
Expand Down Expand Up @@ -152,13 +153,20 @@ def training_loop(
G_ema = copy.deepcopy(G).eval()

# Resume from existing pickle.
cur_nimg = 0
if (resume_pkl is not None) and (rank == 0):
print(f'Resuming from "{resume_pkl}"')
with dnnlib.util.open_url(resume_pkl) as f:
resume_data = legacy.load_network_pkl(f)
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
misc.copy_params_and_buffers(resume_data[name], module, require_all=False)

# resume from a snapshot file(`network-snapshot-<INT>.pkl`) continues from
# where it stopped..
match = re.match(r"^.*(network-snapshot-)(\d+)(.pkl)$", resume_pkl, re.IGNORECASE)
if match:
cur_nimg = int(match.group(2)) * 1000

# Print network summary tables.
if rank == 0:
z = torch.empty([batch_gpu, G.z_dim], device=device)
Expand Down Expand Up @@ -245,14 +253,14 @@ def training_loop(
if rank == 0:
print(f'Training for {total_kimg} kimg...')
print()
cur_nimg = 0

cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
batch_idx = 0
if progress_fn is not None:
progress_fn(0, total_kimg)
progress_fn(cur_nimg // 1000, total_kimg)
while True:

# Fetch training data.
Expand Down