6-5,使用TPU训练模型

    在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 TPU

    可通过以下colab链接测试效果《tf_TPU》:

    1. BATCH_SIZE = 32
    2. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
    3. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
    4. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
    5. MAX_WORDS = x_train.max()+1
    6. CAT_NUM = y_train.max()+1
    7. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
    8. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
    9. .prefetch(tf.data.experimental.AUTOTUNE).cache()
    10. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
    11. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
    12. .prefetch(tf.data.experimental.AUTOTUNE).cache()
    1. tf.keras.backend.clear_session()
    2. def create_model():
    3. model = models.Sequential()
    4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    6. model.add(layers.MaxPool1D(2))
    7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    8. model.add(layers.MaxPool1D(2))
    9. model.add(layers.Flatten())
    10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    11. return(model)
    12. def compile_model(model):
    13. model.compile(optimizer=optimizers.Nadam(),
    14. loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
    16. return(model)
    1. #增加以下6行代码
    2. import os
    3. resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
    4. tf.config.experimental_connect_to_cluster(resolver)
    5. tf.tpu.experimental.initialize_tpu_system(resolver)
    6. strategy = tf.distribute.experimental.TPUStrategy(resolver)
    7. with strategy.scope():
    8. model.summary()
    1. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
    2. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
    3. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
    4. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
    5. INFO:tensorflow:Clearing out eager caches
    6. INFO:tensorflow:Clearing out eager caches
    7. INFO:tensorflow:Finished initializing TPU system.
    8. INFO:tensorflow:Finished initializing TPU system.
    9. INFO:tensorflow:Found TPU system:
    10. INFO:tensorflow:Found TPU system:
    11. INFO:tensorflow:*** Num TPU Cores: 8
    12. INFO:tensorflow:*** Num TPU Cores: 8
    13. INFO:tensorflow:*** Num TPU Workers: 1
    14. INFO:tensorflow:*** Num TPU Workers: 1
    15. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
    16. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
    17. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
    18. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
    19. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
    20. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
    21. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
    22. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
    23. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
    24. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
    25. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
    26. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
    27. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
    28. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
    29. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
    30. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
    31. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
    32. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
    33. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
    34. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
    35. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
    36. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
    37. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
    38. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
    39. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
    40. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
    41. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
    42. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
    43. Model: "sequential"
    44. Layer (type) Output Shape Param #
    45. =================================================================
    46. embedding (Embedding) (None, 300, 7) 216874
    47. _________________________________________________________________
    48. conv1d (Conv1D) (None, 296, 64) 2304
    49. _________________________________________________________________
    50. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
    51. _________________________________________________________________
    52. conv1d_1 (Conv1D) (None, 146, 32) 6176
    53. _________________________________________________________________
    54. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
    55. _________________________________________________________________
    56. flatten (Flatten) (None, 2336) 0
    57. _________________________________________________________________
    58. dense (Dense) (None, 46) 107502
    59. =================================================================
    60. Total params: 332,856
    61. Trainable params: 332,856
    62. Non-trainable params: 0
    63. _________________________________________________________________
    1. Train for 281 steps, validate for 71 steps
    2. Epoch 1/10
    3. 281/281 [==============================] - 12s 43ms/step - loss: 3.4466 - sparse_categorical_accuracy: 0.4332 - sparse_top_k_categorical_accuracy: 0.7180 - val_loss: 3.3179 - val_sparse_categorical_accuracy: 0.5352 - val_sparse_top_k_categorical_accuracy: 0.7195
    4. Epoch 2/10
    5. 281/281 [==============================] - 6s 20ms/step - loss: 3.3251 - sparse_categorical_accuracy: 0.5405 - sparse_top_k_categorical_accuracy: 0.7302 - val_loss: 3.3082 - val_sparse_categorical_accuracy: 0.5463 - val_sparse_top_k_categorical_accuracy: 0.7235
    6. Epoch 3/10
    7. 281/281 [==============================] - 6s 20ms/step - loss: 3.2961 - sparse_categorical_accuracy: 0.5729 - sparse_top_k_categorical_accuracy: 0.7280 - val_loss: 3.3026 - val_sparse_categorical_accuracy: 0.5499 - val_sparse_top_k_categorical_accuracy: 0.7217
    8. Epoch 4/10
    9. 281/281 [==============================] - 5s 19ms/step - loss: 3.2751 - sparse_categorical_accuracy: 0.5924 - sparse_top_k_categorical_accuracy: 0.7276 - val_loss: 3.2957 - val_sparse_categorical_accuracy: 0.5543 - val_sparse_top_k_categorical_accuracy: 0.7217
    10. Epoch 5/10
    11. 281/281 [==============================] - 5s 19ms/step - loss: 3.2655 - sparse_categorical_accuracy: 0.6008 - sparse_top_k_categorical_accuracy: 0.7290 - val_loss: 3.3022 - val_sparse_categorical_accuracy: 0.5490 - val_sparse_top_k_categorical_accuracy: 0.7231
    12. Epoch 6/10
    13. 281/281 [==============================] - 5s 19ms/step - loss: 3.2616 - sparse_categorical_accuracy: 0.6041 - sparse_top_k_categorical_accuracy: 0.7295 - val_loss: 3.3015 - val_sparse_categorical_accuracy: 0.5503 - val_sparse_top_k_categorical_accuracy: 0.7235
    14. Epoch 7/10
    15. 281/281 [==============================] - 6s 21ms/step - loss: 3.2595 - sparse_categorical_accuracy: 0.6059 - sparse_top_k_categorical_accuracy: 0.7322 - val_loss: 3.3064 - val_sparse_categorical_accuracy: 0.5454 - val_sparse_top_k_categorical_accuracy: 0.7266
    16. Epoch 8/10
    17. 281/281 [==============================] - 6s 21ms/step - loss: 3.2591 - sparse_categorical_accuracy: 0.6063 - sparse_top_k_categorical_accuracy: 0.7327 - val_loss: 3.3025 - val_sparse_categorical_accuracy: 0.5481 - val_sparse_top_k_categorical_accuracy: 0.7231
    18. Epoch 9/10
    19. 281/281 [==============================] - 5s 19ms/step - loss: 3.2588 - sparse_categorical_accuracy: 0.6062 - sparse_top_k_categorical_accuracy: 0.7332 - val_loss: 3.2992 - val_sparse_categorical_accuracy: 0.5521 - val_sparse_top_k_categorical_accuracy: 0.7257
    20. Epoch 10/10
    21. 281/281 [==============================] - 5s 18ms/step - loss: 3.2577 - sparse_categorical_accuracy: 0.6073 - sparse_top_k_categorical_accuracy: 0.7363 - val_loss: 3.2981 - val_sparse_categorical_accuracy: 0.5516 - val_sparse_top_k_categorical_accuracy: 0.7306
    22. CPU times: user 18.9 s, sys: 3.86 s, total: 22.7 s
    23. Wall time: 1min 1s

    如果对本书内容理解上有需要进一步和作者交流的地方,欢迎在公众号”Python与算法之美”下留言。作者时间和精力有限,会酌情予以回复。