通过Sub-Pixel实现图像超分辨率

    日期: 2021.01

    摘要: 本示例教程使用U-Net实现图像分割。

    在计算机视觉中,图像超分辨率(Image Super Resolution)是指由一幅低分辨率图像或图像序列恢复出高分辨率图像。图像超分辨率技术分为超分辨率复原和超分辨率重建。

    本示例简要介绍如何通过飞桨开源框架,实现图像超分辨率。包括数据集的定义、模型的搭建与训练。

    参考论文:《Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network》

    论文链接:https://arxiv.org/abs/1609.05158

    二、环境设置

    本案例使用BSR_bsds500数据集,下载链接:

    1. !wget --no-check-certificate --no-cookies --header "Cookie: oraclelicense=accept-securebackup-cookie" http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz
    2. !tar -zxvf BSR_bsds500.tgz
    1. --2021-01-29 00:11:52-- http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz
    2. Resolving www.eecs.berkeley.edu (www.eecs.berkeley.edu)... 23.185.0.1, 2620:12a:8001::1, 2620:12a:8000::1
    3. Connecting to www.eecs.berkeley.edu (www.eecs.berkeley.edu)|23.185.0.1|:80... connected.
    4. HTTP request sent, awaiting response... 301 Moved Permanently
    5. Location: https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz [following]
    6. --2021-01-29 00:11:53-- https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz
    7. Connecting to www.eecs.berkeley.edu (www.eecs.berkeley.edu)|23.185.0.1|:443... connected.
    8. HTTP request sent, awaiting response... 301 Moved Permanently
    9. Location: https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz [following]
    10. --2021-01-29 00:11:53-- https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz
    11. Resolving www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)... 128.32.244.190
    12. Connecting to www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)|128.32.244.190|:443... connected.
    13. ^C
    14. tar (child): BSR_bsds500.tgz: Cannot open: No such file or directory
    15. tar (child): Error is not recoverable: exiting now
    16. tar: Child returned status 2
    17. tar: Error is not recoverable: exiting now

    3.2 数据集概览

    1. BSR
    2. ├── BSDS500
    3. └── data
    4. ├── groundTruth
    5. ├── test
    6. ├── train
    7. └── val
    8. └── images
    9. ├── test
    10. ├── train
    11. └── val
    12. ├── bench
    13. ├── benchmarks
    14. ├── data
    15. ├── ...
    16. └── ...
    17. └── source
    18. └── documentation

    可以看到我们需要的图片文件在BSR/BSDS500/images文件夹下,train、test各200张,val为100张。

    3.3 数据集类定义

    飞桨(PaddlePaddle)数据集加载方案是统一使用Dataset(数据集定义) + DataLoader(多进程数据集加载)。

    首先我们先进行数据集的定义,数据集定义主要是实现一个新的Dataset类,继承父类paddle.io.Dataset,并实现父类中以下两个抽象方法,__getitem__和len

    1. class MyDataset(Dataset):
    2. def __init__(self):
    3. ...
    4. # 每次迭代时返回数据和对应的标签
    5. def __getitem__(self, idx):
    6. return x, y
    7. # 返回整个数据集的总数
    8. def __len__(self):
    9. return count(samples)
    1. class BSD_data(Dataset):
    2. """
    3. 继承paddle.io.Dataset类
    4. """
    5. def __init__(self, mode='train',image_path="BSR/BSDS500/data/images/"):
    6. """
    7. 实现构造函数,定义数据读取方式,划分训练和测试数据集
    8. """
    9. super(BSD_data, self).__init__()
    10. self.mode = mode.lower()
    11. if self.mode == 'train':
    12. self.image_path = os.path.join(image_path,'train')
    13. elif self.mode == 'val':
    14. self.image_path = os.path.join(image_path,'val')
    15. else:
    16. raise ValueError('mode must be "train" or "val"')
    17. # 原始图像的缩放大小
    18. self.crop_size = 300
    19. # 缩放倍率
    20. self.upscale_factor = 3
    21. # 缩小后送入神经网络的大小
    22. self.input_size = self.crop_size // self.upscale_factor
    23. # numpy随机数种子
    24. self.seed=1337
    25. # 图片集合
    26. self.temp_images = []
    27. # 加载数据
    28. self._parse_dataset()
    29. def transforms(self, img):
    30. """
    31. 图像预处理工具,用于将升维(100, 100) => (100, 100,1),
    32. 并对图像的维度进行转换从HWC变为CHW
    33. """
    34. if len(img.shape) == 2:
    35. img = np.expand_dims(img, axis=2)
    36. return img.transpose((2, 0, 1))
    37. def __getitem__(self, idx):
    38. """
    39. 返回 缩小3倍后的图片 和 原始图片
    40. """
    41. # 加载原始图像
    42. img = self._load_img(self.temp_images[idx])
    43. # 将原始图像缩放到(3, 300, 300)
    44. img = img.resize([self.crop_size,self.crop_size], Image.BICUBIC)
    45. #转换为YCbCr图像
    46. ycbcr = img.convert("YCbCr")
    47. # 因为人眼对亮度敏感,所以只取Y通道
    48. y, cb, cr = ycbcr.split()
    49. y = np.asarray(y,dtype='float32')
    50. y = y / 255.0
    51. img_ = img.resize([self.input_size,self.input_size], Image.BICUBIC)
    52. ycbcr_ = img_.convert("YCbCr")
    53. y_, cb_, cr_ = ycbcr_.split()
    54. y_ = np.asarray(y_,dtype='float32')
    55. y_ = y_ / 255.0
    56. # 升纬并将HWC转换为CHW
    57. y = self.transforms(y)
    58. x = self.transforms(y_)
    59. # x为缩小3倍后的图片(1, 100, 100) y是原始图片(1, 300, 300)
    60. def __len__(self):
    61. """
    62. 实现__len__方法,返回数据集总数目
    63. """
    64. return len(self.temp_images)
    65. def _sort_images(self, img_dir):
    66. """
    67. 对文件夹内的图像进行按照文件名排序
    68. """
    69. files = []
    70. for item in os.listdir(img_dir):
    71. if item.split('.')[-1].lower() in ["jpg",'jpeg','png']:
    72. files.append(os.path.join(img_dir, item))
    73. return sorted(files)
    74. def _parse_dataset(self):
    75. """
    76. 处理数据集
    77. """
    78. self.temp_images = self._sort_images(self.image_path)
    79. random.Random(self.seed).shuffle(self.temp_images)
    80. def _load_img(self, path):
    81. """
    82. 从磁盘读取图片
    83. """
    84. with open(path, 'rb') as f:
    85. img = Image.open(io.BytesIO(f.read()))
    86. img = img.convert('RGB')
    87. return img

    实现好BSD_data数据集后,我们来测试一下数据集是否符合预期,因为BSD_data是一个可以被迭代的Class,我们通过for循环从里面读取数据进行展示。

    1. # 测试定义的数据集
    2. train_dataset = BSD_data(mode='train')
    3. val_dataset = BSD_data(mode='val')
    4. print('=============train dataset=============')
    5. x, y = train_dataset[0]
    6. x = x[0]
    7. y = y[0]
    8. x = x * 255
    9. y = y * 255
    10. img_ = Image.fromarray(np.uint8(x), mode="L")
    11. img = Image.fromarray(np.uint8(y), mode="L")
    12. display(img_)
    13. display(img_.size)
    14. display(img)
    15. display(img.size)

    1. (100, 100)
    1. (300, 300)

    四、模型组网

    Sub_Pixel_CNN是一个全卷积网络,网络结构比较简单,这里采用Layer类继承方式组网。

    1. class Sub_Pixel_CNN(paddle.nn.Layer):
    2. def __init__(self, upscale_factor=3, channels=1):
    3. super(Sub_Pixel_CNN, self).__init__()
    4. self.conv1 = paddle.nn.Conv2D(channels,64,5,stride=1, padding=2)
    5. self.conv2 = paddle.nn.Conv2D(64,64,3,stride=1, padding=1)
    6. self.conv3 = paddle.nn.Conv2D(64,32,3,stride=1, padding=1)
    7. self.conv4 = paddle.nn.Conv2D(32,channels * (upscale_factor ** 2),3,stride=1, padding=1)
    8. def forward(self, x):
    9. x = self.conv1(x)
    10. x = self.conv2(x)
    11. x = self.conv3(x)
    12. x = self.conv4(x)
    13. x = paddle.nn.functional.pixel_shuffle(x,3)
    14. return x

    4.1 模型封装

    1. # 模型封装
    2. model = paddle.Model(Sub_Pixel_CNN())

    4.2 模型可视化

    调用飞桨提供的summary接口对组建好的模型进行可视化,方便进行模型结构和参数信息的查看和确认。

    1. model.summary((1,1, 100, 100))
    1. ---------------------------------------------------------------------------
    2. Layer (type) Input Shape Output Shape Param #
    3. ===========================================================================
    4. Conv2D-1 [[1, 1, 100, 100]] [1, 64, 100, 100] 1,664
    5. Conv2D-2 [[1, 64, 100, 100]] [1, 64, 100, 100] 36,928
    6. Conv2D-3 [[1, 64, 100, 100]] [1, 32, 100, 100] 18,464
    7. Conv2D-4 [[1, 32, 100, 100]] [1, 9, 100, 100] 2,601
    8. ===========================================================================
    9. Total params: 59,657
    10. Trainable params: 59,657
    11. Non-trainable params: 0
    12. ---------------------------------------------------------------------------
    13. Input size (MB): 0.04
    14. Forward/backward pass size (MB): 12.89
    15. Params size (MB): 0.23
    16. Estimated Total Size (MB): 13.16
    17. ---------------------------------------------------------------------------
    1. {'total_params': 59657, 'trainable_params': 59657}

    使用模型代码进行Model实例生成,使用prepare接口定义优化器、损失函数和评价指标等信息,用于后续训练使用。在所有初步配置完成后,调用fit接口开启训练执行过程,调用fit时只需要将前面定义好的训练数据集、测试数据集、训练轮次(Epoch)和批次大小(batch_size)配置好即可。

    1. The loss value printed in the log is the current step, and the metric is the average value of previous step.
    2. Epoch 1/20
    3. step 13/13 [==============================] - loss: 0.2466 - 2s/step
    4. Epoch 2/20
    5. step 13/13 [==============================] - loss: 0.0802 - 2s/step
    6. Epoch 3/20
    7. step 13/13 [==============================] - loss: 0.0474 - 2s/step
    8. Epoch 4/20
    9. step 13/13 [==============================] - loss: 0.0340 - 2s/step
    10. Epoch 5/20
    11. step 13/13 [==============================] - loss: 0.0267 - 2s/step
    12. Epoch 6/20
    13. step 13/13 [==============================] - loss: 0.0179 - 2s/step
    14. Epoch 7/20
    15. step 13/13 [==============================] - loss: 0.0215 - 2s/step
    16. Epoch 8/20
    17. step 13/13 [==============================] - loss: 0.0162 - 2s/step
    18. Epoch 9/20
    19. step 13/13 [==============================] - loss: 0.0137 - 2s/step
    20. Epoch 10/20
    21. step 13/13 [==============================] - loss: 0.0099 - 2s/step
    22. step 13/13 [==============================] - loss: 0.0074 - 2s/step
    23. Epoch 12/20
    24. step 13/13 [==============================] - loss: 0.0117 - 2s/step
    25. Epoch 13/20
    26. step 13/13 [==============================] - loss: 0.0065 - 2s/step
    27. step 13/13 [==============================] - loss: 0.0086 - 2s/step
    28. Epoch 15/20
    29. step 13/13 [==============================] - loss: 0.0085 - 2s/step
    30. Epoch 16/20
    31. step 13/13 [==============================] - loss: 0.0067 - 2s/step
    32. Epoch 17/20
    33. step 13/13 [==============================] - loss: 0.0068 - 2s/step
    34. Epoch 18/20
    35. step 13/13 [==============================] - loss: 0.0044 - 2s/step
    36. Epoch 19/20
    37. step 13/13 [==============================] - loss: 0.0069 - 2s/step
    38. Epoch 20/20
    39. step 13/13 [==============================] - loss: 0.0087 - 2s/step

    六、模型预测

    6.1 预测

    我们可以直接使用model.predict接口来对数据集进行预测操作,只需要将预测数据集传递到接口内即可。

    1. predict_results = model.predict(val_dataset)
    1. Predict begin...
    2. step 100/100 [==============================] - 38ms/step
    3. Predict samples: 100

    6.2 定义预测结果可视化函数

    1. import math
    2. import matplotlib.pyplot as plt
    3. from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
    4. from mpl_toolkits.axes_grid1.inset_locator import mark_inset
    5. def psnr(img1, img2):
    6. """
    7. PSMR计算函数
    8. """
    9. mse = np.mean( (img1/255. - img2/255.) ** 2 )
    10. if mse < 1.0e-10:
    11. return 100
    12. PIXEL_MAX = 1
    13. return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
    14. def plot_results(img, title='results', prefix='out'):
    15. """
    16. 画图展示函数
    17. """
    18. img_array = np.asarray(img, dtype='float32')
    19. img_array = img_array.astype("float32") / 255.0
    20. fig, ax = plt.subplots()
    21. im = ax.imshow(img_array[::-1], origin="lower")
    22. plt.title(title)
    23. axins = zoomed_inset_axes(ax, 2, loc=2)
    24. axins.imshow(img_array[::-1], origin="lower")
    25. x1, x2, y1, y2 = 200, 300, 100, 200
    26. axins.set_xlim(x1, x2)
    27. axins.set_ylim(y1, y2)
    28. plt.yticks(visible=False)
    29. plt.xticks(visible=False)
    30. mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
    31. plt.savefig(str(prefix) + "-" + title + ".png")
    32. plt.show()
    33. def get_lowres_image(img, upscale_factor):
    34. """
    35. 缩放图片
    36. """
    37. return img.resize(
    38. (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
    39. Image.BICUBIC,
    40. )
    41. def upscale_image(model, img):
    42. '''
    43. 输入小图,返回上采样三倍的大图像
    44. '''
    45. # 把图片复转换到YCbCr格式
    46. ycbcr = img.convert("YCbCr")
    47. y, cb, cr = ycbcr.split()
    48. y = np.asarray(y, dtype='float32')
    49. y = y / 255.0
    50. img = np.expand_dims(y, axis=0) # 升维度到(1,w,h)一张image
    51. img = np.expand_dims(img, axis=0) # 升维度到(1,1,w,h)一个batch
    52. img = np.expand_dims(img, axis=0) # 升维度到(1,1,1,w,h)可迭代的batch
    53. out = model.predict(img) # predict输入要求为可迭代的batch
    54. out_img_y = out[0][0][0] # 得到predict输出结果
    55. out_img_y *= 255.0
    56. # 把图片复转换回RGB格式
    57. out_img_y = out_img_y.reshape((np.shape(out_img_y)[1], np.shape(out_img_y)[2]))
    58. out_img_y = Image.fromarray(np.uint8(out_img_y), mode="L")
    59. out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
    60. out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
    61. out_img = Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
    62. "RGB"
    63. )
    64. return out_img
    65. def main(model, img, upscale_factor=3):
    66. # 读取图像
    67. with open(img, 'rb') as f:
    68. img = Image.open(io.BytesIO(f.read()))
    69. # 缩小三倍
    70. lowres_input = get_lowres_image(img, upscale_factor)
    71. w = lowres_input.size[0] * upscale_factor
    72. h = lowres_input.size[1] * upscale_factor
    73. # 将缩小后的图片再放大三倍
    74. lowres_img = lowres_input.resize((w, h))
    75. # 确保未经缩放的图像和其他两张图片大小一致
    76. highres_img = img.resize((w, h))
    77. # 得到缩小后又经过 Efficient Sub-Pixel CNN放大的图片
    78. prediction = upscale_image(model, lowres_input)
    79. psmr_low = psnr(np.asarray(lowres_img), np.asarray(highres_img))
    80. psmr_pre = psnr(np.asarray(prediction), np.asarray(highres_img))
    81. # 展示三张图片
    82. plot_results(lowres_img, "lowres")
    83. plot_results(highres_img, "highres")
    84. plot_results(prediction, "prediction")
    85. print("psmr_low:", psmr_low, "psmr_pre:", psmr_pre)

    从我们的预测数据集中抽1个张图片来看看预测的效果,展示一下原图、小图和预测结果。

    1. main(model,'BSR/BSDS500/data/images/test/100007.jpg')
    1. Predict begin...
    2. step 1/1 [==============================] - 75ms/step

    ../../../_images/super_resolution_sub_pixel_26_1.png ../../../_images/super_resolution_sub_pixel_26_3.png

    1. psmr_low: 30.381882136539197 psmr_pre: 29.074438702896636

    7.模型保存