Skip to content

tf.data.experimental.sample_from_datasets non-deterministic in multi-gpu. #39

@yanniskar

Description

@yanniskar

Problem Overview

I train my model on the same dataset in two different setups: A) single-gpu, B) multi-gpu. The former leads to deterministic results, the latter leads to non-deterministic results. The moment I replace the tf.data.experimental.sample_from_datasets API call with a direct call to tf.data.Datasets, B also becomes determinsitic.

Environment

Python: 3.7
Cuda: 11.2
Tensorflow: 2.4.1

Code

Relevant API: https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/data/experimental/sample_from_datasets

def load_resampled_data(dataset_dir: str, split: str, batch_size: int, prepare_example: Callable,
                                         distribution=Distribution.DEFAULT) -> tf.data.Dataset:
    """
     Load the samples in dataset_dir in a shuffled order
     at per-class sampling rates determined by distribution.
     :param dataset_dir: Path to the dataset directory.
     :param split: One of the values in constants.VALID_SPLITS.
     :param batch_size: Number of samples per batch.
     :param prepare_example: Function to apply to every sample.
     :param distribution: Distribution enum indicating the
     distribution over label classes for each epoch.
     """
    assert split in constants.VALID_SPLITS
    class_datasets = []
    tf_records_dir = os.path.join(dataset_dir, split)

    # Load class cardinality information.
    class_cardinality_json = os.path.join(tf_records_dir,
        constants.CLASS_CARDINALITY_JSON)
    with file_py.open(class_cardinality_json, 'r') as f:
        class_cardinality = json.load(f)

    # Determine train-time distribution.
    class_distribution = _get_class_distribution(class_cardinality,
        distribution)
    assert round(sum(class_distribution.values()), 2) == 1.0
    print("Train-time class distribution:", class_distribution)

    # Load class-based tf records with re-sampling.
    resampled_distribution = []
    for class_name, class_weight in class_distribution.items():
        tf_record = os.path.join(tf_records_dir, f"{class_name}.tf_record")
        class_dataset = tf.data.TFRecordDataset(tf_record)
        assert class_cardinality[class_name] > 0, class_cardinality
        class_dataset = class_dataset.shuffle(
            min(class_cardinality[class_name], MAX_SHUFFLE_BUFFER_SIZE),
            seed=constants.SEED,
            reshuffle_each_iteration=False)
        class_datasets.append(class_dataset.repeat())
        resampled_distribution.append(class_weight)
    dataset_cardinality = int(class_cardinality[REFERENCE_CLASS] /
        class_distribution[REFERENCE_CLASS])
    dataset = tf.data.experimental.sample_from_datasets(
        class_datasets, resampled_distribution, seed=constants.SEED)

    # Elements cannot be processed in parallel because
    # of the stateful non-determinism in the data augmentations.
    dataset = dataset.map(
        prepare_example, num_parallel_calls=1, deterministic=True)
    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset.prefetch(1), dataset_cardinality

I cannot provide the full code I use due to it being proprietary, but here is the data loading portion. If more information is needed to root cause this, let me know, and I will see what I can do to provide it. FYI the main code sets all the seeds correctly and disables horovod fusion as suggested by the repo README.

Thanks a lot for the great work on making Tensorflow deterministic. It, along with the documentation provided, has been incredibly useful in my day-to-day work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions