Skip to content

Commit

Permalink
Add SBU Captioned Photo Dataset (pytorch#665)
Browse files Browse the repository at this point in the history
* Add SBU Captioned Photo Dataset

* Add SBU to the dataset docs
  • Loading branch information
adamjstewart authored and fmassa committed Dec 4, 2018
1 parent d563769 commit 878a771
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,12 @@ PhotoTour
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:

SBU
~~~


.. autoclass:: SBU
:members: __getitem__
:special-members:

3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
from .sbu import SBU

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot')
'Omniglot', 'SBU')
109 changes: 109 additions & 0 deletions torchvision/datasets/sbu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from PIL import Image
from six.moves import zip
from .utils import download_url, check_integrity

import os
import torch.utils.data as data


class SBU(data.Dataset):
"""`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
Args:
root (string): Root directory of dataset where tarball
``SBUCaptionedPhotoDataset.tar.gz`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'

def __init__(self, root, transform=None, target_transform=None, download=True):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# Read the caption for each photo
self.photos = []
self.captions = []

file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')

for line1, line2 in zip(open(file1), open(file2)):
url = line1.rstrip()
photo = os.path.basename(url)
filename = os.path.join(self.root, 'dataset', photo)
if os.path.exists(filename):
caption = line2.rstrip()
self.photos.append(photo)
self.captions.append(caption)

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a caption for the photo.
"""
filename = os.path.join(self.root, 'dataset', self.photos[index])
img = Image.open(filename).convert('RGB')
if self.transform is not None:
img = self.transform(img)

target = self.captions[index]
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
"""The number of photos in the dataset."""
return len(self.photos)

def _check_integrity(self):
"""Check the md5 checksum of the downloaded tarball."""
root = self.root
fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum):
return False
return True

def download(self):
"""Download and extract the tarball, and download each individual photo."""
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url(self.url, self.root, self.filename, self.md5_checksum)

# Extract file
with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
tar.extractall(path=self.root)

# Download individual photos
with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
for line in fh:
url = line.rstrip()
try:
download_url(url, os.path.join(self.root, 'dataset'))
except OSError:
# The images point to public images on Flickr.
# Note: Images might be removed by users at anytime.
pass
12 changes: 11 additions & 1 deletion torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,20 @@ def makedir_exist_ok(dirpath):
raise


def download_url(url, root, filename, md5):
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib

root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)

makedir_exist_ok(root)
Expand Down

0 comments on commit 878a771

Please sign in to comment.