自定义层需要继承 tf.keras.layers.Layer
类,并重写 init
、 build
和 call
三个方法,如下所示:
例如,如果我们要自己实现一个 中的全连接层( ),可以按如下方式编写。此代码在 build
方法中创建两个变量,并在 call
方法中使用创建的变量进行运算:
自定义损失函数和评估指标
自定义损失函数需要继承 tf.keras.losses.Loss
类,重写 call
方法即可,输入真实值 和模型预测值 y_pred
,输出模型预测值和真实值之间通过自定义的损失函数计算出的损失值。下面的示例为均方差损失函数:
自定义评估指标需要继承 tf.keras.metrics.Metric
类,并重写 init
、 update_state
和 result
三个方法。下面的示例对前面用到的 评估指标类做了一个简单的重实现:
- LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278-2324, November 1998.
Graves, Alex. “Generating Sequences With Recurrent Neural Networks.” ArXiv:1308.0850 [Cs], August 4, 2013. http://arxiv.org/abs/1308.0850.