模型组网
飞桨框架2.0中,组网相关的API都在paddle.nn
目录下,你可以通过 Sequential
或 SubClass
的方式构建具体的模型。组网相关的API类别与具体的API列表如下表:
针对顺序的线性网络结构你可以直接使用Sequential来快速完成组网,可以减少类的定义等代码编写。具体代码如下:
针对一些比较复杂的网络结构,就可以使用Layer子类定义的方式来进行模型代码编写,在__init__
构造函数中进行组网Layer的声明,在forward
中使用声明的Layer变量进行前向计算。子类组网方式也可以实现sublayer的复用,针对相同的layer可以在构造函数中一次性定义,在forward中多次调用。
# Layer类继承方式组网
class Mnist(paddle.nn.Layer):
def __init__(self):
self.flatten = paddle.nn.Flatten()
self.linear_1 = paddle.nn.Linear(784, 512)
self.linear_2 = paddle.nn.Linear(512, 10)
self.relu = paddle.nn.ReLU()
self.dropout = paddle.nn.Dropout(0.2)
def forward(self, inputs):
y = self.flatten(inputs)
y = self.dropout(y)
y = self.linear_2(y)
return y
mnist_2 = Mnist()
你除了可以通过上述方式组建模型外,还可以使用飞桨框架内置的模型,路径为 paddle.vision.models
,具体列表如下:
飞桨框架内置模型: ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet']
使用方式如下:
你可以通过paddle.summary()
方法查看模型的结构与每一层输入输出形状,具体如下: