空间变换器网络教程

    作者: Ghassen HAMROUNI

    在本教程中,您将学习如何使用称为空间变换器网络的视觉注意机制来扩充您的网络。你可以在 阅读有关空间变换器网络的更多内容。

    空间变换器网络是对任何空间变换的差异化关注的概括。空间变换器网络(简称STN)允许神经网络学习如何在输入图像上执行空间变换,以增强模型的几何不变性。例如,它可以裁剪感兴趣的区域,缩放并校正图像的方向。它可能是一种有用的机制,因为CNN对于旋转和缩放以及更一般的仿射变换并不是不变的。关于STN的最棒的事情之一是能够简单地将其插入任何现有的CNN,只需很少的修改。

    1. # Training dataset
    2. train_loader = torch.utils.data.DataLoader(
    3. datasets.MNIST(root='.', train=True, download=True,
    4. transform=transforms.Compose([
    5. transforms.ToTensor(),
    6. transforms.Normalize((0.1307,), (0.3081,))
    7. ])), batch_size=64, shuffle=True, num_workers=4)
    8. # Test dataset
    9. test_loader = torch.utils.data.DataLoader(
    10. datasets.MNIST(root='.', train=False, transform=transforms.Compose([
    11. transforms.ToTensor(),
    12. transforms.Normalize((0.1307,), (0.3081,))
    13. ])), batch_size=64, shuffle=True, num_workers=4)

    输出:

    空间变换器网络归结为三个主要组成部分:

    • 本地网络(Localisation Network)是常规CNN,其对变换参数进行回归。不会从该数据集中明确地学习转换,而是网络自动学习增强全局准确性的空间变换。
    • 网格生成器( Grid Genator)在输入图像中生成与输出图像中的每个像素相对应的坐标网格。
    • 采样器(Sampler)使用变换的参数并将其应用于输入图像。

    https://pytorch.org/tutorials/_images/stn-arch.png

    我们使用最新版本的Pytorch,它应该包含affine_grid和grid_sample模块。

    1. class Net(nn.Module):
    2. def __init__(self):
    3. super(Net, self).__init__()
    4. self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    5. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    6. self.conv2_drop = nn.Dropout2d()
    7. self.fc1 = nn.Linear(320, 50)
    8. self.fc2 = nn.Linear(50, 10)
    9. # Spatial transformer localization-network
    10. self.localization = nn.Sequential(
    11. nn.Conv2d(1, 8, kernel_size=7),
    12. nn.MaxPool2d(2, stride=2),
    13. nn.ReLU(True),
    14. nn.Conv2d(8, 10, kernel_size=5),
    15. nn.MaxPool2d(2, stride=2),
    16. nn.ReLU(True)
    17. )
    18. # Regressor for the 3 * 2 affine matrix
    19. self.fc_loc = nn.Sequential(
    20. nn.Linear(10 * 3 * 3, 32),
    21. nn.ReLU(True),
    22. )
    23. # Initialize the weights/bias with identity transformation
    24. self.fc_loc[2].weight.data.zero_()
    25. self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
    26. # Spatial transformer network forward function
    27. def stn(self, x):
    28. xs = self.localization(x)
    29. xs = xs.view(-1, 10 * 3 * 3)
    30. theta = self.fc_loc(xs)
    31. theta = theta.view(-1, 2, 3)
    32. grid = F.affine_grid(theta, x.size())
    33. x = F.grid_sample(x, grid)
    34. return x
    35. def forward(self, x):
    36. # transform the input
    37. x = self.stn(x)
    38. # Perform the usual forward pass
    39. x = F.relu(F.max_pool2d(self.conv1(x), 2))
    40. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    41. x = x.view(-1, 320)
    42. x = F.relu(self.fc1(x))
    43. x = F.dropout(x, training=self.training)
    44. x = self.fc2(x)
    45. return F.log_softmax(x, dim=1)
    46. model = Net().to(device)

    现在我们使用SGD(随机梯度下降)算法来训练模型。网络正在以有监督的方式学习分类任务。同时,该模型以端到端的方式自动学习STN。

    现在,我们将检查我们学习的视觉注意机制的结果。我们定义了一个小辅助函数,以便在训练时可视化变换。

    1. def convert_image_np(inp):
    2. """Convert a Tensor to numpy image."""
    3. inp = inp.numpy().transpose((1, 2, 0))
    4. std = np.array([0.229, 0.224, 0.225])
    5. inp = np.clip(inp, 0, 1)
    6. return inp
    7. # We want to visualize the output of the spatial transformers layer
    8. # after the training, we visualize a batch of input images and
    9. # the corresponding transformed batch using STN.
    10. def visualize_stn():
    11. with torch.no_grad():
    12. # Get a batch of training data
    13. data = next(iter(test_loader))[0].to(device)
    14. input_tensor = data.cpu()
    15. transformed_input_tensor = model.stn(data).cpu()
    16. in_grid = convert_image_np(
    17. torchvision.utils.make_grid(input_tensor))
    18. out_grid = convert_image_np(
    19. torchvision.utils.make_grid(transformed_input_tensor))
    20. # Plot the results side-by-side
    21. f, axarr = plt.subplots(1, 2)
    22. axarr[0].imshow(in_grid)
    23. axarr[0].set_title('Dataset Images')
    24. axarr[1].imshow(out_grid)
    25. axarr[1].set_title('Transformed Images')
    26. for epoch in range(1, 20 + 1):
    27. train(epoch)
    28. test()
    29. # Visualize the STN transformation on some input batch
    30. visualize_stn()
    31. plt.ioff()

    输出: