BaseTransform

    视觉中图像变化的基类。

    调用逻辑:

    如果你想要定义自己的图像变化方法, 需要重写子类中的 _apply_* 方法。

    1. import numpy as np
    2. from PIL import Image
    3. import paddle.vision.transforms.functional as F
    4. from paddle.vision.transforms import BaseTransform
    5. def _get_image_size(img):
    6. if F._is_pil_image(img):
    7. return img.size
    8. elif F._is_numpy_image(img):
    9. return img.shape[:2][::-1]
    10. else:
    11. raise TypeError("Unexpected type {}".format(type(img)))
    12. class CustomRandomFlip(BaseTransform):
    13. def __init__(self, prob=0.5, keys=None):
    14. super(CustomRandomFlip, self).__init__(keys)
    15. def _get_params(self, inputs):
    16. image = inputs[self.keys.index('image')]
    17. params = {}
    18. params['flip'] = np.random.random() < self.prob
    19. params['size'] = _get_image_size(image)
    20. return params
    21. def _apply_image(self, image):
    22. if self.params['flip']:
    23. return F.hflip(image)
    24. return image
    25. # if you only want to transform image, do not need to rewrite this function
    26. def _apply_coords(self, coords):
    27. if self.params['flip']:
    28. w = self.params['size'][0]
    29. coords[:, 0] = w - coords[:, 0]
    30. return coords
    31. # if you only want to transform image, do not need to rewrite this function
    32. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
    33. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
    34. minxy = coords.min(axis=1)
    35. maxxy = coords.max(axis=1)
    36. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
    37. return trans_boxes
    38. # if you only want to transform image, do not need to rewrite this function
    39. def _apply_mask(self, mask):
    40. if self.params['flip']:
    41. return F.hflip(mask)
    42. return mask
    43. # create fake inputs
    44. fake_img = Image.fromarray((np.random.rand(400, 500, 3) * 255.).astype('uint8'))
    45. fake_boxes = np.array([[2, 3, 200, 300], [50, 60, 80, 100]])
    46. fake_mask = fake_img.convert('L')
    47. # only transform for image:
    48. flip_transform = CustomRandomFlip(1.0)
    49. converted_img = flip_transform(fake_img)
    50. # transform for image, boxes and mask
    51. flip_transform = CustomRandomFlip(1.0, keys=('image', 'boxes', 'mask'))
    52. print('converted boxes', converted_boxes)