-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Problem Overview
I train my model on the same dataset in two different setups: A) single-gpu, B) multi-gpu. The former leads to deterministic results, the latter leads to non-deterministic results. The moment I replace the tf.data.experimental.sample_from_datasets
API call with a direct call to tf.data.Datasets
, B also becomes determinsitic.
Environment
Python: 3.7
Cuda: 11.2
Tensorflow: 2.4.1
Code
Relevant API: https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/data/experimental/sample_from_datasets
def load_resampled_data(dataset_dir: str, split: str, batch_size: int, prepare_example: Callable,
distribution=Distribution.DEFAULT) -> tf.data.Dataset:
"""
Load the samples in dataset_dir in a shuffled order
at per-class sampling rates determined by distribution.
:param dataset_dir: Path to the dataset directory.
:param split: One of the values in constants.VALID_SPLITS.
:param batch_size: Number of samples per batch.
:param prepare_example: Function to apply to every sample.
:param distribution: Distribution enum indicating the
distribution over label classes for each epoch.
"""
assert split in constants.VALID_SPLITS
class_datasets = []
tf_records_dir = os.path.join(dataset_dir, split)
# Load class cardinality information.
class_cardinality_json = os.path.join(tf_records_dir,
constants.CLASS_CARDINALITY_JSON)
with file_py.open(class_cardinality_json, 'r') as f:
class_cardinality = json.load(f)
# Determine train-time distribution.
class_distribution = _get_class_distribution(class_cardinality,
distribution)
assert round(sum(class_distribution.values()), 2) == 1.0
print("Train-time class distribution:", class_distribution)
# Load class-based tf records with re-sampling.
resampled_distribution = []
for class_name, class_weight in class_distribution.items():
tf_record = os.path.join(tf_records_dir, f"{class_name}.tf_record")
class_dataset = tf.data.TFRecordDataset(tf_record)
assert class_cardinality[class_name] > 0, class_cardinality
class_dataset = class_dataset.shuffle(
min(class_cardinality[class_name], MAX_SHUFFLE_BUFFER_SIZE),
seed=constants.SEED,
reshuffle_each_iteration=False)
class_datasets.append(class_dataset.repeat())
resampled_distribution.append(class_weight)
dataset_cardinality = int(class_cardinality[REFERENCE_CLASS] /
class_distribution[REFERENCE_CLASS])
dataset = tf.data.experimental.sample_from_datasets(
class_datasets, resampled_distribution, seed=constants.SEED)
# Elements cannot be processed in parallel because
# of the stateful non-determinism in the data augmentations.
dataset = dataset.map(
prepare_example, num_parallel_calls=1, deterministic=True)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset.prefetch(1), dataset_cardinality
I cannot provide the full code I use due to it being proprietary, but here is the data loading portion. If more information is needed to root cause this, let me know, and I will see what I can do to provide it. FYI the main code sets all the seeds correctly and disables horovod fusion as suggested by the repo README.
Thanks a lot for the great work on making Tensorflow deterministic. It, along with the documentation provided, has been incredibly useful in my day-to-day work.