新增Pass
CxxPredictor加载模型后,在执行预测前会先优化模型。模型优化过程是通过Pass实现的。 具体调用关系如下: 图片
CreatePredictor(CxxConfig)
函数调用了Predictor->Build(CxxConfig)- CxxPredictor的构建过程(Build)分为两步:
- Predictor->LoadModel() 加载模型文件到program中
- Predicotr->optimizer_.Run() 对Program中的原始图形结构进行优化
- 对图结构的优化是通过调用
Pass->Apply(const std::unique_ptr<SSAGraph>& graph)
方法实现的。
- 对图结构的优化是通过调用
- CxxPredictor的构建过程(Build)分为两步:
每一类Pass定义了一种优化过程,包括:原模型中的kernel选取、OP融合、冗余OP去除、子图创建、内存优化、类型推导、类型转换等。
代码位置:lite/core/mir/pass.h
主要类成员: const Kind kind_
: Pass类型。pass 有三种基本基本类型 :修改图结构的ProgramPass
、修改状态量的StmtPass
和Debug过程采集信息与控制可视化的DebugPass
。std::string name_
:pass 的名称 std::set<TargetType> bound_targets_
: Pass运行的硬件平台,optimizer.Run()优化过程会根据硬件平台选择匹配的Pass。———根据硬件平台自动选择需要的pass std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_
: Pass 绑定的kernel (what’s this used for) 主要接口: Pass::Apply(const std::unique_ptr& graph)
: Pass优化过程的具体操作,是新注册Pass需要实现的接口。输入为SSAGraph
型指针,是对模型结构的拓扑表示。
2、Pass管理 paddle::lite::mir::PassManager
class PassManager {
public:
// 内部静态变量PassManager,用来存储使用的Pass和图优化操作
static PassManager& Global() {
static PassManager x;
return x;
}
// 执行所有的 Pass
void Run(const std::unique_ptr<SSAGraph>& graph) {
for (auto& pass : passes_) {
LOG(INFO) << "Running MIR pass " << pass->name();
pass->Apply(graph);
}
private:
std::list<std::unique_ptr> passes_; //存储所有的 Pass
std::map<std::string, mir::Pass*> pass_map_; //使用map变量存储 PassName::Pass
}
代码位置:lite/core/mir/pass_manager.h
主要类成员: std::list:unique_ptr> passes_;
: List类型,存储了所有已注册Pass。 std::map<std::string, mir::Pass*> pass_map_;
: Map类型,存储了所有”Pass名称-Pass类”键对,用于根据名称查找Pass。
主要接口: static PassManager& Global()
返回PassManager全局静态变量,该变量存储了所有已注册的Pass bool AddNewPass(const std::string& name, Pass* pass)
添加新的Pass到PassManager中
代码位置:lite/core/mir/pass_registry.h
主要接口: REGISTER_MIR_PASS(name__, class__)
:宏定义函数,用于注册Pass。注册Pass过程实现的是 PassManager::Global().AddNewPass(name__, class__)
,将新注册Pass添加到全局变量PassManager
中。
1. Pass 注册流程
在lite/core/mir
或其子目录下继承Pass基类
,实现Pass::Apply
接口,并使用宏REGISTER_MIR_PASS(name__, class__)
将Pass注册到PassManager
即完成了新Pass注册。
**以新建 **new_demo_pass
为例,具体流程如下: (1)在lite/core/mir
路径下新建example_pass.cc
和 new_demo_pass.h
文件 (2)在example_pass.h
文件中继承Pass基类(ProgramPass、StmtPass或DebugPass)定义自己的Pass类。
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ExamplePass : public ProgramPass {
void Apply(const std::unique_ptr<SSAGraph> &graph) override {}
...
};
} // namespace mir
} // namespace lite
} // namespace paddle
(3)在example_pass.cc
文件中实现ExamplePass::Apply()
接口,并注册ExamplePass
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/example_pass.h"
namespace paddle {
namespace lite {
...
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(example_pass, paddle::lite::mir::ExamplePass)
.BindTargets({TARGET(kARM)}); // Pass执行的目标硬件平台
// .BindKernel("conv2d"); //Pass绑定的 kernel
lite_cc_library(mir_passes
SRCS
demo_pass.cc // 新建的Pass文件
...
memory_optimize_pass.cc
DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})
将Pass注册到PassManager后不会自动生效。需要在optimizer->run()
函数中添加该Pass才会在模型优化过程中调用。 (1)在paddle_use_passes.h
文件中调用该Pass
(2)要想在优化模型时调用该Pass,需要在optimizer->run()
函数中手动添加调用。
修改lite/core/optimizer.h
文件,添加new_demo_pass
到Optimizer::Run()
函数;
class Optimizer {
public:
void Run(...) {
...
if (passes.empty()) {
RunPasses(std::vector<std::string>{
{"new_demo_pass" //将新注册的Pass添加在这里
...
}
...
}
(3)只有CxxPredictor才会在模型加载后根据Pass优化模型。
...
#include "paddle_use_passes.h" // 引用Pass优化模型
void RunModel() {
// 1. 创建 CxxConfig
CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places(Place{TARGET(kARM), PRECISION(kFloat)});
// 2. 创建CxxPredictor,该过程包括加载模型和用Pass优化模型
std::shared_ptr> predictor =
Creat<CxxConfig>(config);
}
Fusion Pass
是一种常见图结构优化Pass,可将多个连续OP融合成单个等效OP,减少数据交换并简化图结构。Pass运行时调用Fuser
自动查找并替换指定图结构,所以注册FuserPass
时还需要实现对应的Fuser类。
下面以fc_fuse_pass
为例,详细说明FusionPass
的效果和注册方法。
fc_fuse_pass
的作用
将相邻的mul
算子和 element_wise add
算子 融合成一个 FC
算子
mul(X) = X * W
elementwise_add( mul(x) ) = X * W + Bias
//----------> after fusion
FC(X) = X * W +Bias
Pass 运行效果如下: 图片 mul和elementwise_add的原有参数映射到FC的参数上: 图片
1、创建FcFuser
(1)在lite/core/mir/fusion
路径下新建fc_fuser.cc
和 fc_fuser.h
文件 (2)在fc_fuser.h
文件中继承FuseBase
定义自己的Fuser类。
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class FcFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
} // namespace fusion
} // namespace mir
} // namespace paddle
对于 FcFuser
:BuildPattern描述的Pattern是mul+elementwise add
,GenOpDesc创建的FC_op,InsertNewNode函数的效果是用新建的FC_op
替换模型中的mul+elementwise add
pattern。
(3) 在fc_fuser.cc
文件中实现 BuildPattern()
、GenOpDesc()
、InsertNewNode()
接口
下面以FcFuser为例介绍三种接口的实现:
2、注册fc_fuse_pass
(1)在lite/core/mir/fusion
路径下新建fc_fuse_pass.cc
和 fc_fuse_pass.h
文件 (2)在fc_fuse_pass.h
文件中,继承ProgramPass
定义FcFusePass
。
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class FcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; namespace mir namespace lite namespace paddle
(3)在fc_fuse_pass.cc
文件中实现FcFusePass::Apply()
接口,并注册FcFusePass
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/example_pass.h"
namespace paddle {
namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser;
fuser(graph.get());namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)}) // FcFusePass 可以在任何硬件平台执行
.BindKernel("fc"); // FcFusePass 绑定 fc_kernel
(4)修改lite/core/mir/fusion/CMakeLists.txt
文件,将fc_fuser.cc
编译到mir_fusers
库
lite_cc_library(fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
...
CACHE INTERNAL "fusers")
(5)修改lite/core/mir/CMakeLists.txt
文件,将fc_fuse_pass.cc
编译到mir_pass
库
lite_cc_library(mir_passes
SRCS
fusion/fc_fuse_pass.cc
...
DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})
3、使用 fc_fuse_pass
(1) lite/api/paddle_use_passes.h
使用USE_LITE_PASS
宏来引入新加入的pass
(2) 在lite/core/optimizer.h
文件的Optimizer::Run()
函数中添加新注册的pass
class Optimizer {
public:
void Run(Program&& program,
const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) {
...
if (passes.empty()) {
RunPasses(std::vector<std::string>{
{"lite_fc_fuse_pass", // the newly registered pass
...
"argument_type_display_pass"}});
} else {
RunPasses(passes);
}
exec_scope_ = program.exec_scope();