基于U-Net卷积神经网络实现宠物图像分割

    日期: 2021.01

    摘要: 本示例教程使用U-Net实现图像分割。

    在计算机视觉领域,图像分割指的是将数字图像细分为多个图像子区域的过程。图像分割的目的是简化或改变图像的表示形式,使得图像更容易理解和分析。图像分割通常用于定位图像中的物体和边界(线,曲线等)。更精确的,图像分割是对图像中的每个像素加标签的一个过程,这一过程使得具有相同标签的像素具有某种共同视觉特性。图像分割的领域非常多,无人车、地块检测、表计识别等等。

    本示例简要介绍如何通过飞桨开源框架,实现图像分割。这里我们是采用了一个在图像分割领域比较熟知的U-Net网络结构,是一个基于FCN做改进后的一个深度学习网络,包含下采样(编码器,特征提取)和上采样(解码器,分辨率还原)两个阶段,因模型结构比较像U型而命名为U-Net。

    二、环境设置

    导入一些比较基础常用的模块,确认自己的飞桨版本。

    本案例使用Oxford-IIIT Pet数据集,官网: 。

    数据集统计如下:

    alt 数据集统计信息

    数据集包含两个压缩文件:

    1. 分割图像:https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz

    1. !curl -O http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
    2. !curl -O http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
    3. !tar -xf images.tar.gz
    4. !tar -xf annotations.tar.gz

    3.2 数据集概览

    首先我们先看看下载到磁盘上的文件结构是什么样,来了解一下我们的数据集。

    1. 首先看一下images.tar.gz这个压缩包,该文件解压后得到一个images目录,这个目录比较简单,里面直接放的是用类名和序号命名好的图片文件,每个图片是对应的宠物照片。
    1. .
    2. ├── samoyed_7.jpg
    3. ├── ......
    4. └── samoyed_81.jpg
    1. 然后我们在看下annotations.tar.gz,文件解压后的目录里面包含以下内容,目录中的README文件将每个目录和文件做了比较详细的介绍,我们可以通过README来查看每个目录文件的说明。
    1. .
    2. ├── README
    3. ├── list.txt
    4. ├── test.txt
    5. ├── trainval.txt
    6. ├── trimaps
    7. ├── Abyssinian_1.png
    8. ├── Abyssinian_10.png
    9. ├── ......
    10. └── yorkshire_terrier_99.png
    11. └── xmls
    12. ├── Abyssinian_1.xml
    13. ├── Abyssinian_10.xml
    14. ├── ......
    15. └── yorkshire_terrier_190.xml

    本次我们主要使用到images和annotations/trimaps两个目录,即原图和三元图像文件,前者作为训练的输入数据,后者是对应的标签数据。

    我们来看看这个数据集给我们提供了多少个训练样本。

    1. IMAGE_SIZE = (160, 160)
    2. train_images_path = "images/"
    3. label_images_path = "annotations/trimaps/"
    4. image_count = len([os.path.join(train_images_path, image_name)
    5. for image_name in os.listdir(train_images_path)
    6. if image_name.endswith('.jpg')])
    7. print("用于训练的图片样本数量:", image_count)
    8. # 对数据集进行处理,划分训练集、测试集
    9. def _sort_images(image_dir, image_type):
    10. """
    11. 对文件夹内的图像进行按照文件名排序
    12. """
    13. files = []
    14. for image_name in os.listdir(image_dir):
    15. if image_name.endswith('.{}'.format(image_type)) \
    16. and not image_name.startswith('.'):
    17. files.append(os.path.join(image_dir, image_name))
    18. return sorted(files)
    19. def write_file(mode, images, labels):
    20. with open('./{}.txt'.format(mode), 'w') as f:
    21. for i in range(len(images)):
    22. f.write('{}\t{}\n'.format(images[i], labels[i]))
    23. """
    24. 由于所有文件都是散落在文件夹中,在训练时我们需要使用的是数据集和标签对应的数据关系,
    25. 所以我们第一步是对原始的数据集进行整理,得到数据集和标签两个数组,分别一一对应。
    26. 这样可以在使用的时候能够很方便的找到原始数据和标签的对应关系,否则对于原有的文件夹图片数据无法直接应用。
    27. 在这里是用了一个非常简单的方法,按照文件名称进行排序。
    28. 因为刚好数据和标签的文件名是按照这个逻辑制作的,名字都一样,只有扩展名不一样。
    29. """
    30. images = _sort_images(train_images_path, 'jpg')
    31. labels = _sort_images(label_images_path, 'png')
    32. eval_num = int(image_count * 0.15)
    33. write_file('train', images[:-eval_num], labels[:-eval_num])
    34. write_file('test', images[-eval_num:], labels[-eval_num:])
    35. write_file('predict', images[-eval_num:], labels[-eval_num:])
    1. 用于训练的图片样本数量: 7390

    3.3 PetDataSet数据集抽样展示

    划分好数据集之后,我们来查验一下数据集是否符合预期,我们通过划分的配置文件读取图片路径后再加载图片数据来用matplotlib进行展示,这里要注意的是对于分割的标签文件因为是1通道的灰度图片,需要在使用imshow接口时注意下传参cmap=‘gray’。

    ../../../_images/image_segmentation_10_0.png ../../../_images/image_segmentation_10_2.png

    3.4 数据集类定义

    飞桨(PaddlePaddle)数据集加载方案是统一使用Dataset(数据集定义) + DataLoader(多进程数据集加载)。

    首先我们先进行数据集的定义,数据集定义主要是实现一个新的Dataset类,继承父类paddle.io.Dataset,并实现父类中以下两个抽象方法,__getitem____len__

    1. class MyDataset(Dataset):
    2. def __init__(self):
    3. ...
    4. # 每次迭代时返回数据和对应的标签
    5. def __getitem__(self, idx):
    6. return x, y
    7. # 返回整个数据集的总数
    8. def __len__(self):
    9. return count(samples)

    在数据集内部可以结合图像数据预处理相关API进行图像的预处理(改变大小、反转、调整格式等)。

    由于加载进来的图像不一定都符合自己的需求,举个例子,已下载的这些图片里面就会有RGBA格式的图片,这个时候图片就不符合我们所需3通道的需求,我们需要进行图片的格式转换,那么这里我们直接实现了一个通用的图片读取接口,确保读取出来的图片都是满足我们的需求。

    1. import random
    2. from paddle.io import Dataset
    3. from paddle.vision.transforms import transforms as T
    4. class PetDataset(Dataset):
    5. """
    6. 数据集定义
    7. """
    8. def __init__(self, mode='train'):
    9. """
    10. 构造函数
    11. """
    12. self.image_size = IMAGE_SIZE
    13. self.mode = mode.lower()
    14. assert self.mode in ['train', 'test', 'predict'], \
    15. "mode should be 'train' or 'test' or 'predict', but got {}".format(self.mode)
    16. self.train_images = []
    17. self.label_images = []
    18. with open('./{}.txt'.format(self.mode), 'r') as f:
    19. for line in f.readlines():
    20. image, label = line.strip().split('\t')
    21. self.train_images.append(image)
    22. self.label_images.append(label)
    23. def _load_img(self, path, color_mode='rgb', transforms=[]):
    24. """
    25. 统一的图像处理接口封装,用于规整图像大小和通道
    26. """
    27. with open(path, 'rb') as f:
    28. img = PilImage.open(io.BytesIO(f.read()))
    29. if color_mode == 'grayscale':
    30. # if image is not already an 8-bit, 16-bit or 32-bit grayscale image
    31. # convert it to an 8-bit grayscale image.
    32. if img.mode not in ('L', 'I;16', 'I'):
    33. img = img.convert('L')
    34. elif color_mode == 'rgba':
    35. if img.mode != 'RGBA':
    36. img = img.convert('RGBA')
    37. elif color_mode == 'rgb':
    38. if img.mode != 'RGB':
    39. img = img.convert('RGB')
    40. else:
    41. raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
    42. return T.Compose([
    43. T.Resize(self.image_size)
    44. ] + transforms)(img)
    45. def __getitem__(self, idx):
    46. """
    47. 返回 image, label
    48. """
    49. train_image = self._load_img(self.train_images[idx],
    50. transforms=[
    51. T.Transpose(),
    52. T.Normalize(mean=127.5, std=127.5)
    53. ]) # 加载原始图像
    54. label_image = self._load_img(self.label_images[idx],
    55. color_mode='grayscale',
    56. transforms=[T.Grayscale()]) # 加载Label图像
    57. # 返回image, label
    58. train_image = np.array(train_image, dtype='float32')
    59. label_image = np.array(label_image, dtype='int64')
    60. return train_image, label_image
    61. def __len__(self):
    62. """
    63. 返回数据集总数
    64. """
    65. return len(self.train_images)

    四、模型组网

    U-Net是一个U型网络结构,可以看做两个大的阶段,图像先经过Encoder编码器进行下采样得到高级语义特征图,再经过Decoder解码器上采样将特征图恢复到原图片的分辨率。

    我们为了减少卷积操作中的训练参数来提升性能,是继承paddle.nn.Layer自定义了一个SeparableConv2D Layer类,整个过程是把filter_size * filter_size * num_filters的Conv2D操作拆解为两个子Conv2D,先对输入数据的每个通道使用filter_size * filter_size * 1的卷积核进行计算,输入输出通道数目相同,之后在使用1 * 1 * num_filters的卷积核计算。

    1. from paddle.nn import functional as F
    2. class SeparableConv2D(paddle.nn.Layer):
    3. def __init__(self,
    4. in_channels,
    5. out_channels,
    6. kernel_size,
    7. stride=1,
    8. padding=0,
    9. dilation=1,
    10. groups=None,
    11. weight_attr=None,
    12. bias_attr=None,
    13. data_format="NCHW"):
    14. super(SeparableConv2D, self).__init__()
    15. self._padding = padding
    16. self._dilation = dilation
    17. self._in_channels = in_channels
    18. self._data_format = data_format
    19. # 第一次卷积参数,没有偏置参数
    20. filter_shape = [in_channels, 1] + self.convert_to_list(kernel_size, 2, 'kernel_size')
    21. self.weight_conv = self.create_parameter(shape=filter_shape, attr=weight_attr)
    22. # 第二次卷积参数
    23. filter_shape = [out_channels, in_channels] + self.convert_to_list(1, 2, 'kernel_size')
    24. self.weight_pointwise = self.create_parameter(shape=filter_shape, attr=weight_attr)
    25. attr=bias_attr,
    26. is_bias=True)
    27. def convert_to_list(self, value, n, name, dtype=np.int):
    28. if isinstance(value, dtype):
    29. return [value, ] * n
    30. else:
    31. try:
    32. value_list = list(value)
    33. except TypeError:
    34. raise ValueError("The " + name +
    35. "'s type must be list or tuple. Received: " + str(
    36. value))
    37. if len(value_list) != n:
    38. raise ValueError("The " + name + "'s length must be " + str(n) +
    39. ". Received: " + str(value))
    40. for single_value in value_list:
    41. try:
    42. dtype(single_value)
    43. except (ValueError, TypeError):
    44. raise ValueError(
    45. "The " + name + "'s type must be a list or tuple of " + str(
    46. n) + " " + str(dtype) + " . Received: " + str(
    47. value) + " "
    48. "including element " + str(single_value) + " of type" + " "
    49. + str(type(single_value)))
    50. return value_list
    51. def forward(self, inputs):
    52. conv_out = F.conv2d(inputs,
    53. self.weight_conv,
    54. padding=self._padding,
    55. stride=self._stride,
    56. dilation=self._dilation,
    57. groups=self._in_channels,
    58. data_format=self._data_format)
    59. out = F.conv2d(conv_out,
    60. self.weight_pointwise,
    61. bias=self.bias_pointwise,
    62. padding=0,
    63. stride=1,
    64. dilation=1,
    65. groups=1,
    66. data_format=self._data_format)
    67. return out

    4.2 定义Encoder编码器

    我们将网络结构中的Encoder下采样过程进行了一个Layer封装,方便后续调用,减少代码编写,下采样是有一个模型逐渐向下画曲线的一个过程,这个过程中是不断的重复一个单元结构将通道数不断增加,形状不断缩小,并且引入残差网络结构,我们将这些都抽象出来进行统一封装。

    1. class Encoder(paddle.nn.Layer):
    2. def __init__(self, in_channels, out_channels):
    3. super(Encoder, self).__init__()
    4. self.relus = paddle.nn.LayerList(
    5. [paddle.nn.ReLU() for i in range(2)])
    6. self.separable_conv_01 = SeparableConv2D(in_channels,
    7. out_channels,
    8. kernel_size=3,
    9. padding='same')
    10. self.bns = paddle.nn.LayerList(
    11. [paddle.nn.BatchNorm2D(out_channels) for i in range(2)])
    12. self.separable_conv_02 = SeparableConv2D(out_channels,
    13. out_channels,
    14. kernel_size=3,
    15. padding='same')
    16. self.pool = paddle.nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
    17. self.residual_conv = paddle.nn.Conv2D(in_channels,
    18. out_channels,
    19. kernel_size=1,
    20. stride=2,
    21. padding='same')
    22. def forward(self, inputs):
    23. previous_block_activation = inputs
    24. y = self.relus[0](inputs)
    25. y = self.separable_conv_01(y)
    26. y = self.bns[0](y)
    27. y = self.relus[1](y)
    28. y = self.separable_conv_02(y)
    29. y = self.bns[1](y)
    30. y = self.pool(y)
    31. residual = self.residual_conv(previous_block_activation)
    32. y = paddle.add(y, residual)
    33. return y

    4.3 定义Decoder解码器

    在通道数达到最大得到高级语义特征图后,网络结构会开始进行decode操作,进行上采样,通道数逐渐减小,对应图片尺寸逐步增加,直至恢复到原图像大小,那么这个过程里面也是通过不断的重复相同结构的残差网络完成,我们也是为了减少代码编写,将这个过程定义一个Layer来放到模型组网中使用。

    1. class Decoder(paddle.nn.Layer):
    2. def __init__(self, in_channels, out_channels):
    3. super(Decoder, self).__init__()
    4. self.relus = paddle.nn.LayerList(
    5. [paddle.nn.ReLU() for i in range(2)])
    6. self.conv_transpose_01 = paddle.nn.Conv2DTranspose(in_channels,
    7. out_channels,
    8. kernel_size=3,
    9. padding=1)
    10. self.conv_transpose_02 = paddle.nn.Conv2DTranspose(out_channels,
    11. out_channels,
    12. kernel_size=3,
    13. padding=1)
    14. self.bns = paddle.nn.LayerList(
    15. [paddle.nn.BatchNorm2D(out_channels) for i in range(2)]
    16. )
    17. self.upsamples = paddle.nn.LayerList(
    18. [paddle.nn.Upsample(scale_factor=2.0) for i in range(2)]
    19. )
    20. self.residual_conv = paddle.nn.Conv2D(in_channels,
    21. out_channels,
    22. kernel_size=1,
    23. padding='same')
    24. def forward(self, inputs):
    25. previous_block_activation = inputs
    26. y = self.relus[0](inputs)
    27. y = self.conv_transpose_01(y)
    28. y = self.bns[0](y)
    29. y = self.relus[1](y)
    30. y = self.conv_transpose_02(y)
    31. y = self.bns[1](y)
    32. y = self.upsamples[0](y)
    33. residual = self.upsamples[1](previous_block_activation)
    34. residual = self.residual_conv(residual)
    35. y = paddle.add(y, residual)
    36. return y

    4.4 训练模型组网

    按照U型网络结构格式进行整体的网络结构搭建,三次下采样,四次上采样。

    1. class PetNet(paddle.nn.Layer):
    2. def __init__(self, num_classes):
    3. super(PetNet, self).__init__()
    4. self.conv_1 = paddle.nn.Conv2D(3, 32,
    5. kernel_size=3,
    6. stride=2,
    7. padding='same')
    8. self.bn = paddle.nn.BatchNorm2D(32)
    9. self.relu = paddle.nn.ReLU()
    10. in_channels = 32
    11. self.encoders = []
    12. self.encoder_list = [64, 128, 256]
    13. self.decoder_list = [256, 128, 64, 32]
    14. # 根据下采样个数和配置循环定义子Layer,避免重复写一样的程序
    15. for out_channels in self.encoder_list:
    16. block = self.add_sublayer('encoder_{}'.format(out_channels),
    17. Encoder(in_channels, out_channels))
    18. self.encoders.append(block)
    19. in_channels = out_channels
    20. self.decoders = []
    21. # 根据上采样个数和配置循环定义子Layer,避免重复写一样的程序
    22. for out_channels in self.decoder_list:
    23. block = self.add_sublayer('decoder_{}'.format(out_channels),
    24. Decoder(in_channels, out_channels))
    25. self.decoders.append(block)
    26. in_channels = out_channels
    27. self.output_conv = paddle.nn.Conv2D(in_channels,
    28. num_classes,
    29. kernel_size=3,
    30. padding='same')
    31. def forward(self, inputs):
    32. y = self.conv_1(inputs)
    33. y = self.bn(y)
    34. y = self.relu(y)
    35. for encoder in self.encoders:
    36. y = encoder(y)
    37. for decoder in self.decoders:
    38. y = decoder(y)
    39. y = self.output_conv(y)
    40. return y

    调用飞桨提供的summary接口对组建好的模型进行可视化,方便进行模型结构和参数信息的查看和确认。

    1. -----------------------------------------------------------------------------
    2. Layer (type) Input Shape Output Shape Param #
    3. Conv2D-1 [[1, 3, 160, 160]] [1, 32, 80, 80] 896
    4. BatchNorm2D-1 [[1, 32, 80, 80]] [1, 32, 80, 80] 128
    5. ReLU-1 [[1, 32, 80, 80]] [1, 32, 80, 80] 0
    6. ReLU-2 [[1, 32, 80, 80]] [1, 32, 80, 80] 0
    7. SeparableConv2D-1 [[1, 32, 80, 80]] [1, 64, 80, 80] 2,400
    8. ReLU-3 [[1, 64, 80, 80]] [1, 64, 80, 80] 0
    9. SeparableConv2D-2 [[1, 64, 80, 80]] [1, 64, 80, 80] 4,736
    10. BatchNorm2D-3 [[1, 64, 80, 80]] [1, 64, 80, 80] 256
    11. MaxPool2D-1 [[1, 64, 80, 80]] [1, 64, 40, 40] 0
    12. Conv2D-2 [[1, 32, 80, 80]] [1, 64, 40, 40] 2,112
    13. Encoder-1 [[1, 32, 80, 80]] [1, 64, 40, 40] 0
    14. ReLU-4 [[1, 64, 40, 40]] [1, 64, 40, 40] 0
    15. SeparableConv2D-3 [[1, 64, 40, 40]] [1, 128, 40, 40] 8,896
    16. BatchNorm2D-4 [[1, 128, 40, 40]] [1, 128, 40, 40] 512
    17. ReLU-5 [[1, 128, 40, 40]] [1, 128, 40, 40] 0
    18. SeparableConv2D-4 [[1, 128, 40, 40]] [1, 128, 40, 40] 17,664
    19. BatchNorm2D-5 [[1, 128, 40, 40]] [1, 128, 40, 40] 512
    20. MaxPool2D-2 [[1, 128, 40, 40]] [1, 128, 20, 20] 0
    21. Conv2D-3 [[1, 64, 40, 40]] [1, 128, 20, 20] 8,320
    22. Encoder-2 [[1, 64, 40, 40]] [1, 128, 20, 20] 0
    23. ReLU-6 [[1, 128, 20, 20]] [1, 128, 20, 20] 0
    24. SeparableConv2D-5 [[1, 128, 20, 20]] [1, 256, 20, 20] 34,176
    25. BatchNorm2D-6 [[1, 256, 20, 20]] [1, 256, 20, 20] 1,024
    26. ReLU-7 [[1, 256, 20, 20]] [1, 256, 20, 20] 0
    27. SeparableConv2D-6 [[1, 256, 20, 20]] [1, 256, 20, 20] 68,096
    28. BatchNorm2D-7 [[1, 256, 20, 20]] [1, 256, 20, 20] 1,024
    29. MaxPool2D-3 [[1, 256, 20, 20]] [1, 256, 10, 10] 0
    30. Conv2D-4 [[1, 128, 20, 20]] [1, 256, 10, 10] 33,024
    31. Encoder-3 [[1, 128, 20, 20]] [1, 256, 10, 10] 0
    32. ReLU-8 [[1, 256, 10, 10]] [1, 256, 10, 10] 0
    33. Conv2DTranspose-1 [[1, 256, 10, 10]] [1, 256, 10, 10] 590,080
    34. BatchNorm2D-8 [[1, 256, 10, 10]] [1, 256, 10, 10] 1,024
    35. ReLU-9 [[1, 256, 10, 10]] [1, 256, 10, 10] 0
    36. Conv2DTranspose-2 [[1, 256, 10, 10]] [1, 256, 10, 10] 590,080
    37. BatchNorm2D-9 [[1, 256, 10, 10]] [1, 256, 10, 10] 1,024
    38. Upsample-1 [[1, 256, 10, 10]] [1, 256, 20, 20] 0
    39. Upsample-2 [[1, 256, 10, 10]] [1, 256, 20, 20] 0
    40. Conv2D-5 [[1, 256, 20, 20]] [1, 256, 20, 20] 65,792
    41. Decoder-1 [[1, 256, 10, 10]] [1, 256, 20, 20] 0
    42. ReLU-10 [[1, 256, 20, 20]] [1, 256, 20, 20] 0
    43. Conv2DTranspose-3 [[1, 256, 20, 20]] [1, 128, 20, 20] 295,040
    44. BatchNorm2D-10 [[1, 128, 20, 20]] [1, 128, 20, 20] 512
    45. ReLU-11 [[1, 128, 20, 20]] [1, 128, 20, 20] 0
    46. Conv2DTranspose-4 [[1, 128, 20, 20]] [1, 128, 20, 20] 147,584
    47. BatchNorm2D-11 [[1, 128, 20, 20]] [1, 128, 20, 20] 512
    48. Upsample-3 [[1, 128, 20, 20]] [1, 128, 40, 40] 0
    49. Upsample-4 [[1, 256, 20, 20]] [1, 256, 40, 40] 0
    50. Conv2D-6 [[1, 256, 40, 40]] [1, 128, 40, 40] 32,896
    51. Decoder-2 [[1, 256, 20, 20]] [1, 128, 40, 40] 0
    52. ReLU-12 [[1, 128, 40, 40]] [1, 128, 40, 40] 0
    53. Conv2DTranspose-5 [[1, 128, 40, 40]] [1, 64, 40, 40] 73,792
    54. BatchNorm2D-12 [[1, 64, 40, 40]] [1, 64, 40, 40] 256
    55. ReLU-13 [[1, 64, 40, 40]] [1, 64, 40, 40] 0
    56. Conv2DTranspose-6 [[1, 64, 40, 40]] [1, 64, 40, 40] 36,928
    57. BatchNorm2D-13 [[1, 64, 40, 40]] [1, 64, 40, 40] 256
    58. Upsample-5 [[1, 64, 40, 40]] [1, 64, 80, 80] 0
    59. Upsample-6 [[1, 128, 40, 40]] [1, 128, 80, 80] 0
    60. Conv2D-7 [[1, 128, 80, 80]] [1, 64, 80, 80] 8,256
    61. Decoder-3 [[1, 128, 40, 40]] [1, 64, 80, 80] 0
    62. ReLU-14 [[1, 64, 80, 80]] [1, 64, 80, 80] 0
    63. Conv2DTranspose-7 [[1, 64, 80, 80]] [1, 32, 80, 80] 18,464
    64. BatchNorm2D-14 [[1, 32, 80, 80]] [1, 32, 80, 80] 128
    65. ReLU-15 [[1, 32, 80, 80]] [1, 32, 80, 80] 0
    66. Conv2DTranspose-8 [[1, 32, 80, 80]] [1, 32, 80, 80] 9,248
    67. BatchNorm2D-15 [[1, 32, 80, 80]] [1, 32, 80, 80] 128
    68. Upsample-7 [[1, 32, 80, 80]] [1, 32, 160, 160] 0
    69. Upsample-8 [[1, 64, 80, 80]] [1, 64, 160, 160] 0
    70. Conv2D-8 [[1, 64, 160, 160]] [1, 32, 160, 160] 2,080
    71. Decoder-4 [[1, 64, 80, 80]] [1, 32, 160, 160] 0
    72. Conv2D-9 [[1, 32, 160, 160]] [1, 4, 160, 160] 1,156
    73. =============================================================================
    74. Total params: 2,059,268
    75. Trainable params: 2,051,716
    76. Non-trainable params: 7,552
    77. -----------------------------------------------------------------------------
    78. Input size (MB): 0.29
    79. Forward/backward pass size (MB): 117.77
    80. Params size (MB): 7.86
    81. Estimated Total Size (MB): 125.92
    82. -----------------------------------------------------------------------------
    1. {'total_params': 2059268, 'trainable_params': 2051716}

    5.1 启动模型训练

    使用模型代码进行Model实例生成,使用prepare接口定义优化器、损失函数和评价指标等信息,用于后续训练使用。在所有初步配置完成后,调用fit接口开启训练执行过程,调用fit时只需要将前面定义好的训练数据集、测试数据集、训练轮次(Epoch)和批次大小(batch_size)配置好即可。

    1. train_dataset = PetDataset(mode='train') # 训练数据集
    2. val_dataset = PetDataset(mode='test') # 验证数据集
    3. optim = paddle.optimizer.RMSProp(learning_rate=0.001,
    4. rho=0.9,
    5. momentum=0.0,
    6. epsilon=1e-07,
    7. centered=False,
    8. parameters=model.parameters())
    9. model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1))
    10. model.fit(train_dataset,
    11. val_dataset,
    12. epochs=15,
    13. batch_size=32,
    14. verbose=1)
    1. The loss value printed in the log is the current step, and the metric is the average value of previous step.
    2. Epoch 1/15
    3. step 197/197 [==============================] - loss: 0.8845 - 247ms/step
    4. Eval begin...
    5. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    6. step 35/35 [==============================] - loss: 0.7308 - 228ms/step
    7. Eval samples: 1108
    8. Epoch 2/15
    9. step 197/197 [==============================] - loss: 0.4457 - 248ms/step
    10. Eval begin...
    11. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    12. step 35/35 [==============================] - loss: 0.6239 - 229ms/step
    13. Eval samples: 1108
    14. Epoch 3/15
    15. step 197/197 [==============================] - loss: 0.4924 - 245ms/step
    16. Eval begin...
    17. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    18. step 35/35 [==============================] - loss: 0.5226 - 228ms/step
    19. Eval samples: 1108
    20. Epoch 4/15
    21. step 197/197 [==============================] - loss: 0.5653 - 247ms/step
    22. Eval begin...
    23. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    24. step 35/35 [==============================] - loss: 0.5351 - 229ms/step
    25. Eval samples: 1108
    26. Epoch 5/15
    27. step 197/197 [==============================] - loss: 0.5002 - 244ms/step
    28. Eval begin...
    29. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    30. step 35/35 [==============================] - loss: 0.4558 - 228ms/step
    31. Eval samples: 1108
    32. Epoch 6/15
    33. step 197/197 [==============================] - loss: 0.3979 - 245ms/step
    34. Eval begin...
    35. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    36. step 35/35 [==============================] - loss: 0.3849 - 230ms/step
    37. Eval samples: 1108
    38. Epoch 7/15
    39. step 197/197 [==============================] - loss: 0.2664 - 247ms/step
    40. Eval begin...
    41. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    42. step 35/35 [==============================] - loss: 0.3900 - 228ms/step
    43. Eval samples: 1108
    44. Epoch 8/15
    45. step 197/197 [==============================] - loss: 0.2803 - 246ms/step
    46. Eval begin...
    47. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    48. step 35/35 [==============================] - loss: 0.3872 - 228ms/step
    49. Eval samples: 1108
    50. Epoch 9/15
    51. step 197/197 [==============================] - loss: 0.4651 - 246ms/step
    52. Eval begin...
    53. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    54. step 35/35 [==============================] - loss: 0.4824 - 228ms/step
    55. Eval samples: 1108
    56. Epoch 10/15
    57. step 197/197 [==============================] - loss: 0.3553 - 246ms/step
    58. Eval begin...
    59. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    60. step 35/35 [==============================] - loss: 0.4708 - 229ms/step
    61. Eval samples: 1108
    62. Epoch 11/15
    63. step 197/197 [==============================] - loss: 0.3170 - 251ms/step
    64. Eval begin...
    65. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    66. step 35/35 [==============================] - loss: 0.3867 - 230ms/step
    67. Eval samples: 1108
    68. Epoch 12/15
    69. step 197/197 [==============================] - loss: 0.3067 - 246ms/step
    70. Eval begin...
    71. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    72. step 35/35 [==============================] - loss: 0.4145 - 229ms/step
    73. Eval samples: 1108
    74. Epoch 13/15
    75. step 197/197 [==============================] - loss: 0.3447 - 249ms/step
    76. Eval begin...
    77. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    78. step 35/35 [==============================] - loss: 0.4658 - 230ms/step
    79. Eval samples: 1108
    80. Epoch 14/15
    81. step 197/197 [==============================] - loss: 0.3662 - 249ms/step
    82. Eval begin...
    83. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    84. step 35/35 [==============================] - loss: 0.3955 - 237ms/step
    85. Eval samples: 1108
    86. Epoch 15/15
    87. step 197/197 [==============================] - loss: 0.3253 - 247ms/step
    88. Eval begin...
    89. The loss value printed in the log is the current batch, and the metric is the average value of previous step.
    90. step 35/35 [==============================] - loss: 0.4501 - 229ms/step
    91. Eval samples: 1108

    六、模型预测

    6.1 预测数据集准备和预测

    继续使用PetDataset来实例化待预测使用的数据集。这里我们为了方便没有在另外准备预测数据,复用了评估数据。

    我们可以直接使用model.predict接口来对数据集进行预测操作,只需要将预测数据集传递到接口内即可。

    1. predict_dataset = PetDataset(mode='predict')
    2. predict_results = model.predict(predict_dataset)
    1. Predict begin...
    2. step 1108/1108 [==============================] - 14ms/step

    6.2 预测结果可视化

    从我们的预测数据集中抽3个动物来看看预测的效果,展示一下原图、标签图和预测结果。