自定义Task

    当自定义一个Task时,我们并不需要重新实现eval、finetune等通用接口。一般来讲,新的Task与其他Task的区别在于

    • 网络结构

    • 评估指标

    这两者的差异可以通过重载BasicTask的组网事件和运行事件来实现

    进行前向网络组网的函数,用户需要自定义实现该函数,函数需要返回对应预测结果的Variable list

    _add_label

    添加label的函数,用户需要自定义实现该函数,函数需要返回对应输入label的Variable list

    1. # 代码示例
    2. return [fluid.layers.data(name="label", dtype="int64", shape=[1])]

    _add_metrics

    添加度量指标的函数,用户需要自定义实现该函数,函数需要返回对应度量指标的Variable list

    1. # 代码示例
    2. def _add_metrics(self):
    3. return [fluid.layers.accuracy(input=self.outputs[0], label=self.label)]

    运行事件

    BasicTask定义了一系列的运行时回调事件,在特定的时机时触发对应的事件,在自定的Task中,通过重载实现对应的回调函数,用户可以实现所需的功能

    _build_env_start_event

    当需要进行一个新的运行环境构建时,该事件被触发。通过重载实现该函数,用户可以在一个环境开始构建前进行对应操作,例如写日志

    1. # 代码示例
    2. logger.info("Start to build env {}".format(self.phase))

    _build_env_end_event

    当一个新的运行环境构建完成时,该事件被触发。通过继承实现该函数,用户可以在一个环境构建结束后进行对应操作,例如写日志

    1. # 代码示例
    2. def _build_env_end_event(self):
    3. logger.info("End of build env {}".format(self.phase))

    当开始一次finetune时,该事件被触发。通过继承实现该函数,用户可以在开始一次finetune操作前进行对应操作,例如写日志

    _finetune_end_event

    1. # 代码示例
    2. def _finetune_end_event(self):
    3. logger.info("PaddleHub finetune finished.")

    _eval_start_event

    当开始一次evaluate时,该事件被触发。通过继承实现该函数,用户可以在开始一次evaluate操作前进行对应操作,例如写日志

    1. # 代码示例
    2. def _eval_start_event(self):
    3. logger.info("Evaluation on {} dataset start".format(self.phase))

    _eval_end_event

    当结束一次evaluate时,该事件被触发。通过继承实现该函数,用户可以在完成一次evaluate操作后进行对应操作,例如计算运行速度、评估指标等

    1. # 代码示例
    2. def _eval_end_event(self, run_states):
    3. run_step = 0
    4. for run_state in run_states:
    5. run_step += run_state.run_step
    6. run_time_used = time.time() - run_states[0].run_time_begin
    7. logger.info("[%s dataset evaluation result] [step/sec: %.2f]" %
    8. (self.phase, run_speed))
    • : 一个list对象,list中的每一个元素都是RunState对象,该list包含了整个评估过程的状态数据。

    _predict_start_event

    当开始一次predict时,该事件被触发。通过继承实现该函数,用户可以在开始一次predict操作前进行对应操作,例如写日志

    1. # 代码示例
    2. def _predict_start_event(self):
    3. logger.info("PaddleHub predict start")

    当结束一次predict时,该事件被触发。通过继承实现该函数,用户可以在结束一次predict操作后进行对应操作,例如写日志

    _log_interval_event

    调用finetune 或者 finetune_and_eval接口时,每当命中用户设置的日志打印周期时()。通过继承实现该函数,用户可以在finetune过程中定期打印所需数据,例如计算运行速度、loss、准确率等

    1. # 代码示例
    2. def _log_interval_event(self, run_states):
    3. avg_loss, avg_acc, run_speed = self._calculate_metrics(run_states)
    4. self.env.loss_scalar.add_record(self.current_step, avg_loss)
    5. self.env.acc_scalar.add_record(self.current_step, avg_acc)
    6. logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" %
    7. (self.current_step, avg_loss, avg_acc, run_speed))
    • run_states: 一个list对象,list中的每一个元素都是RunState对象,该list包含了整个从上一次该事件被触发到本次被触发的状态数据

    _save_ckpt_interval_event

    调用finetune 或者 finetune_and_eval接口时,每当命中用户设置的保存周期时(),该事件被触发。通过继承实现该函数,用户可以在定期保存checkpoint

    1. # 代码示例
    2. def _save_ckpt_interval_event(self):
    3. self.save_checkpoint(self.current_epoch, self.current_step)

    _eval_interval_event

    调用finetune_and_eval接口时,每当命中用户设置的评估周期时(),该事件被触发。通过继承实现该函数,用户可以实现自定义的评估指标计算

    1. # 代码示例
    2. def _eval_interval_event(self):
    3. self.eval(phase="dev")

    _run_step_event

    1. # 代码示例
    2. def _run_step_event(self, run_state):