Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info={},
data_type="text",
bucket_size=2048,
bucket_size_init=-1,
Expand All @@ -144,6 +145,7 @@ def __init__(
self.transforms = transforms
self.vocabs = vocabs
self.corpora_info = corpora_info
self.resume_corpora_info = resume_corpora_info
self.task = task
self.init_iterators = False
self.batch_size = batch_size
Expand Down Expand Up @@ -171,7 +173,17 @@ def __init__(

@classmethod
def from_opt(
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
cls,
corpora,
transforms,
vocabs,
opt,
task,
copy,
device,
resume_corpora_info={},
stride=1,
offset=0,
):
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
corpora_info = {}
Expand Down Expand Up @@ -206,6 +218,7 @@ def from_opt(
opt.batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info=resume_corpora_info,
data_type=opt.data_type,
bucket_size=bucket_size,
bucket_size_init=bucket_size_init,
Expand Down Expand Up @@ -388,6 +401,7 @@ def build_dynamic_dataset_iter(
vocabs,
copy=False,
task=CorpusTask.TRAIN,
resume_corpora_info={},
stride=1,
offset=0,
src=None,
Expand All @@ -412,7 +426,14 @@ def build_dynamic_dataset_iter(
advance to avoid the GPU waiting during the refilling of the bucket.
"""
transforms = make_transforms(opt, transforms_cls, vocabs)
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
corpora = get_corpora(
opt,
task,
src=src,
tgt=tgt,
align=align,
resume_corpora_info=resume_corpora_info,
)
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
Expand Down Expand Up @@ -442,6 +463,7 @@ def build_dynamic_dataset_iter(
vocabs,
opt,
task,
resume_corpora_info=resume_corpora_info,
copy=copy,
stride=stride,
offset=offset,
Expand Down
46 changes: 38 additions & 8 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None
self,
name,
src,
tgt,
align=None,
n_src_feats=0,
src_feats_defaults=None,
line_number_to_resume=0,
):
"""Initialize src & tgt side file path."""
self.id = name
Expand All @@ -108,6 +115,12 @@ def __init__(
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.line_number_to_resume = line_number_to_resume
self.can_read_file = False

def activate_reading_mode(self, line_number):
if line_number >= self.line_number_to_resume:
self.can_read_file = True

def load(self, offset=0, stride=1):
"""
Expand All @@ -116,7 +129,7 @@ def load(self, offset=0, stride=1):
`stride` example, starting from `offset`.
"""

def make_ex(sline, tline, align):
def make_ex(sline, tline, align, line_number):
sline, sfeats = parse_features(
sline,
n_feats=self.n_src_feats,
Expand All @@ -131,6 +144,7 @@ def make_ex(sline, tline, align):
"tgt": tline,
"src_original": sline,
"tgt_original": tline,
"cid_line_number": line_number,
}
if align is not None:
example["align"] = align
Expand All @@ -145,19 +159,25 @@ def make_ex(sline, tline, align):
for i, (sline, tline, align) in enumerate(
itertools.zip_longest(fs, ft, fa)
):
self.activate_reading_mode(line_number=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
yield make_ex(sline, tline, align)
yield make_ex(sline, tline, align, i)
else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(
self.tgt, mode="rb"
) as ft, exfile_open(self.align, mode="rb") as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
self.activate_reading_mode(line_number=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
if tline is not None:
tline = tline.decode("utf-8")
if align is not None:
align = align.decode("utf-8")
yield make_ex(sline.decode("utf-8"), tline, align)
yield make_ex(sline.decode("utf-8"), tline, align, i)

def __str__(self):
cls_name = type(self).__name__
Expand All @@ -169,19 +189,25 @@ def __str__(self):
)


def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
def get_corpora(
opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None, resume_corpora_info={}
):
corpora_dict = {}
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
if corpus_id != CorpusName.VALID:
if corpus_dict.get("path_txt", None) is None:
resume_line = 0
if corpus_id in resume_corpora_info:
resume_line = resume_corpora_info[corpus_id]
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
line_number_to_resume=resume_line,
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
Expand Down Expand Up @@ -244,8 +270,6 @@ def _process(self, stream):
example["src_feats"] = [
feat.strip().split(" ") for feat in example["src_feats"]
]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip().split(" ")
Expand All @@ -258,6 +282,7 @@ def _process(self, stream):
or ("align" in example and example["align"] == 0)
):
# empty example: skip
line_number = example["cid_line_number"]
empty_msg = f"Empty line in {self.cid}#{line_number}."
if self.skip_empty_level == "error":
raise IOError(empty_msg)
Expand All @@ -282,7 +307,12 @@ def __iter__(self):


def build_corpora_iters(
corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0
corpora,
transforms,
corpora_info,
skip_empty_level="warning",
stride=1,
offset=0,
):
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
corpora_iters = dict()
Expand Down
92 changes: 90 additions & 2 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import torch
import re
import subprocess
from collections import deque
import onmt.utils
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
def build_model_saver(
model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id
):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
Expand All @@ -20,6 +24,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
optim,
opt.keep_checkpoint,
opt.save_format,
resume_corpora_info,
device_id,
)
return model_saver
Expand Down Expand Up @@ -81,6 +86,65 @@ def fix_key(s):
return checkpoint


def load_corpora_info(opts, checkpoint):
message_resume_from_beginning = (
"The training will resume from the beginning of each corpus."
)
# Check if resume_from_corpora is True
if not opts.resume_from_corpora:
logger.info(
"No resume from corpora is specified. " + message_resume_from_beginning
)
return {}

# Check if the corpus list from the last training
# and in the new training are identical.
checkpoint_corpora = checkpoint.get("corpus_info", None)
if checkpoint_corpora is None:
logger.info(
"Incoherent info: Some corpora in the last training "
+ "and in the new list do not match. "
+ message_resume_from_beginning
)
return {}

checkpoint_corpus_names = [name for name in checkpoint_corpora]
new_corpus_names = [name for name in opts.data]
if set(checkpoint_corpus_names) != set(new_corpus_names):
logger.info(
"Incoherent info: Some corpora in the last training "
+ "and in the new list do not match. "
+ message_resume_from_beginning
)
return {}

# For each corpus, check if the last line number to resume
# is smaller than or equal to the number of text lines.
message_incoherent_line_number = (
"Incoherent info: text line numbers "
+ "to resume in some corpora exceed their total numbers of lines. "
+ message_resume_from_beginning
)
for c_name in checkpoint_corpora:
number_of_text_lines = int(
subprocess.getoutput(
"wc -l " + opts.data[c_name]["path_src"] + " | awk '{print $1}'"
)
)
if checkpoint_corpora[c_name] > number_of_text_lines - 1:
logger.info(message_incoherent_line_number)
return {}

# To set the text lines to resume, we increase all text lines by 1
# (and return to the beginning if the end is reached)
checkpoint_corpora[c_name] = (
checkpoint_corpora[c_name] + 1
) % number_of_text_lines

logger.info("The training will resume from the saved text line in each corpus.")
return checkpoint_corpora


class ModelSaverBase(object):
"""Base class for model saving operations

Expand All @@ -98,6 +162,7 @@ def __init__(
optim,
keep_checkpoint=-1,
save_format="pytorch",
resume_corpora_info={},
device_id=0,
):
self.base_path = base_path
Expand All @@ -108,14 +173,35 @@ def __init__(
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.corpus_info = resume_corpora_info
self.device_id = device_id

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
if save_format == "safetensors":
self.model_queue = deque([], maxlen=keep_checkpoint)

def save(self, step, moving_average=None):
def update_corpus_info_from_batches(self, batches, distributed=False):
# Update the last text line of each corpus
if batches is not None:
# Gather corpus line numbers to save to checkpoints
batch_cids = sum([batch["cid"] for batch in batches], [])
batch_cid_line_numbers = sum(
[batch["cid_line_number"] for batch in batches], []
)
if distributed:
batch_cids = sum(onmt.utils.distributed.all_gather_list(batch_cids), [])
batch_cid_line_numbers = sum(
onmt.utils.distributed.all_gather_list(batch_cid_line_numbers), []
)
# Save the last processed line number of each corpus
new_corpus_info = {
c_name: cid_line_number
for c_name, cid_line_number in zip(batch_cids, batch_cid_line_numbers)
}
self.corpus_info = {**self.corpus_info, **new_corpus_info}

def save(self, step, moving_average=None, batches=None, distributed=False):
"""Main entry point for model saver

It wraps the `_save` method with checks and apply `keep_checkpoint`
Expand Down Expand Up @@ -266,6 +352,7 @@ def _save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
Expand Down Expand Up @@ -355,6 +442,7 @@ def _st_save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser):
help="If training from a checkpoint then this is the "
"path to the pretrained model's state_dict.",
)
group.add(
"--resume_from_corpora",
"-resume_from_corpora",
action="store_true",
help="If training from a checkpoint and this is set to True "
" then the data generator will resume from the last line of each corpora.",
)
group.add(
"--reset_optim",
"-reset_optim",
Expand Down
Loading