split

    切分指定操作的参数到多个设备,并且并行计算得到结果。

    当前,支持一下三种情形。

    情形1:并行Embedding

    Embedding操作的参数是个NxM的矩阵,行数为N,列数为M。并行Embedding情形下,参数切分到num_partitions个设备,每个设备上的参数是 (N/num_partitions + 1)行、M列的矩阵。其中,最后一行作为padding idx。

    假设将NxM的参数矩阵切分到两个设备device_0和device_1。那么每个设置上的参数矩阵为(N/2+1)行和M列。device_0上,输入x中的值如果介于[0, N/2-1],则其值保持不变;否则值变更为N/2,经过embedding映射为全0值。类似地,device_1上,输入x中的值V如果介于[N/2, N-1]之间,那么这些值将变更为(V-N/2);否则,值变更为N/2,经过embedding映射为全0值。最后,使用all_reduce_sum操作汇聚各个卡上的结果。

    Linear操作的参数是个NxM的矩阵,行数为N,列数为M。行并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N/num_partitions行、M列的矩阵。

    情形3:列并行Linear

    Linear操作的参数是个NxM的矩阵,行数为N,列数为M。列并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N行、M/num_partitions列的矩阵。

    Tensor