Skip to content

Commit

Permalink
Improve dataset download logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 4, 2018
1 parent 5854c7d commit a8b72a8
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 24 deletions.
5 changes: 4 additions & 1 deletion tensorpack/dataflow/dataset/bsds500.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from ..base import RNGDataFlow

__all__ = ['BSDS500']


DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_SIZE = 70763455
IMG_W, IMG_H = 481, 321


Expand All @@ -35,7 +38,7 @@ def __init__(self, name, data_dir=None, shuffle=True):
if data_dir is None:
data_dir = get_dataset_path('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir)
download(DATA_URL, data_dir, expect_size=DATA_SIZE)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(data_dir, filename)
import tarfile
Expand Down
13 changes: 6 additions & 7 deletions tensorpack/dataflow/dataset/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import pickle
import numpy as np
import tarfile
import six
from six.moves import range

Expand All @@ -16,13 +17,12 @@
__all__ = ['Cifar10', 'Cifar100']


DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_10 = ('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 170498071)
DATA_URL_CIFAR_100 = ('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 169001437)


def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
"""Download and extract the tarball from Alex's website. Copied from tensorflow example """
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
cifar_foldername = 'cifar-10-batches-py'
Expand All @@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
return
else:
DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1]
filename = DATA_URL[0].split('/')[-1]
filepath = os.path.join(dest_directory, filename)
import tarfile
download(DATA_URL[0], dest_directory, expect_size=DATA_URL[1])
tarfile.open(filepath, 'r:gz').extractall(dest_directory)


Expand Down
4 changes: 2 additions & 2 deletions tensorpack/dataflow/dataset/ilsvrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']

CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008)


class ILSVRCMeta(object):
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_synset_1000(self):
return dict(enumerate(lines))

def _download_caffe_meta(self):
fpath = download(CAFFE_ILSVRC12_URL, self.dir, expect_size=17858008)
fpath = download(CAFFE_ILSVRC12_URL[0], self.dir, expect_size=CAFFE_ILSVRC12_URL[1])
tarfile.open(fpath, 'r:gz').extractall(self.dir)

def get_image_list(self, name, dir_structure='original'):
Expand Down
12 changes: 8 additions & 4 deletions tensorpack/graph_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_number
from ..utils.argtools import call_only_once
from ..utils import logger

Expand Down Expand Up @@ -66,13 +67,16 @@ def __init__(self, worker_device, ps_devices):
self.ps_sizes = [0] * len(self.ps_devices)

def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if get_tf_version_number() >= 1.8:
from tensorflow.python.training.device_util import canonicalize
else:
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()

if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
return canonicalize(self.worker_device)

device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
Expand All @@ -84,7 +88,7 @@ def sanitize_name(name): # tensorflow/tensorflow#11484

self.ps_sizes[device_index] += var_size

return sanitize_name(device_name)
return canonicalize(device_name)

def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
Expand Down
19 changes: 11 additions & 8 deletions tensorpack/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
import cv2 # noqa
if int(cv2.__version__.split('.')[0]) == 3:
cv2.ocl.setUseOpenCL(False)
# check if cv is built with cuda
# check if cv is built with cuda or openmp
info = cv2.getBuildInformation().split('\n')
for line in info:
if 'use cuda' in line.lower():
answer = line.split()[-1].lower()
if answer == 'yes':
splits = line.split()
if not len(splits):
continue
answer = splits[-1].lower()
if answer in ['yes', 'no']:
if 'cuda' in line.lower() and answer == 'yes':
# issue#1197
print("OpenCV is built with CUDA support. "
"This may cause slow initialization or sometimes segfault with TensorFlow.")
break
if answer == 'openmp':
print("OpenCV is built with OpenMP support. This usually results in poor performance. For details, see "
"https://github.com/tensorpack/benchmarks/blob/master/ImageNet/benchmark-opencv-resize.py")
except (ImportError, TypeError):
pass

Expand All @@ -41,9 +46,7 @@
try:
import tensorflow as tf # noqa
_version = tf.__version__.split('.')
assert int(_version[0]) >= 1, "TF>=1.0 is required!"
if int(_version[1]) < 3:
print("TF<1.3 support will be removed after 2018-03-15! Actually many examples already require TF>=1.3.")
assert int(_version[0]) >= 1 and int(_version[1]) >= 3, "TF>=1.3 is required!"
_HAS_TF = True
except ImportError:
_HAS_TF = False
Expand Down
11 changes: 9 additions & 2 deletions tensorpack/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def mkdir_p(dirname):
""" Make a dir recursively, but do nothing if the dir exists
""" Like "mkdir -p", make a dir recursively, but do nothing if the dir exists
Args:
dirname(str):
Expand All @@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None):
filename = url.split('/')[-1]
fpath = os.path.join(dir, filename)

if os.path.isfile(fpath):
if expect_size is not None and os.stat(fpath).st_size == expect_size:
logger.info("File {} exists! Skip download.".format(filename))
return fpath
else:
logger.warn("File {} exists. Will overwrite with a new download!".format(filename))

def hook(t):
last_b = [0]

Expand All @@ -62,7 +69,7 @@ def inner(b, bsize, tsize=None):
logger.error("You may have downloaded a broken file, or the upstream may have modified the file.")

# TODO human-readable size
print('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.')
logger.info('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.')
return fpath


Expand Down

0 comments on commit a8b72a8

Please sign in to comment.