TensorFlow Frontend前端

时间:2021-03-16 11:59:01   收藏:0   阅读:0

TensorFlow Frontend前端

TensorFlow前端有助于将TensorFlow模型导入TVM。

Supported versions:

Tested models:

Preparing a Model for Inference准备推理模型

Remove Unneeded Nodes删除不需要的节点

导出过程将删除许多不需要进行推理的节点,但不幸的是会留下一些剩余的节点。应该手动删除的节点:

Convert None Dimensions to Constants将无尺寸Dimensions转换为常数

TVM对动态张量形状的支持最少。None应将尺寸替换为常量。例如,模型可以接受带有shape的输入(None,20)。这应转换为的形状(1,20)。应该相应地修改模型,以确保这些形状在整个图形中都匹配。

Export

TensorFlow前端需要冻结的protobuf(.pb)或保存的模型作为输入。不支持检查点(.ckpt)。TensorFlow前端所需的graphdef,可以从活动会话中提取,可以使用TFParser帮助器类提取。

应该导出该模型并进行许多转换,以准备模型进行推理。设置`add_shapes=True`也很重要,因为这会将每个节点的输出形状嵌入到图形中。这是一个给定会话将模型导出为protobuf的函数:

import tensorflow as tf

from tensorflow.tools.graph_transforms import TransformGraph

 

def export_pb(session):

    with tf.gfile.GFile("myexportedmodel.pb", "wb") as f:

        inputs = ["myinput1", "myinput2"] # replace with your input names

        outputs = ["myoutput1"] # replace with your output names

        graph_def = session.graph.as_graph_def(add_shapes=True)

        graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs)

        graph_def = TransformGraph(

            graph_def,

            inputs,

            outputs,

            [

                "remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)",

                "sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering

                "remove_attribute(attribute_name=_XlaSeparateCompiledGradients)",

                "remove_attribute(attribute_name=_XlaCompile)",

                "remove_attribute(attribute_name=_XlaScope)",

                "sort_by_execution_order",

                "remove_device",

                "sort_by_execution_order",

                "fold_batch_norms",

                "sort_by_execution_order",

                "fold_old_batch_norms",

                "sort_by_execution_order"

            ]

        )

        f.write(graph_def.SerializeToString())

Another method is to export and freeze the graph.

Import the Model

Explicit Shape:

确保可以在整个图形中知道形状,将`shape`参数传递给`from_tensorflow`。该词典将输入名称映射到输入形状。

Data Layout

大多数TensorFlow模型以NHWC布局发布。NCHW布局通常提供更好的性能,尤其是在GPU上。该TensorFlow前端可以通过传递参数自动转换模型的数据布局`layout=‘NCHW‘`到`from_tensorflow`。

Best Practices

Supported Ops

 

评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!