自定义指标

    有时你会遇到特定任务的Loss计算方式在框架既有的Loss接口中不存在,或算法不符合自己的需求,那么期望能够自己来进行Loss的自定义。这里介绍如何进行Loss的自定义操作,首先来看下面的代码:

    1. def __init__(self):
    2. super(SoftmaxWithCrossEntropy, self).__init__()
    3. def forward(self, input, label):
    4. loss = F.softmax_with_cross_entropy(input,
    5. label,
    6. return_softmax=False,
    7. axis=1)
    8. return paddle.mean(loss)

    和Loss一样,你也可以来通过框架实现自定义的评估方法,具体的实现如下:

    1. from paddle.metric import Metric
    2. class Precision(Metric):
    3. """
    4. Precision (also called positive predictive value) is the fraction of
    5. relevant instances among the retrieved instances. Refer to
    6. https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
    7. Noted that this class manages the precision score only for binary
    8. classification task.
    9. ......
    10. """
    11. def __init__(self, name='precision', *args, **kwargs):
    12. super(Precision, self).__init__(*args, **kwargs)
    13. self.tp = 0 # true positive
    14. self.fp = 0 # false positive
    15. self._name = name
    16. def update(self, preds, labels):
    17. """
    18. Update the states based on the current mini-batch prediction results.
    19. preds (numpy.ndarray): The prediction result, usually the output
    20. of two-class sigmoid function. It should be a vector (column
    21. labels (numpy.ndarray): The ground truth (labels),
    22. the shape should keep the same as preds.
    23. The data type is 'int32' or 'int64'.
    24. """
    25. if isinstance(preds, paddle.Tensor):
    26. preds = preds.numpy()
    27. elif not _is_numpy_(preds):
    28. raise ValueError("The 'preds' must be a numpy ndarray or Tensor.")
    29. if isinstance(labels, paddle.Tensor):
    30. labels = labels.numpy()
    31. elif not _is_numpy_(labels):
    32. raise ValueError("The 'labels' must be a numpy ndarray or Tensor.")
    33. sample_num = labels.shape[0]
    34. preds = np.floor(preds + 0.5).astype("int32")
    35. for i in range(sample_num):
    36. pred = preds[i]
    37. label = labels[i]
    38. if pred == 1:
    39. if pred == label:
    40. self.tp += 1
    41. else:
    42. self.fp += 1
    43. def reset(self):
    44. """
    45. Resets all of the metric state.
    46. """
    47. self.tp = 0
    48. self.fp = 0
    49. """
    50. Returns:
    51. A scaler float: results of the calculated precision.
    52. """
    53. ap = self.tp + self.fp
    54. return float(self.tp) / ap if ap != 0 else .0
    55. def name(self):
    56. """
    57. Returns metric name
    58. """
    59. return self._name

    fit接口的callback参数支持传入一个`` Callback``类实例,用来在每轮训练和每个`` batch``训练前后进行调用,可以通过`` callback``收集到训练过程中的一些数据和参数,或者实现一些自定义操作。

    1. class ModelCheckpoint(Callback):
    2. def __init__(self, save_freq=1, save_dir=None):
    3. self.save_freq = save_freq
    4. self.save_dir = save_dir
    5. def on_epoch_begin(self, epoch=None, logs=None):
    6. self.epoch = epoch
    7. def _is_save(self):
    8. return self.model and self.save_dir and ParallelEnv().local_rank == 0
    9. def on_epoch_end(self, epoch, logs=None):
    10. if self._is_save() and self.epoch % self.save_freq == 0:
    11. path = '{}/{}'.format(self.save_dir, epoch)
    12. print('save checkpoint at {}'.format(os.path.abspath(path)))
    13. self.model.save(path)
    14. def on_train_end(self, logs=None):
    15. if self._is_save():
    16. path = '{}/final'.format(self.save_dir)
    17. print('save checkpoint at {}'.format(os.path.abspath(path)))