Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion advanced_source/cpp_cuda_graphs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)
message(STATUS "Downloading MNIST dataset")
execute_process(
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/download_mnist.py
-d ${CMAKE_BINARY_DIR}/data
ERROR_VARIABLE DOWNLOAD_ERROR)
if (DOWNLOAD_ERROR)
Expand Down
6 changes: 4 additions & 2 deletions advanced_source/cpp_cuda_graphs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ $ make
```

where `/path/to/libtorch` should be the path to the unzipped _LibTorch_
distribution, which you can get from the [PyTorch
homepage](https://pytorch.org/get-started/locally/).
distribution or PyTorch's CMake prefix path
`python -c "import torch; print(torch.utils.cmake_prefix_path)"`.
Please see [PyTorch homepage](https://pytorch.org/get-started/locally/)
for installation instructions.

Execute the compiled binary to train the model:

Expand Down
89 changes: 89 additions & 0 deletions advanced_source/cpp_cuda_graphs/download_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import division
from __future__ import print_function

import argparse
import gzip
import os
import sys
import urllib

try:
from urllib.error import URLError
from urllib.request import urlretrieve
except ImportError:
from urllib2 import URLError
from urllib import urlretrieve

RESOURCES = [
'train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz',
]


def report_download_progress(chunk_number, chunk_size, file_size):
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = '#' * int(64 * percent)
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))


def download(destination_path, url, quiet):
if os.path.exists(destination_path):
if not quiet:
print('{} already exists, skipping ...'.format(destination_path))
else:
print('Downloading {} ...'.format(url))
try:
hook = None if quiet else report_download_progress
urlretrieve(url, destination_path, reporthook=hook)
except URLError:
raise RuntimeError('Error downloading resource!')
finally:
if not quiet:
# Just a newline.
print()


def unzip(zipped_path, quiet):
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
print('{} already exists, skipping ... '.format(unzipped_path))
return
with gzip.open(zipped_path, 'rb') as zipped_file:
with open(unzipped_path, 'wb') as unzipped_file:
unzipped_file.write(zipped_file.read())
if not quiet:
print('Unzipped {} ...'.format(zipped_path))


def main():
parser = argparse.ArgumentParser(
description='Download the MNIST dataset from the internet')
parser.add_argument(
'-d', '--destination', default='.', help='Destination directory')
parser.add_argument(
'-q',
'--quiet',
action='store_true',
help="Don't report about progress")
options = parser.parse_args()

if not os.path.exists(options.destination):
os.makedirs(options.destination)

try:
for resource in RESOURCES:
path = os.path.join(options.destination, resource)
# url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
url = 'https://ossci-datasets.s3.amazonaws.com/mnist/{}'.format(resource)
download(path, url, options.quiet)
unzip(path, options.quiet)
except KeyboardInterrupt:
print('Interrupted')


if __name__ == '__main__':
main()
Loading