1-2,图片数据建模流程范例

    训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。

    cifar2任务的目标是训练一个模型来对飞机airplane和机动车automobile两种图片进行分类。

    我们准备的Cifar2数据集的文件结构如下所示。

    在tensorflow中准备图片数据的常用方案有两种,第一种是使用tf.keras中的ImageDataGenerator工具构建图片数据生成器。

    第二种是使用tf.data.Dataset搭配tf.image中的一些图片处理方法构建数据管道。

    第一种方法更为简单,其使用范例可以参考以下文章。

    第二种方法是TensorFlow的原生方法,更加灵活,使用得当的话也可以获得更好的性能。

    我们此处介绍第二种方法。

    1. from tensorflow.keras import datasets,layers,models
    2. BATCH_SIZE = 100
    3. def load_image(img_path,size = (32,32)):
    4. label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*automobile.*") \
    5. else tf.constant(0,tf.int8)
    6. img = tf.io.read_file(img_path)
    7. img = tf.image.decode_jpeg(img) #注意此处为jpeg格式
    8. img = tf.image.resize(img,size)/255.0
    9. return(img,label)
    1. #使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能
    2. ds_train = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg") \
    3. .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    4. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
    5. .prefetch(tf.data.experimental.AUTOTUNE)
    6. ds_test = tf.data.Dataset.list_files("./data/cifar2/test/*/*.jpg") \
    7. .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    8. .batch(BATCH_SIZE) \
    9. .prefetch(tf.data.experimental.AUTOTUNE)
    1. %matplotlib inline
    2. %config InlineBackend.figure_format = 'svg'
    3. #查看部分样本
    4. from matplotlib import pyplot as plt
    5. plt.figure(figsize=(8,8))
    6. for i,(img,label) in enumerate(ds_train.unbatch().take(9)):
    7. ax=plt.subplot(3,3,i+1)
    8. ax.imshow(img.numpy())
    9. ax.set_title("label = %d"%label)
    10. ax.set_xticks([])
    11. ax.set_yticks([])
    12. plt.show()

    1-2,图片数据建模流程范例 - 图2

    1. for x,y in ds_train.take(1):
    2. print(x.shape,y.shape)
    1. (100, 32, 32, 3) (100,)

    二,定义模型

    使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。

    此处选择使用函数式API构建模型。

    1. tf.keras.backend.clear_session() #清空会话
    2. inputs = layers.Input(shape=(32,32,3))
    3. x = layers.Conv2D(32,kernel_size=(3,3))(inputs)
    4. x = layers.MaxPool2D()(x)
    5. x = layers.Conv2D(64,kernel_size=(5,5))(x)
    6. x = layers.MaxPool2D()(x)
    7. x = layers.Dropout(rate=0.1)(x)
    8. x = layers.Flatten()(x)
    9. x = layers.Dense(32,activation='relu')(x)
    10. outputs = layers.Dense(1,activation = 'sigmoid')(x)
    11. model = models.Model(inputs = inputs,outputs = outputs)
    12. model.summary()
    1. Model: "model"
    2. Layer (type) Output Shape Param #
    3. =================================================================
    4. input_1 (InputLayer) [(None, 32, 32, 3)] 0
    5. _________________________________________________________________
    6. _________________________________________________________________
    7. max_pooling2d (MaxPooling2D) (None, 15, 15, 32) 0
    8. _________________________________________________________________
    9. conv2d_1 (Conv2D) (None, 11, 11, 64) 51264
    10. _________________________________________________________________
    11. max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0
    12. _________________________________________________________________
    13. dropout (Dropout) (None, 5, 5, 64) 0
    14. _________________________________________________________________
    15. flatten (Flatten) (None, 1600) 0
    16. _________________________________________________________________
    17. dense (Dense) (None, 32) 51232
    18. _________________________________________________________________
    19. dense_1 (Dense) (None, 1) 33
    20. =================================================================
    21. Total params: 103,425
    22. Trainable params: 103,425
    23. Non-trainable params: 0
    24. _________________________________________________________________

    训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。

    1. Train for 100 steps, validate for 20 steps
    2. Epoch 1/10
    3. 100/100 [==============================] - 16s 156ms/step - loss: 0.4830 - accuracy: 0.7697 - val_loss: 0.3396 - val_accuracy: 0.8475
    4. Epoch 2/10
    5. 100/100 [==============================] - 14s 142ms/step - loss: 0.3437 - accuracy: 0.8469 - val_loss: 0.2997 - val_accuracy: 0.8680
    6. Epoch 3/10
    7. 100/100 [==============================] - 13s 131ms/step - loss: 0.2871 - accuracy: 0.8777 - val_loss: 0.2390 - val_accuracy: 0.9015
    8. Epoch 4/10
    9. 100/100 [==============================] - 12s 117ms/step - loss: 0.2410 - accuracy: 0.9040 - val_loss: 0.2005 - val_accuracy: 0.9195
    10. Epoch 5/10
    11. 100/100 [==============================] - 13s 130ms/step - loss: 0.1992 - accuracy: 0.9213 - val_loss: 0.1949 - val_accuracy: 0.9180
    12. Epoch 6/10
    13. 100/100 [==============================] - 14s 136ms/step - loss: 0.1737 - accuracy: 0.9323 - val_loss: 0.1723 - val_accuracy: 0.9275
    14. Epoch 7/10
    15. 100/100 [==============================] - 14s 139ms/step - loss: 0.1531 - accuracy: 0.9412 - val_loss: 0.1670 - val_accuracy: 0.9310
    16. Epoch 8/10
    17. 100/100 [==============================] - 13s 134ms/step - loss: 0.1299 - accuracy: 0.9525 - val_loss: 0.1553 - val_accuracy: 0.9340
    18. Epoch 9/10
    19. 100/100 [==============================] - 14s 137ms/step - loss: 0.1158 - accuracy: 0.9556 - val_loss: 0.1581 - val_accuracy: 0.9340
    20. Epoch 10/10
    21. 100/100 [==============================] - 14s 142ms/step - loss: 0.1006 - accuracy: 0.9617 - val_loss: 0.1614 - val_accuracy: 0.9345

    四,评估模型

    1. %load_ext tensorboard
    2. #%tensorboard --logdir ./data/keras_model
    1. from tensorboard import notebook
    2. notebook.list()
    1. #在tensorboard中查看模型
    2. notebook.start("--logdir ./data/keras_model")

    1. import pandas as pd
    2. dfhistory = pd.DataFrame(history.history)
    3. dfhistory.index = range(1,len(dfhistory) + 1)
    4. dfhistory.index.name = 'epoch'
    5. dfhistory
    1. %config InlineBackend.figure_format = 'svg'
    2. def plot_metric(history, metric):
    3. train_metrics = history.history[metric]
    4. val_metrics = history.history['val_'+metric]
    5. epochs = range(1, len(train_metrics) + 1)
    6. plt.plot(epochs, train_metrics, 'bo--')
    7. plt.plot(epochs, val_metrics, 'ro-')
    8. plt.title('Training and validation '+ metric)
    9. plt.xlabel("Epochs")
    10. plt.ylabel(metric)
    11. plt.legend(["train_"+metric, 'val_'+metric])
    12. plt.show()
    1. plot_metric(history,"loss")

    1-2,图片数据建模流程范例 - 图5

    1. #可以使用evaluate对数据进行评估
    2. val_loss,val_accuracy = model.evaluate(ds_test,workers=4)
    3. print(val_loss,val_accuracy)
    1. 0.16139143370091916 0.9345

    可以使用model.predict(ds_test)进行预测。

    也可以使用model.predict_on_batch(x_test)对一个批量进行预测。

    1. model.predict(ds_test)
    1. array([[9.9996173e-01],
    2. [9.5104784e-01],
    3. [2.8648047e-04],
    4. ...,
    5. [1.1484033e-03],
    6. [3.5589080e-02],
    7. [9.8537153e-01]], dtype=float32)
    1. for x,y in ds_test.take(1):
    2. print(model.predict_on_batch(x[0:20]))
    1. tf.Tensor(
    2. [[3.8065155e-05]
    3. [8.8236779e-01]
    4. [9.1433197e-01]
    5. [9.9921846e-01]
    6. [6.4052093e-01]
    7. [4.9970779e-03]
    8. [2.6735585e-04]
    9. [9.9842811e-01]
    10. [7.9198682e-01]
    11. [7.4823302e-01]
    12. [8.7208226e-03]
    13. [9.3951421e-03]
    14. [9.9790359e-01]
    15. [9.9998581e-01]
    16. [2.1642199e-05]
    17. [1.7915063e-02]
    18. [2.5839690e-02]
    19. [9.7538447e-01]
    20. [9.7393811e-01]
    21. [9.7333014e-01]], shape=(20, 1), dtype=float32)

    六,保存模型

    推荐使用TensorFlow原生方式保存模型。

    1. # 保存权重,该方式仅仅保存权重张量
    2. model.save_weights('./data/tf_model_weights.ckpt',save_format = "tf")
    1. # 保存模型结构与模型参数到文件,该方式保存的模型具有跨平台性便于部署
    2. model.save('./data/tf_model_savedmodel', save_format="tf")
    3. print('export saved model.')
    4. model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel')
    5. model_loaded.evaluate(ds_test)

    如果对本书内容理解上有需要进一步和作者交流的地方,欢迎在公众号”Python与算法之美”下留言。作者时间和精力有限,会酌情予以回复。

    也可以在公众号后台回复关键字:加群,加入读者交流群和大家讨论。