• 欢迎光临~

4-3AutoGraph的使用规范——eat_tensorflow2_in_30_days

开发技术 开发技术 2022-06-21 次浏览

4-3 AutoGraph的使用规范

有三种计算图的构建方式:静态计算图,动态计算图,以及Autograph

TensorFlow2.0主要使用的是动态计算图和Autograph

  • 动态计算图易于调试,编码效率较高,但执行效率较低
  • 静态计算图执行效率很高,但较难调试
  • 而Autograph机制可以将动态计算图转换成静态计算图,兼收执行效率和编码效率之利

当然Autograph 机制能够转换的代码并不是没哟任何约束的,有一些编码规范需要遵循,否则可能会转换失败或者不符合预期

这里将着重介绍Autograph的编码规范和Autograph转换成静态图的原理

并介绍使用tf.Module来更好地构建Autograph

Autograph编码规范总结

  • 被@tf.function修饰的函数应尽可能使用TensorFlow中的函数而不是Python中的其他函数。例如使用tf.print而不是print,使用tf.range而不是range,使用tf.constant(True)而不是True
  • 避免在@tf.function修饰的函数内部定义tf.Variable
  • 被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等数据结构变量

Autograph编码规范解析

  • 被@tf.function修饰的函数应尽可能使用TensorFlow中的函数而不是Python中的其他函数
import numpy as np
import tensorflow as tf

@tf.function
def np_random():
    # np.random.randn函数返回一个或一组样本,具有标准正太分布
    a = np.random.randn(3, 3)
    tf.print(a)
    
@tf.function
def tf_random():
    
    # tf.random.normal服从指定正态分布的序列
    a = tf.random.normal((3, 3))
    tf.print(a)
    
# np.random每次执行都是一样的结果
np_random()
np_random()

"""
array([[-0.84051143, -0.11712408, -0.17738803],
       [ 0.7147196 ,  1.42842053, -0.56037017],
       [-0.0487268 ,  1.05235275,  1.01622511]])
array([[-0.84051143, -0.11712408, -0.17738803],
       [ 0.7147196 ,  1.42842053, -0.56037017],
       [-0.0487268 ,  1.05235275,  1.01622511]])
"""

# tf_random每次执行都会有重新生成随机数
tf_random()
tf_random()

"""
[[1.19916523 0.203395322 1.3903774]
 [-2.06304955 -0.38222155 -1.46414936]
 [0.491630137 0.0822804719 -0.254222572]]
[[-0.549568892 2.08878803 0.558463752]
 [-0.36475572 0.136399537 -0.0849579573]
 [0.253954887 -0.276775241 -1.54324198]]
"""
  • 避免在@tf.function修饰的函数内部定义tf.Variable
# 避免在@tf.function修饰的函数内部定义tf.Variable
x = tf.Variable(1.0, dtype=tf.float32)

@tf.function
def outer_var():
    x.assign_add(1.0)
    tf.print(x)
    return x

outer_var()
outer_var()

"""
2
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
"""
@tf.function
def inner_var():
    x = tf.Variable(1.0, dtype=tf.float32)
    x.assign_add(1.0)
    tf.print(x)
    return x

#  执行将报错
# inner_var()
  • 被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等数据结构变量
tensor_list = []

@tf.function  # 加上这一行换成Autograph结果将不符合预期
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list

append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

"""
[<tf.Tensor 'x:0' shape=() dtype=float32>]
"""
tensor_list = []

# @tf.function  # 加上这一行换成Autograph结果将不符合预期
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list

append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

"""
[<tf.Tensor: shape=(), dtype=float32, numpy=5.0>, <tf.Tensor: shape=(), dtype=float32, numpy=6.0>]
"""
程序员灯塔
转载请注明原文链接:4-3AutoGraph的使用规范——eat_tensorflow2_in_30_days
喜欢 (0)