cond

    如果 predTrue ,该API返回 true_fn() ,否则返回 false_fn() 。 用户如果不想在 callable 中做任何事,可以把 true_fnfalse_fn 设为 None ,此时本API会把该 callable 视为简单返回 None

    true_fnfalse_fn 需要返回同样嵌套结构(nest structure)的Tensor,如果不想返回任何值也可都返回 None 。 PaddlePaddle里Tensor的嵌套结构是指一个Tensor,或者Tensor的元组(tuple),或者Tensor的列表(list)。

    注解

    1. 因为PaddlePaddle的静态图数据流, true_fnfalse_fn 返回的元组必须形状相同,但是里面的Tensor形状可以不同。

    2. 不论运行哪个分支,在 true_fnfalse_fn 外创建的Tensor和Op都会被运行,即PaddlePaddle并不是惰性语法(lazy semantics)。例如

    参数:

    • pred (Tensor) - 一个形状为[1]的布尔型(boolean)的Tensor,该布尔值决定要返回 true_fn 还是 false_fn 的运行结果。

    • true_fn (callable) - 一个当 pred 是 时被调用的callable,默认值: None

    • false_fn (callable) - 一个当 predFalse 时被调用的callable,默认值: None

    • name (str,可选) – 具体用法请参见 ,一般无需设置,默认值: None

    如果 predTrue ,该API返回 true_fn() ,否则返回 false_fn()

    返回类型:Tensor|list(Tensor)|tuple(Tensor)

    抛出异常:

    • TypeError - 如果 true_fnfalse_fn 不是callable。

    • ValueError - 如果 true_fnfalse_fn 没有返回同样的嵌套结构(nest structure),对嵌套结构的解释见上文。

    1. import paddle
    2. #
    3. # pseudocode:
    4. # return 1, True
    5. # else:
    6. #
    7. def true_func():
    8. return paddle.fill_constant(shape=[1, 2], dtype='int32',
    9. value=1), paddle.fill_constant(shape=[2, 3],
    10. dtype='bool',
    11. value=True)
    12. return paddle.fill_constant(shape=[3, 4], dtype='float32',
    13. value=3), paddle.fill_constant(shape=[4, 5],
    14. dtype='int64',
    15. value=2)
    16. x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1)
    17. y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23)
    18. pred = paddle.less_than(x=x, y=y, name=None)
    19. ret = paddle.nn.cond(pred, true_func, false_func)
    20. # ret is a tuple containing 2 tensors
    21. # ret[0] = [[1 1]]
    22. # ret[1] = [[ True True True]