Skip to content

torchvision.datasets

torchvision.datasets中包含了以下数据集

  1. MNIST
  2. COCO(用于图像标注和目标检测)(Captioning and Detection)
  3. LSUN Classification
  4. ImageFolder
  5. Imagenet-12
  6. CIFAR10 and CIFAR100
  7. STL10
  8. SVHN
  9. PhotoTour

所有数据集都是torch.utils.data.Dataset的子类, 即它们具有getitemlen实现方法。因此,它们都可以传递给torch.utils.data.DataLoader 可以使用torch.multiprocessing工作人员并行加载多个样本的数据。例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads) 

所有的数据集都有几乎相似的 API。他们都有两个共同的参数: transform 和 target_transform 分别转换输入和目标。

MNIST

dset.MNIST(root, train=True, transform=None, target_transform=None, download=False) 

参数说明:

  • root : 数据集,存在于根目录 processed/training.pt 和 processed/test.pt 中。
  • train : True = 训练集, False = 测试集
  • download : 如果为 true,请从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
  • transform :接收 PIL 映像并返回转换版本的函数/变换。例如:transforms.RandomCrop
  • target_transform:一个接收目标并转换它的函数/变换。

COCO

需要安装COCO API

dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform]) 

参数说明:

  • root(string) - 图像下载到的根目录。
  • annFile(string) - json 注释文件的路径。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如:transforms.ToTensor
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。

例子:

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target) 

输出:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background'] 

getitem(index) 参数: index(int) - 索引 返回: 元组(图像,目标)。目标是图片的标题列表。 返回类型: 元组

检测:

dset.CocoDetection(root="dir where images are", annFile="json annotation file", [transform, target_transform]) 

参数:

  • root(string) - 图像下载到的根目录。
  • annFile(string) - json 注释文件的路径。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.ToTensor
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。

getitem(index) 参数: index(int) - 索引 返回: 元组(图像,目标)。目标是图片的标题列表。 返回类型: 元组

LSUN

dset.LSUN(db_path, classes='train', [transform, target_transform]) 

参数说明:

  • db_path = 数据集文件的根目录
  • classe = train (所有类别, 训练集), val (所有类别, 验证集), test (所有类别, 测试集) [bedroom_train, church_train, …] : 要加载的类别列表
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.RandomCrop
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。

ImageFolder

一个通用的数据加载器,数据集中的数据以以下方式组织

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
dset.ImageFolder(root="root folder path", [transform, target_transform]) 

参数说明:

  • root(string) - 根目录路径。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.RandomCrop
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。
  • loader - 加载给定其路径的图像的函数。

Imagenet-12

这应该简单地用ImageFolder数据集实现。数据按照这里所述进行预处理 这里有一个例子

CIFAR

dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False) 

参数说明:

  • root : cifar-10-batches-py 的根目录
  • train : True = 训练集, False = 测试集
  • download : True = 从互联上下载数据,并将其放在 root 目录下。如果数据集已经下载,什么都不干。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.RandomCrop
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。

STL10

dset.STL10(root, split='train', transform=None, target_transform=None, download=False) 

参数说明:

  • root : stl10_binary 的根目录
  • split : 'train' = 训练集, 'test' = 测试集, 'unlabeled' = 无标签数据集, 'train+unlabeled' = 训练 + 无标签数据集 (没有标签的标记为-1)
  • download : True = 从互联上下载数据,并将其放在 root 目录下。如果数据集已经下载,什么都不干。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.RandomCrop
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。

SVHN

class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False) 

参数:

  • root(string) - 目录 SVHN 存在的数据集的根目录 。
  • split(string) - {'train','test','extra'}之一。因此选择数据集。“额外”是额外的训练集。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。例如,transforms.RandomCrop
  • target_transform(可调用,可选) - 一个接收目标并转换它的函数/变换。
  • download (bool,可选) - 如果为 true,请从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。

PhotoTour

class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False) 

参数:

  • root(string) - 图像所在的根目录。
  • name(string) - 要加载的数据集的名称。
  • transform(可调用,可选) - 接收 PIL 映像并返回转换版本的函数/变换。
  • download (bool,可选) - 如果为 true,请从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。

译者署名

用户名 头像 职能 签名
Song 翻译 人生总要追求点什么

我们一直在努力

apachecn/AiLearning

【布客】中文翻译组