Tensorflow中float32模型强制转为float16半浮点模型
最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。
在Tensorflow框架训练完成后,部署模型时希望对模型进行压缩。一种方案是前面文字介绍的方法《【Ubuntu】Tensorflow对训练后的模型做8位(uint8)量化转换》。另一种方法是半浮点量化,今天我们主要介绍如何通过修改Tensorflow的pb文件中的计算节点和常量(const),将float32数据类型的模型大小压缩减半为float16数据类型的模型。
1 加载pb模型
封装函数,加载pb模型:
def load_graph(model_path): graph = tf.Graph() with graph.as_default(): graph_def = tf.GraphDef() if model_path.endswith("pb"): with open(model_path, "rb") as f: graph_def.ParseFromString(f.read()) else: with open(model_path, "r") as pf: text_format.Parse(pf.read(), graph_def) tf.import_graph_def(graph_def, name="") sess = tf.Session(graph=graph) ops=graph.get_operations() for op in ops: print(op.name) return sess
2 重写BatchNorm
由于BatchNorm对精度比较敏感,需要保持float32类型,因此BatchNorm需要特殊处理。
#用FusedBatchNormV2替换FusedBatchNorm,以保证反向梯度下降计算时使用的是float def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): if target_type == 'fp16': dtype = types_pb2.DT_HALF elif target_type == 'fp64': dtype = types_pb2.DT_DOUBLE else: dtype = types_pb2.DT_FLOAT new_node = graph_def.node.add() new_node.op = "FusedBatchNormV2" new_node.name = node.name new_node.input.extend(node.input) new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) for attr in list(node.attr.keys()): if attr == "T": node.attr[attr].type = dtype new_node.attr[attr].CopyFrom(node.attr[attr]) print("rewrite fused_batch_norm done!")
3 Graph转换
重新构造graph,参数从原始pb的graph中拷贝,并转为float16
def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None): #生成新的图数据类型 if target_type == 'fp16': dtype = types_pb2.DT_HALF elif target_type == 'fp64': dtype = types_pb2.DT_DOUBLE else: dtype = types_pb2.DT_FLOAT #加载需要转换的模型 source_sess = load_graph(model_path) source_graph_def = source_sess.graph.as_graph_def() #创建新的模图对象 target_graph_def = graph_pb2.GraphDef() target_graph_def.versions.CopyFrom(source_graph_def.versions) #对加载的模型遍历计算节点 for node in source_graph_def.node: # 对FusedBatchNorm计算节点替换为FusedBatchNormV2 if node.op == "FusedBatchNorm": rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type) continue # 复制计算节点 new_node = target_graph_def.node.add() new_node.op = node.op new_node.name = node.name new_node.input.extend(node.input) #对attrs属性进行复制,attrs属性主要关注 attrs = list(node.attr.keys()) # BatchNorm属性保持不变 if ("BatchNorm" in node.name) or ('batch_normalization' in node.name): for attr in attrs: new_node.attr[attr].CopyFrom(node.attr[attr]) continue # 除了BatchNorm以外其他计算节点的属性单独 for attr in attrs: # 对指定的计算节点保持不变 if node.name in keep_fp32_node_name: new_node.attr[attr].CopyFrom(node.attr[attr]) continue #将Float类型修改为设置的目标类型 if node.attr[attr].type == types_pb2.DT_FLOAT: # modify node dtype node.attr[attr].type = dtype #重点关注value,weights都是保存在value属性中 if attr == "value": tensor = node.attr[attr].tensor if tensor.dtype == types_pb2.DT_FLOAT: # if float_val exists if tensor.float_val: float_val = tf.make_ndarray(node.attr[attr].tensor) new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype)) continue # if tensor content exists if tensor.tensor_content: tensor_shape = [x.size for x in tensor.tensor_shape.dim] tensor_weights = tf.make_ndarray(tensor) # reshape tensor tensor_weights = np.reshape(tensor_weights, tensor_shape) tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype) new_node.attr[attr].tensor.CopyFrom(tensor_proto) continue new_node.attr[attr].CopyFrom(node.attr[attr]) # transform graph if output_names: if not input_name: input_name = [] transforms = ["strip_unused_nodes"] target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms) # write graph_def to model tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text) print("Converting done ...")
4 完整的代码
import tensorflow as tf from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2 from tensorflow.tools.graph_transforms import TransformGraph from google.protobuf import text_format import numpy as np # object detection api input and output nodes input_name = "input_tf" output_names = ["output:0"] keep_fp32_node_name = [] def load_graph(model_path): graph = tf.Graph() with graph.as_default(): graph_def = tf.GraphDef() if model_path.endswith("pb"): with open(model_path, "rb") as f: graph_def.ParseFromString(f.read()) else: with open(model_path, "r") as pf: text_format.Parse(pf.read(), graph_def) tf.import_graph_def(graph_def, name="") sess = tf.Session(graph=graph) ops=graph.get_operations() for op in ops: print(op.name) return sess #用FusedBatchNormV2替换FusedBatchNorm,以保证反向梯度下降计算时使用的是float def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): if target_type == 'fp16': dtype = types_pb2.DT_HALF elif target_type == 'fp64': dtype = types_pb2.DT_DOUBLE else: dtype = types_pb2.DT_FLOAT new_node = graph_def.node.add() new_node.op = "FusedBatchNormV2" new_node.name = node.name new_node.input.extend(node.input) new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) for attr in list(node.attr.keys()): if attr == "T": node.attr[attr].type = dtype new_node.attr[attr].CopyFrom(node.attr[attr]) print("rewrite fused_batch_norm done!") def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None): #生成新的图数据类型 if target_type == 'fp16': dtype = types_pb2.DT_HALF elif target_type == 'fp64': dtype = types_pb2.DT_DOUBLE else: dtype = types_pb2.DT_FLOAT #加载需要转换的模型 source_sess = load_graph(model_path) source_graph_def = source_sess.graph.as_graph_def() #创建新的模图对象 target_graph_def = graph_pb2.GraphDef() target_graph_def.versions.CopyFrom(source_graph_def.versions) #对加载的模型遍历计算节点 for node in source_graph_def.node: # 对FusedBatchNorm计算节点替换为FusedBatchNormV2 if node.op == "FusedBatchNorm": rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type) continue # 复制计算节点 new_node = target_graph_def.node.add() new_node.op = node.op new_node.name = node.name new_node.input.extend(node.input) #对attrs属性进行复制,attrs属性主要关注 attrs = list(node.attr.keys()) # BatchNorm属性保持不变 if ("BatchNorm" in node.name) or ('batch_normalization' in node.name): for attr in attrs: new_node.attr[attr].CopyFrom(node.attr[attr]) continue # 除了BatchNorm以外其他计算节点的属性单独 for attr in attrs: # 对指定的计算节点保持不变 if node.name in keep_fp32_node_name: new_node.attr[attr].CopyFrom(node.attr[attr]) continue #将Float类型修改为设置的目标类型 if node.attr[attr].type == types_pb2.DT_FLOAT: # modify node dtype node.attr[attr].type = dtype #重点关注value,weights都是保存在value属性中 if attr == "value": tensor = node.attr[attr].tensor if tensor.dtype == types_pb2.DT_FLOAT: # if float_val exists if tensor.float_val: float_val = tf.make_ndarray(node.attr[attr].tensor) new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype)) continue # if tensor content exists if tensor.tensor_content: tensor_shape = [x.size for x in tensor.tensor_shape.dim] tensor_weights = tf.make_ndarray(tensor) # reshape tensor tensor_weights = np.reshape(tensor_weights, tensor_shape) tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype) new_node.attr[attr].tensor.CopyFrom(tensor_proto) continue new_node.attr[attr].CopyFrom(node.attr[attr]) # transform graph if output_names: if not input_name: input_name = [] transforms = ["strip_unused_nodes"] target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms) # write graph_def to model tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text) print("Converting done ...") save_path = "test" name = "output_fp16.pb" model_path="test.pb" as_text = False target_type = 'fp16' convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names) # 测试一下转换后的模型是否能够加载 sess = load_graph(save_path+"/"+name)