SampleEmbeddingHelper

    SampleEmbeddingHelper是 的子类。作为解码helper,它通过采样而非使用 argmax 并将采样结果送入embedding层,以此作为下一解码步的输入。

    参数:

    • embedding_fn (callable) - 作用于 argmax 结果的函数,通常是一个将词id转换为词嵌入的embedding层,注意 ,这里要使用 cn_api_fluid_embedding 而非 embedding,因为选中的id的形状是

      ,如果使用后者则还需要在这里提供unsqueeze。

    • start_tokens (Variable) - 形状为

      、数据类型为int64、 值为起始标记id的tensor。

    • softmax_temperature (float,可选) - 该值用于在softmax计算前除以logits。温度越高(大于1.0)随机性越大,温度越低则越趋向于argmax。该值必须大于0,默认值None等同于1.0。

    • seed (int,可选) - 采样使用的随机种子。默认为None,表示不使用固定的随机种子。

    示例代码

    ( time, outputs, states )

    参数:

    • outputs (Variable) - tensor变量,通常其数据类型为float32或float64,形状为 [batch_size,vocabulary_size][batch_size,vocabulary_size] ,表示当前解码步预测产生的logit(未归一化的概率),和由 BasicDecoder.output_fn(BasicDecoder.cell.call()) 返回的 是同一内容。

    • states (Variable) - 单个tensor变量或tensor变量组成的嵌套结构,和由 BasicDecoder.cell.call() 返回的 new_states 是同一内容。

    返回:数据类型为int64形状为 [batch_size][batch_size] 的tensor,表示采样得到的id。

    返回类型:Variable