CTCLoss
该接口用于计算 CTC loss。该接口的底层调用了第三方 baidu-research::warp-ctc 的实现。 也可以叫做 softmax with CTC,因为 Warp-CTC 库中插入了 softmax 激活函数来对输入的值进行归一化。
log_probs (Tensor): - 经过 padding 的概率序列,其 shape 必须是 [max_logit_length, batch_size, num_classes + 1]。其中 max_logit_length 是最长输入序列的长度。该输入不需要经过 softmax 操作,因为该 OP 的内部对 input 做了 softmax 操作。数据类型仅支持float32。
labels (Tensor): - 经过 padding 的标签序列,其 shape 为 [batch_size, max_label_length],其中 max_label_length 是最长的 label 序列的长度。数据类型支持int32。
label_lengths (Tensor): - 表示 label 中每个序列的长度,shape为 [batch_size] 。数据类型支持int64。
Tensor
,输入 log_probs
和标签 labels
间的 ctc loss。如果 是 'none'
,则输出 loss 的维度为 [batch_size]。如果 reduction
是 'mean'
或 'sum'
, 则输出Loss的维度为 [1]。数据类型与输入 log_probs
一致。