[TOC]
概述
很多时候发现通过tensorflow或者pytorch转过来的模型是没有中间的node的shape的,比如下面这样:
原理
ONNX本身提供了进行inference的api:
1
| shape_inference.infer_shapes()
|
但是呢,这里进行inference并不是根据graph中的tensor,而是根据graph的input中各个tensor的tensor_value_info。所以我们需要做的就是根据各个tensor的信息创建出对应的tensor_value_info之后将其append进graph.inputs即可。
最开始我进行infer_shapes之后发现没用就是因为graph.inputs中的tensor_value_info只有input node的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| import onnx from onnx.tools import update_model_dims import numpy as np import onnx.helper as helper from onnx import shape_inference, TensorProto import sys
ONNX_DTYPE = { 0: TensorProto.FLOAT, 1: TensorProto.FLOAT, 2: TensorProto.UINT8, 3: TensorProto.INT8, 4: TensorProto.UINT16, 5: TensorProto.INT16, 6: TensorProto.INT32, 7: TensorProto.INT64, 8: TensorProto.STRING, 9: TensorProto.BOOL }
# load model onnx_model = onnx.load("tf_resnet_v2_50_onnx.onnx") graph = onnx_model.graph
# rewrite the input tensor of graph input_tensor = graph.input[0] input_shape = input_tensor.type.tensor_type.shape.dim input_tensor_new = onnx.helper.make_tensor_value_info(name = input_tensor.name, elem_type = 1, shape = [1, input_shape[1].dim_value, input_shape[2].dim_value, input_shape[3].dim_value]) graph.input.remove(input_tensor) graph.input.insert(0, input_tensor_new)
# append all tensor infos to graph input weight_infos = [] tensors = graph.initializer for i, tensor in enumerate(tensors): value_info = helper.make_tensor_value_info(tensor.name, ONNX_DTYPE[tensor.data_type], tensor.dims) weight_infos.append(value_info) graph.input.insert(i+1, value_info) # because 0 is for placeholder, so start index is 1
# run node shape inference node = graph.node value_info = graph.value_info print("Before shape inference: \n") print(value_info) print("------------------------------------------------------------") print("After shape inference: \n") inferred_onnx_model = shape_inference.infer_shapes(onnx_model) onnx.checker.check_model(onnx_model) inferred_graph = inferred_onnx_model.graph inferred_value_info = inferred_graph.value_info print(inferred_value_info) onnx.save(inferred_onnx_model,"./new.onnx")
|