torchvision.datasets

    • MNIST
    • COCO(用于图像标注和目标检测)(Captioning and Detection)
    • LSUN Classification
    • ImageFolder
    • Imagenet-12
    • CIFAR10 and CIFAR100
    • STL10

    Datasets 拥有以下API:

    __getitem__
    __len__

    由于以上Datasets都是 torch.utils.data.Dataset的子类,所以,他们也可以通过torch.utils.data.DataLoader使用多线程(python的多进程)。

    举例说明:
    torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)

    在构造函数中,不同的数据集直接的构造函数会有些许不同,但是他们共同拥有 keyword 参数。
    In the constructor, each dataset has a slightly different API as needed, but they all take the keyword args:

    • target_transform - 一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。

      参数说明:

    • root : 和 processed/test.pt 的主目录
    • train : True = 训练集, False = 测试集
    • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。

    COCO

    需要安装

    1. dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform])

    例子:

    1. import torchvision.datasets as dset
    2. import torchvision.transforms as transforms
    3. cap = dset.CocoCaptions(root = 'dir where images are',
    4. annFile = 'json annotation file',
    5. transform=transforms.ToTensor())
    6. print('Number of samples: ', len(cap))
    7. img, target = cap[3] # load 4th sample
    8. print("Image Size: ", img.size())
    9. print(target)

    输出:

    检测:

      1. dset.LSUN(db_path, classes='train', [transform, target_transform])
      • db_path = 数据集文件的根目录
      • classes = ‘train’ (所有类别, 训练集), ‘val’ (所有类别, 验证集), ‘test’ (所有类别, 测试集)
        [‘bedroom_train’, ‘church_train’, …] : a list of categories to load

        ImageFolder

        一个通用的数据加载器,数据集中的数据以以下方式组织
        ```
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

      root/cat/123.png
      root/cat/nsdf3.png
      root/cat/asd932_.png

      他有以下成员变量:

      • self.classes - 用一个list保存 类名
      • self.class_to_idx - 类名对应的 索引
      • self.imgs - 保存(img-path, class) tuple的list

      This is simply implemented with an ImageFolder dataset.

      The data is preprocessed as described here

      CIFAR

      1. dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
      2. dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
      • root : cifar-10-batches-py 的根目录
      • train : True = 训练集, False = 测试集
      • download : True = 从互联上下载数据,并将其放在root目录下。如果数据集已经下载,什么都不干。
        1. dset.STL10(root, split='train', transform=None, target_transform=None, download=False)
        参数说明:
      • root : stl10_binary的根目录
      • split : ‘train’ = 训练集, ‘test’ = 测试集, ‘unlabeled’ = 无标签数据集, ‘train+unlabeled’ = 训练 + 无标签数据集 (没有标签的标记为-1)
      • download : True = 从互联上下载数据,并将其放在目录下。如果数据集已经下载,什么都不干。