TensorFlow中graph_util()函数的理解和应用
tensorflow中的graph_util()函数主要用于导出和加载TensorFlow计算图,使得我们可以在不同的环境中共享和使用已经训练好的模型。graph_util模块提供了从GraphDef文件中加载计算图的函数,并提供了将计算图导出为GraphDef文件的函数。GraphDef是一个序列化的Protocol Buffer,包含了一个完整的TensorFlow计算图。
graph_util模块的主要函数和用法如下:
1. graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None)
此函数将会替换计算图中的变量(Variable)为该变量在当前会话中的值,并将计算图持久化为一个具有常量节点的GraphDef文件。参数说明:
- sess: TensorFlow会话
- input_graph_def: 输入的计算图的GraphDef对象
- output_node_names: 需要导出的计算图中的输出节点的名称列表
- variable_names_whitelist: 可选参数,指定需要转换为常量的变量名称列表
- variable_names_blacklist: 可选参数,指定不需要转换的变量名称列表
2. graph_util.remove_training_nodes(input_graph_def)
此函数将从计算图中移除训练节点(包括Variable节点、Assign节点等),返回移除后的新的GraphDef对象。参数说明:
- input_graph_def: 输入的计算图的GraphDef对象
下面是一个使用graph_util.convert_variables_to_constants函数的示例:
import tensorflow as tf
from tensorflow.python.framework import graph_util
# 创建TensorFlow计算图
x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='biases')
y = tf.nn.softmax(tf.matmul(x, W) + b, name='output')
# 导出计算图的GraphDef对象
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
graph_def = tf.get_default_graph().as_graph_def()
# 将计算图中的变量转换为常量
output_node_names = ['output']
constant_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, output_node_names)
# 保存为pb文件
output_dir = 'output/'
with tf.gfile.FastGFile(output_dir + 'model.pb', mode='wb') as f:
f.write(constant_graph_def.SerializeToString())
以上示例中,首先创建了一个简单的计算图,包含一个输入节点(input)、一个权重变量节点(weights)、一个偏置变量节点(biases)、一个输出节点(output)。然后使用tf.get_default_graph().as_graph_def()获取整个计算图的GraphDef对象。最后,使用graph_util.convert_variables_to_constants函数将计算图中的Variable节点替换为其在当前会话中的值,并将计算图持久化为一个pb文件。
这样,我们就可以在其他环境中使用该pb文件加载计算图,并进行推理操作。例如:
import tensorflow as tf
from tensorflow.python.platform import gfile
# 从pb文件中加载计算图
with tf.Session() as sess:
model_filename = 'output/model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# 获取输入输出节点
input_node = sess.graph.get_tensor_by_name('input:0')
output_node = sess.graph.get_tensor_by_name('output:0')
# 进行推理操作
input_data = ...
output_data = sess.run(output_node, feed_dict={input_node: input_data})
以上代码中,首先使用gfile.FastGFile从pb文件中加载计算图的GraphDef对象,然后使用tf.import_graph_def导入计算图。接着,通过sess.graph.get_tensor_by_name获取输入和输出节点,并使用sess.run进行推理。
