center_loss

    • center_loss(input, label, num_classes, alpha, param_attr, update_center=True)[源代码]

    对于输入,(X)和标签(Y),计算公式为:

    • label (Variable) - 输入的标签,一个形状为为[N x 1]的2维张量,N表示batch size,数据类型为int32。
    • num_class (int32) - 输入类别的数量。
    • param_attr (ParamAttr) - 指定权重参数属性的对象。具体用法请参见 。
    • update_center (bool) - 是否更新类别中心的参数。

    返回:形状为[N x 1]的2维Tensor|LoDTensor。

    代码示例