9.6 目标检测数据集(皮卡丘)
前面说了,皮卡丘数据集使用MXNet提供的im2rec工具将图像转换成了二进制的RecordIO格式,但是我们后续要使用PyTorch,所以我先用脚本将其转换成了PNG图片并用json文件存放对应的label信息。在继续阅读前,请务必确保运行了这个脚本,保证数据已准备好。文件夹下的结构应如下所示。
先导入相关库。
%matplotlib inline
import os
import json
import numpy as np
import torch
import torchvision
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
assert os.path.exists(os.path.join(data_dir, "train"))
我们先定义一个数据集类PikachuDetDataset
,数据集每个样本包含label
和image
,其中label是一个
的向量,即m个边界框,每个边界框由[class, x_min, y_min, x_max, y_max]
表示,这里的皮卡丘数据集中每个图像只有一个边界框,因此m=1。image
是一个所有元素都位于[0.0, 1.0]
的浮点tensor
,代表图片数据。
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def load_data_pikachu(batch_size, edge_size=256, data_dir = '../../data/pikachu'):
train_dataset = PikachuDetDataset(data_dir, 'train', image_size)
val_dataset = PikachuDetDataset(data_dir, 'val', image_size)
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=4)
val_iter = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=4)
return train_iter, val_iter
下面我们读取一个小批量并打印图像和标签的形状。图像的形状和之前实验中的一样,依然是(批量大小, 通道数, 高, 宽)。而标签的形状则是(批量大小,
, 5),其中
等于数据集中单个图像最多含有的边界框个数。小批量计算虽然高效,但它要求每张图像含有相同数量的边界框,以便放在同一个批量中。由于每张图像含有的边界框个数可能不同,我们为边界框个数小于
的图像填充非法边界框,直到每张图像均含有
个边界框。这样,我们就可以每次读取小批量的图像了。图像中每个边界框的标签由长度为5的数组表示。数组中第一个元素是边界框所含目标的类别。当值为-1时,该边界框为填充用的非法边界框。数组的剩余4个元素分别表示边界框左上角的
和
轴坐标以及右下角的
和
轴坐标(值域在0到1之间)。这里的皮卡丘数据集中每个图像只有一个边界框,因此
。
输出:
torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])
- 目标检测的数据读取跟图像分类的类似。然而,在引入边界框后,标签形状和图像增广(如随机裁剪)发生了变化。
[1] im2rec工具。
[2] GluonCV 工具包。https://gluon-cv.mxnet.io/
注:除代码外本节与原书基本相同,