scatter
scatter
(input, index, updates, name=None, overwrite=True)[源代码]
- 参数:
- input (Variable) - 支持任意纬度的Tensor。支持的数据类型为float32。
- index (Variable) - 表示索引,仅支持1-D Tensor。 支持的数据类型为int32,int64。
- updates (Variable) - 根据索引的值将updates Tensor中的对应值更新到input Tensor中,updates Tensor的维度需要和input tensor保持一致,且除了第一维外的其他的维度的大小需要和input Tensor保持相同。支持的数据类型为float32。
- name (str,可选) - 具体用法请参见 ,一般无需设置,默认值为None。
- overwrite (bool,可选) - 如果index中的索引值有重复且overwrite 为True,旧更新值将被新的更新值覆盖;如果为False,新的更新值将同旧的更新值相加。默认值为True。
- import paddle.fluid as fluid
-
- input = fluid.layers.data(name='data', shape=[3, 2], dtype='float32', append_batch_size=False)
- index = fluid.layers.data(name='index', shape=[4], dtype='int64', append_batch_size=False)
- updates = fluid.layers.data(name='update', shape=[4, 2], dtype='float32', append_batch_size=False)
-
- output = fluid.layers.scatter(input, index, updates, overwrite=False)
-
- exe = fluid.Executor(fluid.CPUPlace())
-
- in_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float32)
- index_data = np.array([2, 1, 0, 1]).astype(np.int64)
- update_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float32)
-
- res = exe.run(fluid.default_main_program(), feed={'data':in_data, "index":index_data, "update":update_data}, fetch_list=[output])
- print(res)
- # [array([[3., 3.],
- # [6., 6.],