From 84ff6748f6ac97a0a2e4ad20c9d07ba460cb0074 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 15 Sep 2025 10:38:55 +0400 Subject: [PATCH] Add script to download the data --- .../cpp_cuda_graphs/CMakeLists.txt | 2 +- advanced_source/cpp_cuda_graphs/README.md | 6 +- .../cpp_cuda_graphs/download_mnist.py | 89 +++++++++++++++++++ 3 files changed, 94 insertions(+), 3 deletions(-) create mode 100644 advanced_source/cpp_cuda_graphs/download_mnist.py diff --git a/advanced_source/cpp_cuda_graphs/CMakeLists.txt b/advanced_source/cpp_cuda_graphs/CMakeLists.txt index 76fc5bc6762..6cd44b84426 100644 --- a/advanced_source/cpp_cuda_graphs/CMakeLists.txt +++ b/advanced_source/cpp_cuda_graphs/CMakeLists.txt @@ -9,7 +9,7 @@ option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) if (DOWNLOAD_MNIST) message(STATUS "Downloading MNIST dataset") execute_process( - COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/download_mnist.py -d ${CMAKE_BINARY_DIR}/data ERROR_VARIABLE DOWNLOAD_ERROR) if (DOWNLOAD_ERROR) diff --git a/advanced_source/cpp_cuda_graphs/README.md b/advanced_source/cpp_cuda_graphs/README.md index cbe368d1e90..68e2b556afc 100644 --- a/advanced_source/cpp_cuda_graphs/README.md +++ b/advanced_source/cpp_cuda_graphs/README.md @@ -16,8 +16,10 @@ $ make ``` where `/path/to/libtorch` should be the path to the unzipped _LibTorch_ -distribution, which you can get from the [PyTorch -homepage](https://pytorch.org/get-started/locally/). +distribution or PyTorch's CMake prefix path +`python -c "import torch; print(torch.utils.cmake_prefix_path)"`. +Please see [PyTorch homepage](https://pytorch.org/get-started/locally/) +for installation instructions. Execute the compiled binary to train the model: diff --git a/advanced_source/cpp_cuda_graphs/download_mnist.py b/advanced_source/cpp_cuda_graphs/download_mnist.py new file mode 100644 index 00000000000..0c47994c37a --- /dev/null +++ b/advanced_source/cpp_cuda_graphs/download_mnist.py @@ -0,0 +1,89 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import gzip +import os +import sys +import urllib + +try: + from urllib.error import URLError + from urllib.request import urlretrieve +except ImportError: + from urllib2 import URLError + from urllib import urlretrieve + +RESOURCES = [ + 'train-images-idx3-ubyte.gz', + 'train-labels-idx1-ubyte.gz', + 't10k-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz', +] + + +def report_download_progress(chunk_number, chunk_size, file_size): + if file_size != -1: + percent = min(1, (chunk_number * chunk_size) / file_size) + bar = '#' * int(64 * percent) + sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100))) + + +def download(destination_path, url, quiet): + if os.path.exists(destination_path): + if not quiet: + print('{} already exists, skipping ...'.format(destination_path)) + else: + print('Downloading {} ...'.format(url)) + try: + hook = None if quiet else report_download_progress + urlretrieve(url, destination_path, reporthook=hook) + except URLError: + raise RuntimeError('Error downloading resource!') + finally: + if not quiet: + # Just a newline. + print() + + +def unzip(zipped_path, quiet): + unzipped_path = os.path.splitext(zipped_path)[0] + if os.path.exists(unzipped_path): + if not quiet: + print('{} already exists, skipping ... '.format(unzipped_path)) + return + with gzip.open(zipped_path, 'rb') as zipped_file: + with open(unzipped_path, 'wb') as unzipped_file: + unzipped_file.write(zipped_file.read()) + if not quiet: + print('Unzipped {} ...'.format(zipped_path)) + + +def main(): + parser = argparse.ArgumentParser( + description='Download the MNIST dataset from the internet') + parser.add_argument( + '-d', '--destination', default='.', help='Destination directory') + parser.add_argument( + '-q', + '--quiet', + action='store_true', + help="Don't report about progress") + options = parser.parse_args() + + if not os.path.exists(options.destination): + os.makedirs(options.destination) + + try: + for resource in RESOURCES: + path = os.path.join(options.destination, resource) + # url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource) + url = 'https://ossci-datasets.s3.amazonaws.com/mnist/{}'.format(resource) + download(path, url, options.quiet) + unzip(path, options.quiet) + except KeyboardInterrupt: + print('Interrupted') + + +if __name__ == '__main__': + main()