警告

    并不是任何函数都可以被 @tf.function 修饰!@tf.function 使用静态编译将函数内的代码转换成计算图,因此对函数内可使用的语句有一定限制(仅支持Python语言的一个子集),且需要函数内的操作本身能够被构建为计算图。建议在函数内只使用TensorFlow的原生操作,不要使用过于复杂的Python语句,函数参数只包括TensorFlow张量或NumPy数组,并最好是能够按照计算图的思想去构建函数(换言之,@tf.function 只是给了你一种更方便的写计算图的方法,而不是一颗能给任何函数加速的 )。详细内容可参考 AutoGraph Capabilities and Limitations

    运行400个Batch进行测试,加入 @tf.function 的程序耗时35.5秒,未加入 @tf.function 的纯Eager Execution程序耗时43.8秒。可见 @tf.function 带来了一定的性能提升。一般而言,当模型由较多小的操作组成的时候, @tf.function 带来的提升效果较大。而当模型的操作数量较少,但单一操作均很耗时的时候,则 @tf.function 带来的性能提升不会太大。

    当被 @tf.function 修饰的函数第一次被调用的时候,进行以下操作:

    • 在Eager Execution模式关闭的环境下,函数内的代码依次运行。也就是说,每个 tf. 方法都只是定义了计算节点,而并没有进行任何实质的计算。这与TensorFlow 1.X的Graph Execution是一致的;

    • 使用AutoGraph将函数中的Python控制流语句转换成TensorFlow计算图中的对应节点(比如说 whilefor 语句转换为 tf.whileif 语句转换为 等等;

    • 基于上面的两步,建立函数内代码的计算图表示(为了保证图的计算顺序,图中还会自动加入一些 tf.control_dependencies 节点);

    • 运行一次这个计算图;

    • 基于函数的名字和输入的函数参数的类型生成一个哈希值,并将建立的计算图缓存到一个哈希表中。

    以下是一个测试题:

    1. import tensorflow as tf
    2. import numpy as np
    3.  
    4. @tf.function
    5. def f(x):
    6. print("The function is running in Python")
    7. tf.print(x)
    8.  
    9. a = tf.constant(1, dtype=tf.int32)
    10. f(a)
    11. b = tf.constant(2, dtype=tf.int32)
    12. f(b)
    13. b_ = np.array(2, dtype=np.int32)
    14. f(b_)
    15. c = tf.constant(0.1, dtype=tf.float32)
    16. f(c)
    17. d = tf.constant(0.2, dtype=tf.float32)
    18. f(d)
    19. f(1)
    20. f(1)
    21. f(0.1)
    22. f(0.2)
    23. f(0.1)

    思考一下,上面这段程序的结果是什么?

    答案是:

    当计算 f(a) 时,由于是第一次调用该函数,TensorFlow进行了以下操作:

    • 将函数内的代码依次运行了一遍(因此输出了文本);

    • 构建了计算图,然后运行了一次该计算图(因此输出了1)。这里 tf.print(x) 可以作为计算图的节点,但Python内置的 print 则不能被转换成计算图的节点。因此,计算图中只包含了 tf.print(x) 这一操作;

    • 将该计算图缓存到了一个哈希表中(如果之后再有类型为 tf.int32 ,shape为空的张量输入,则重复使用已构建的计算图)。

    计算 f(b) 时,由于b的类型与a相同,所以TensorFlow重复使用了之前已构建的计算图并运行(因此输出了2)。这里由于并没有真正地逐行运行函数中的代码,所以函数第一行的文本输出代码没有运行。计算 f(b_) 时,TensorFlow自动将numpy的数据结构转换成了TensorFlow中的张量,因此依然能够复用之前已构建的计算图。

    计算 f(c) 时,虽然张量 c 的shape和 ab 均相同,但类型为 tf.float32 ,因此TensorFlow重新运行了函数内代码(从而再次输出了文本)并建立了一个输入为 tf.float32 类型的计算图。

    之后的计算结果则显示出 @tf.function 对Python内置的整数和浮点数类型的处理方式。简而言之,只有当值完全一致的时候, @tf.function 才会复用之前建立的计算图,而并不会自动将Python内置的整数或浮点数等转换成张量。因此,当函数参数包含Python内置整数或浮点数时,需要额外小心。一般而言,应当只在指定超参数等少数场合使用Python内置类型作为被 @tf.function 修饰的函数的参数。

    下一个思考题:

    1. import tensorflow as tf
    2.  
    3. a = tf.Variable(0.0)
    4.  
    5. @tf.function
    6. def g():
    7. a.assign(a + 1.0)
    8. return a
    9.  
    10. print(g())
    11. print(g())
    12. print(g())

    这段代码的输出是:

    正如同正文里的例子一样,你可以在被 @tf.function 修饰的函数里调用 tf.Variabletf.keras.optimizerstf.keras.Model 等包含有变量的数据结构。一旦被调用,这些结构将作为隐含的参数提供给函数。当这些结构内的值在函数内被修改时,在函数外也同样生效。

    前面提到,@tf.function 使用名为AutoGraph的机制将函数中的Python控制流语句转换成TensorFlow计算图中的对应节点。以下是一个示例,使用 tf.autograph 模块的低层API tf.autograph.to_code 将函数 square_if_positive 转换成TensorFlow计算图:

    1. import tensorflow as tf
    2.  
    3. @tf.function
    4. def square_if_positive(x):
    5. if x > 0:
    6. x = x * x
    7. else:
    8. x = 0
    9. return x
    10. a = tf.constant(1)
    11. b = tf.constant(-1)
    12. print(square_if_positive(a), square_if_positive(b))
    13. print(tf.autograph.to_code(square_if_positive.python_function))

    输出:

    我们注意到,原函数中的Python控制流 if…else… 被转换为了 x = ag__.if_stmt(cond, if_true, if_false, get_state, set_state) 这种计算图式的写法。AutoGraph起到了类似编译器的作用,能够帮助我们通过更加自然的Python控制流轻松地构建带有条件/循环的计算图,而无需手动使用TensorFlow的API进行构建。

    不过,如果你依然钟情于TensorFlow传统的Graph Execution模式也没有问题。TensorFlow 2.0提供了 tf.compat.v1 模块以支持TensorFlow 1.X版本的API。同时,只要在编写模型的时候稍加注意,Keras的模型是可以同时兼容Eager Execution模式和Graph Execution模式的。注意,在Graph Execution模式下, model(input_tensor) 只需运行一次以完成图的建立操作。

    例如,通过以下代码,同样可以在MNIST数据集上训练前面所建立的MLP或CNN模型:

    1. optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    2. num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
    3. # 建立计算图
    4. X_placeholder = tf.compat.v1.placeholder(name='X', shape=[None, 28, 28, 1], dtype=tf.float32)
    5. y_placeholder = tf.compat.v1.placeholder(name='y', shape=[None], dtype=tf.int32)
    6. y_pred = model(X_placeholder)
    7. loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_placeholder, y_pred=y_pred)
    8. loss = tf.reduce_mean(loss)
    9. train_op = optimizer.minimize(loss)
    10. sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    11. # 建立Session
    12. with tf.compat.v1.Session() as sess:
    13. sess.run(tf.compat.v1.global_variables_initializer())
    14. for batch_index in range(num_batches):
    15. X, y = data_loader.get_batch(batch_size)
    16. # 使用Session.run()将数据送入计算图节点,进行训练以及计算损失函数
    17. _, loss_value = sess.run([train_op, loss], feed_dict={X_placeholder: X, y_placeholder: y})
    18. print("batch %d: loss %f" % (batch_index, loss_value))
    19.  
    20. num_batches = int(data_loader.num_test_data // batch_size)
    21. for batch_index in range(num_batches):
    22. start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    23. y_pred = model.predict(data_loader.test_data[start_index: end_index])
    24. sess.run(sparse_categorical_accuracy.update(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred))