WMT16

    1. import paddle
    2. from paddle.text.datasets import WMT16
    3. class SimpleNet(paddle.nn.Layer):
    4. super(SimpleNet, self).__init__()
    5. def forward(self, src_ids, trg_ids, trg_ids_next):
    6. return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
    7. wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
    8. src_ids = paddle.to_tensor(src_ids)
    9. trg_ids = paddle.to_tensor(trg_ids)
    10. trg_ids_next = paddle.to_tensor(trg_ids_next)
    11. model = SimpleNet()
    12. src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
    13. print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())