6-2 Three Ways of Training

    Note: fit_generator method is not recommended in tf.keras since it has been merged into fit.

    1. MAX_LEN = 300
    2. BATCH_SIZE = 32
    3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
    4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
    5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
    6. MAX_WORDS = x_train.max()+1
    7. CAT_NUM = y_train.max()+1
    8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
    9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
    10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
    11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
    12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
    13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

    This is a powerful method, which supports training the data with types of numpy array, tf.data.Dataset and Python generator.

    1. tf.keras.backend.clear_session()
    2. def create_model():
    3. model = models.Sequential()
    4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    6. model.add(layers.MaxPool1D(2))
    7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    8. model.add(layers.MaxPool1D(2))
    9. model.add(layers.Flatten())
    10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    11. return(model)
    12. def compile_model(model):
    13. model.compile(optimizer=optimizers.Nadam(),
    14. loss=losses.SparseCategoricalCrossentropy(),
    15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
    16. return(model)
    17. model = create_model()
    18. model.summary()
    19. model = compile_model(model)
    1. Model: "sequential"
    2. _________________________________________________________________
    3. Layer (type) Output Shape Param #
    4. =================================================================
    5. embedding (Embedding) (None, 300, 7) 216874
    6. _________________________________________________________________
    7. conv1d (Conv1D) (None, 296, 64) 2304
    8. _________________________________________________________________
    9. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
    10. _________________________________________________________________
    11. conv1d_1 (Conv1D) (None, 146, 32) 6176
    12. _________________________________________________________________
    13. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
    14. _________________________________________________________________
    15. flatten (Flatten) (None, 2336) 0
    16. _________________________________________________________________
    17. dense (Dense) (None, 46) 107502
    18. =================================================================
    19. Total params: 332,856
    20. Trainable params: 332,856
    21. Non-trainable params: 0
    22. _________________________________________________________________
    1. history = model.fit(ds_train,validation_data = ds_test,epochs = 10)

    This pre-defined method allows fine-controlling to the training procedure for each batch without the callbacks, which is even more flexible than fit method.

    1. tf.keras.backend.clear_session()
    2. def create_model():
    3. model = models.Sequential()
    4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    6. model.add(layers.MaxPool1D(2))
    7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    8. model.add(layers.MaxPool1D(2))
    9. model.add(layers.Flatten())
    10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    11. return(model)
    12. def compile_model(model):
    13. model.compile(optimizer=optimizers.Nadam(),
    14. loss=losses.SparseCategoricalCrossentropy(),
    15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
    16. return(model)
    17. model = create_model()
    18. model.summary()
    1. Model: "sequential"
    2. _________________________________________________________________
    3. Layer (type) Output Shape Param #
    4. =================================================================
    5. _________________________________________________________________
    6. conv1d (Conv1D) (None, 296, 64) 2304
    7. _________________________________________________________________
    8. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
    9. _________________________________________________________________
    10. conv1d_1 (Conv1D) (None, 146, 32) 6176
    11. _________________________________________________________________
    12. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
    13. _________________________________________________________________
    14. flatten (Flatten) (None, 2336) 0
    15. _________________________________________________________________
    16. dense (Dense) (None, 46) 107502
    17. =================================================================
    18. Total params: 332,856
    19. Trainable params: 332,856
    20. Non-trainable params: 0
    21. _________________________________________________________________
    1. def train_model(model,ds_train,ds_valid,epoches):
    2. for epoch in tf.range(1,epoches+1):
    3. model.reset_metrics()
    4. # Reduce learning rate at the late stage of training.
    5. if epoch == 5:
    6. model.optimizer.lr.assign(model.optimizer.lr/2.0)
    7. tf.print("Lowering optimizer Learning Rate...\n\n")
    8. for x, y in ds_train:
    9. train_result = model.train_on_batch(x, y)
    10. for x, y in ds_valid:
    11. valid_result = model.test_on_batch(x, y,reset_metrics=False)
    12. if epoch%1 ==0:
    13. printbar()
    14. tf.print("epoch = ",epoch)
    15. print("train:",dict(zip(model.metrics_names,train_result)))
    16. print("valid:",dict(zip(model.metrics_names,valid_result)))
    17. print("")
    1. train_model(model,ds_train,ds_test,10)
    1. ================================================================================13:09:19
    2. epoch = 1
    3. train: {'loss': 0.82411176, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 0.8636364}
    4. valid: {'loss': 1.9265995, 'sparse_categorical_accuracy': 0.5743544, 'sparse_top_k_categorical_accuracy': 0.75779164}
    5. ================================================================================13:09:27
    6. epoch = 2
    7. train: {'loss': 0.6006621, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 0.95454544}
    8. valid: {'loss': 1.844159, 'sparse_categorical_accuracy': 0.6126447, 'sparse_top_k_categorical_accuracy': 0.7920748}
    9. ================================================================================13:09:35
    10. epoch = 3
    11. train: {'loss': 0.36935613, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 0.95454544}
    12. valid: {'loss': 2.163433, 'sparse_categorical_accuracy': 0.63312554, 'sparse_top_k_categorical_accuracy': 0.8045414}
    13. ================================================================================13:09:42
    14. epoch = 4
    15. train: {'loss': 0.2304088, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 1.0}
    16. valid: {'loss': 2.8911984, 'sparse_categorical_accuracy': 0.6344613, 'sparse_top_k_categorical_accuracy': 0.7978629}
    17. Lowering optimizer Learning Rate...
    18. ================================================================================13:09:51
    19. epoch = 5
    20. train: {'loss': 0.111194365, 'sparse_categorical_accuracy': 0.95454544, 'sparse_top_k_categorical_accuracy': 1.0}
    21. valid: {'loss': 3.6431572, 'sparse_categorical_accuracy': 0.6295637, 'sparse_top_k_categorical_accuracy': 0.7978629}
    22. ================================================================================13:09:59
    23. epoch = 6
    24. train: {'loss': 0.07741702, 'sparse_categorical_accuracy': 0.95454544, 'sparse_top_k_categorical_accuracy': 1.0}
    25. valid: {'loss': 4.074161, 'sparse_categorical_accuracy': 0.6255565, 'sparse_top_k_categorical_accuracy': 0.794301}
    26. ================================================================================13:10:07
    27. epoch = 7
    28. train: {'loss': 0.056113098, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
    29. valid: {'loss': 4.4461513, 'sparse_categorical_accuracy': 0.6273375, 'sparse_top_k_categorical_accuracy': 0.79652715}
    30. ================================================================================13:10:17
    31. epoch = 8
    32. train: {'loss': 0.043448802, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
    33. valid: {'loss': 4.7687583, 'sparse_categorical_accuracy': 0.6224399, 'sparse_top_k_categorical_accuracy': 0.79741764}
    34. ================================================================================13:10:26
    35. epoch = 9
    36. train: {'loss': 0.035002146, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
    37. valid: {'loss': 5.130505, 'sparse_categorical_accuracy': 0.6175423, 'sparse_top_k_categorical_accuracy': 0.794301}
    38. ================================================================================13:10:34
    39. epoch = 10
    40. train: {'loss': 0.028303564, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
    41. valid: {'loss': 5.4559293, 'sparse_categorical_accuracy': 0.6148709, 'sparse_top_k_categorical_accuracy': 0.7947462}

    Re-compilation of the model is not required in the customized training loop, just back-propagate the iterative parameters through the optimizer according to the loss function, which gives us the highest flexibility.

    1. tf.keras.backend.clear_session()
    2. model = models.Sequential()
    3. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    4. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    5. model.add(layers.MaxPool1D(2))
    6. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    7. model.add(layers.MaxPool1D(2))
    8. model.add(layers.Flatten())
    9. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    10. return(model)
    11. model = create_model()
    12. model.summary()
    1. optimizer = optimizers.Nadam()
    2. loss_func = losses.SparseCategoricalCrossentropy()
    3. train_loss = metrics.Mean(name='train_loss')
    4. train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
    5. valid_loss = metrics.Mean(name='valid_loss')
    6. valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
    7. @tf.function
    8. def train_step(model, features, labels):
    9. with tf.GradientTape() as tape:
    10. predictions = model(features,training = True)
    11. loss = loss_func(labels, predictions)
    12. gradients = tape.gradient(loss, model.trainable_variables)
    13. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    14. train_loss.update_state(loss)
    15. train_metric.update_state(labels, predictions)
    16. @tf.function
    17. def valid_step(model, features, labels):
    18. predictions = model(features)
    19. batch_loss = loss_func(labels, predictions)
    20. valid_loss.update_state(batch_loss)
    21. valid_metric.update_state(labels, predictions)
    22. def train_model(model,ds_train,ds_valid,epochs):
    23. for epoch in tf.range(1,epochs+1):
    24. for features, labels in ds_train:
    25. train_step(model,features,labels)
    26. for features, labels in ds_valid:
    27. valid_step(model,features,labels)
    28. logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
    29. if epoch%1 ==0:
    30. printbar()
    31. tf.print(tf.strings.format(logs,
    32. (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
    33. tf.print("")
    34. train_loss.reset_states()
    35. valid_loss.reset_states()
    36. train_metric.reset_states()
    37. valid_metric.reset_states()
    38. train_model(model,ds_train,ds_test,10)
    1. ================================================================================13:12:03
    2. Epoch=1,Loss:2.02051544,Accuracy:0.460253835,Valid Loss:1.75700927,Valid Accuracy:0.536954582
    3. ================================================================================13:12:09
    4. Epoch=2,Loss:1.510795,Accuracy:0.610665798,Valid Loss:1.55349839,Valid Accuracy:0.616206586
    5. ================================================================================13:12:17
    6. Epoch=3,Loss:1.19221532,Accuracy:0.696170092,Valid Loss:1.52315605,Valid Accuracy:0.651380241
    7. ================================================================================13:12:23
    8. Epoch=4,Loss:0.90101546,Accuracy:0.766310394,Valid Loss:1.68327653,Valid Accuracy:0.648263574
    9. ================================================================================13:12:30
    10. Epoch=5,Loss:0.655430496,Accuracy:0.831329346,Valid Loss:1.90872383,Valid Accuracy:0.641139805
    11. ================================================================================13:12:37
    12. Epoch=6,Loss:0.492730737,Accuracy:0.877866864,Valid Loss:2.09966016,Valid Accuracy:0.63223511
    13. ================================================================================13:12:44
    14. Epoch=7,Loss:0.391238362,Accuracy:0.904030263,Valid Loss:2.27431226,Valid Accuracy:0.625111282
    15. ================================================================================13:12:51
    16. Epoch=8,Loss:0.327761739,Accuracy:0.922066331,Valid Loss:2.42568827,Valid Accuracy:0.617542326
    17. ================================================================================13:12:58
    18. Epoch=9,Loss:0.285573095,Accuracy:0.930527747,Valid Loss:2.55942106,Valid Accuracy:0.612644672
    19. ================================================================================13:13:05
    20. Epoch=10,Loss:0.255482465,Accuracy:0.936094403,Valid Loss:2.67789412,Valid Accuracy:0.612199485

    You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.