ExponentialMovingAverage
- class
ExponentialMovingAverage
(decay=0.999, thres_steps=None, name=None)[源代码]
,它的指数滑动平均值 (exponential moving average, EMA) 为
用 update()
方法计算出的平均结果将保存在由实例化对象创建和维护的临时变量中,并且可以通过调用 apply()
方法把结果应用于当前模型的参数。同时,可用 方法恢复原始参数。
,因此它们相对于零是有偏的,可以通过除以因子
来校正,因此在调用
apply()
方法时,作用于参数的真实滑动平均值将为:
衰减率调节 一个非常接近于1的很大的衰减率将会导致平均值滑动得很慢。更优的策略是,开始时设置一个相对较小的衰减率。参数 thres_steps
允许用户传递一个变量以设置衰减率,在这种情况下, 真实的衰减率变为 :
通常 thres_steps
可以是全局的训练迭代步数。
- 参数:
- decay (float) – 指数衰减率,通常接近1,如0.999,0.9999,……
- thres_steps (Variable, 可选) – 调节衰减率的阈值步数,默认值为 None。
代码示例
- ()
更新指数滑动平均,在训练过程中需调用此方法。
apply
(executor, need_restore=True)
- 参数:
- need_restore (bool) –是否在结束后恢复原始参数,默认值为
True
。
- need_restore (bool) –是否在结束后恢复原始参数,默认值为
restore
(executor)
恢复参数。
- 参数:
- executor (Executor) – 执行恢复动作的执行器。