NVIDIA TensorRT 支持多种类型的层,并且其功能不断扩展; 但是,在某些情况下,支持的层不能满足模型的特定需求。 在这种情况下,可以通过实现自定义层(通常称为插件)来扩展 TensorRT。
TensorRT 包含可以加载到您的应用程序中的插件。 开源插件列表参考GitHub: TensorRT plugins(https://github.com/NVIDIA/TensorRT/tree/main/plugin#tensorrt-plugins
).
要在您的应用程序中使用 TensorRT 插件,必须加载 libnvinfer_plugin.so(Windows 上为 nvinfer_plugin.dll)库,并且必须通过在您的应用程序代码中调用 initLibNvInferPlugins 来注册所有插件。参考NvInferPlugin.h(https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/_nv_infer_plugin_8h.html
)获取更多信息。
如果这些插件不能满足你的需求,你可以自己编写。
9.1 使用C++API增加自定义层
您可以通过派生自 TensorRT 的插件基类之一来实现自定义层。
从插件的基类之一派生插件类。 它们在支持不同类型/格式的 I/O 或具有动态形状的网络方面具有不同的表达能力。 下表总结了基类,从最少表达到最多表达。
如果插件用于一般用途,请提供 FP32 实现,以使其能够在任何网络上正常运行。
在哪个TensorRT版本引入 | 混合 I/O 格式/类型 | 动态形状? | 支持隐式/显式批处理模式? | |
IPluginV2Ext | 5.1 | 有限的 | 不支持 | 隐式和显式批处理模式 |
IPluginV2IOExt | 6.0.1 | 一般 | 不支持 | 隐式和显式批处理模式 |
IPluginV2DynamicExt | 6.0.1 | 一般 | 支持 | 仅显式批处理模式 |
为了在网络中使用插件,您必须首先在 TensorRT 的 PluginRegistry 中注册它。最好不要直接注册插件,而是为插件注册一个工厂类的实例,派生自 PluginCreator。插件创建者类还提供有关插件的其他信息:名称、版本和插件字段参数。
您可以通过两种方式向注册表注册插件:
- TensorRT 提供了一个宏 REGISTER_TENSORRT_PLUGIN,用于向注册表静态注册插件creator。 请注意,REGISTER_TENSORRT_PLUGIN 始终在默认命名空间(“”)下注册creator。
- 通过创建您自己的类似于 initLibNvInferPlugins 的入口点并在插件注册表上调用 registerCreator 来动态注册。这比静态注册更可取,因为它提供了可能更低的内存占用,并允许插件在唯一的命名空间下注册。这可确保在不同插件库的构建期间不会发生名称冲突。
调用 IPluginCreator::createPlugin() 返回类型为 IPluginV2 的插件对象。 您可以使用 addPluginV2() 将插件添加到 TensorRT 网络,这会使用给定的插件创建网络层。
例如,您可以按如下方式向您的网络添加一个插件层。
// Look up the plugin in the registry auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion); const PluginFieldCollection* pluginFC = creator->getFieldNames(); // Populate the fields parameters for the plugin layer PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields); // Create the plugin object using the layerName and the plugin meta data IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData); // Add the plugin to the TensorRT network auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj); … (build rest of the network and serialize engine) // Destroy the plugin object pluginObj->destroy() … (free allocated pluginData)
注意:上面描述的 createPlugin 方法在堆上创建一个新的插件对象并返回指向它的指针。 确保如前所示销毁 pluginObj,以避免内存泄漏。
在序列化过程中,TensorRT 引擎内部存储了所有 IPluginV2 类型插件的插件类型、插件版本和命名空间(如果存在)。在反序列化期间,TensorRT 从插件注册表中查找插件创建器并调用 IPluginCreator::deserializePlugin()。当引擎被删除时,在引擎构建期间创建的插件对象的克隆被引擎通过调用 IPluginV2::destroy() 方法销毁。您有责任确保您创建的插件对象在添加到网络后被释放。
注意: 不要序列化所有插件参数:仅序列化插件在运行时正常运行所需的参数。 构建时参数可以省略。 以相同的顺序序列化和反序列化插件参数。 在反序列化期间,验证插件参数是否初始化为默认值或反序列化值。 未初始化的参数会导致未定义的行为。
9.1.1 示例:使用 C++ 添加具有动态形状支持的自定义层
为了支持动态形状,你的插件必须继承至IPluginV2DynamicExt
BarPlugin是一个插件有2个输入和2个输出。
- 第一个输出是第二个输入的拷贝
- 第二个输出是两个输入的串联,沿第一个维度,所有类型/格式必须相同并且是线性格式。
BarPlugin必须继承至
class BarPlugin : public IPluginV2DynamicExt { ...override virtual methods inherited from IPluginV2DynamicExt. };
4个被动态形状影响的方法是:
- getOutputDimensions
- supportsFormatCombination
- configurePlugin
- enqueue
getOutputDimensions 根据输入维度返回输出维度的符号表达式。您可以使用传递给 getOutputDimensions 的 IExprBuilder 从输入表达式构建表达式。在示例中,不必为案例 1 构建新表达式,因为第二个输出的维度与第一个输入的维度相同。
DimsExprs BarPlugin::getOutputDimensions(int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) { switch (outputIndex) { case 0: { // First dimension of output is sum of input // first dimensions. DimsExprs output(inputs[0]); output.d[0] = exprBuilder.operation(DimensionOperation::kSUM, inputs[0].d[0], inputs[1].d[0]); return output; } case 1: return inputs[0]; default: throw std::invalid_argument(“invalid output”); }
supportsFormatCombination 必须指示是否允许格式组合。该接口将输入/输出统一索引为“连接”,第一个输入从 0 开始,然后按顺序输入其余部分,然后对输出进行编号。在示例中,输入是连接 0 和 1,输出是连接 2 和 3。
TensorRT 使用 supportsFormatCombination 询问给定的格式/类型组合是否适合连接,给定的格式/类型用于索引较少的连接。因此可以假设较小的索引连接已经被审查并专注于与索引 pos 的连接。
bool BarPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override { assert(0 <= pos && pos < 4); const auto* in = inOut; const auto* out = inOut + nbInputs; switch (pos) { case 0: return in[0].format == TensorFormat::kLINEAR; case 1: return in[1].type == in[0].type && in[1].format == TensorFormat::kLINEAR; case 2: return out[0].type == in[0].type && out[0].format == TensorFormat::kLINEAR; case 3: return out[1].type == in[0].type && out[1].format == TensorFormat::kLINEAR; } throw std::invalid_argument(“invalid connection number”); }
此处的局部变量 in 和 out 允许通过输入或输出编号而不是连接编号检查 inOut。
重要:检查索引小于 pos 的连接的格式/类型,但绝不能检查索引大于 pos 的连接的格式/类型。该示例使用案例 3 根据连接 0 检查连接 3,而不使用案例 0 根据连接 3 检查连接 0。
TensorRT 使用 configurePlugin 在运行时设置插件。 这个插件不需要 configurePlugin 做任何事情,所以它是一个空操作:
void BarPlugin::configurePlugin( const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) override { }
如果插件需要知道它可能遇到的最小或最大维度,它可以检查字段 DynamicPluginTensorDesc::min 或 DynamicPluginTensorDesc::max 以获取任何输入或输出。格式和构建时维度信息可以在 DynamicPluginTensorDesc::desc 中找到。 任何运行时维度都显示为 -1。 实际尺寸提供给 BarPlugin::enqueue。
最后,覆盖 BarPlugin::enqueue 必须完成这项工作。 由于形状是动态的,因此 enqueue 会收到一个 PluginTensorDesc,它描述了每个输入和输出的实际维度、类型和格式。
9.1.2 示例:使用 C++ 添加具有 INT8 I/O 支持的自定义层
PoolPlugin 是一个插件,用于演示如何为自定义池化层扩展 INT8 I/O。
class PoolPlugin : public IPluginV2IOExt { ...override virtual methods inherited from IPluginV2IOExt. };
大部分纯虚方法都是插件通用的。 影响INT8 I/O的主要方法有:
- supportsFormatCombination
- configurePlugin
- enqueue
supportsFormatCombination 的覆盖必须指示允许哪种 INT8 I/O 组合。这个方法的使用有点类似上述动态形状的方法。在此示例中,支持的 I/O 张量格式为具有 FP32、FP16 或 INT8 数据类型的线性 CHW,但 I/O 张量必须具有相同的数据类型。
bool PoolPlugin::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override { assert(nbInputs == 1 && nbOutputs == 1 && pos < nbInputs + nbOutputs); bool condition = inOut[pos].format == TensorFormat::kLINEAR; condition &= ((inOut[pos].type == DataType::kFLOAT) || (inOut[pos].type == DataType::kHALF) || (inOut[pos].type == DataType::kINT8)); condition &= inOut[pos].type == inOut[0].type; return condition; }
重要:
- 如果 INT8 校准必须与带有 INT8 I/O 插件的网络一起使用,则该插件必须支持 FP32 I/O,因为 TensorRT 使用 FP32 来校准图。
- 如果不支持 FP32 I/O 变体或未使用 INT8 校准,则必须明确设置所有必需的 INT8 I/O 张量标度。
- 校准无法确定插件内部张量的动态范围。 对量化数据进行操作的插件必须为内部张量计算自己的动态范围。
TensorRT调用configurePlugin方法通过PluginTensorDesc将信息传递给插件,这些信息作为成员变量存储,序列化和反序列化。
void PoolPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) { ... mPoolingParams.mC = mInputDims.d[0]; mPoolingParams.mH = mInputDims.d[1]; mPoolingParams.mW = mInputDims.d[2]; mPoolingParams.mP = mOutputDims.d[1]; mPoolingParams.mQ = mOutputDims.d[2]; mInHostScale = in[0].scale >= 0.0F ? in[0].scale : -1.0F; mOutHostScale = out[0].scale >= 0.0F ? out[0].scale : -1.0F; }
每个张量的 INT8 I/O 比例可以从 PluginTensorDesc::scale 获得。
最后, UffPoolPluginV2::enqueue 必须完成这项工作。 它包括一组核心算法,可在运行时使用实际批量大小、输入、输出、cuDNN 流和配置的信息来执行自定义层。
int PoolPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { ... CHECK(cudnnPoolingForward(mCudnn, mPoolingDesc, &kONE, mSrcDescriptor, input, &kZERO, mDstDescriptor, output)); ... return 0; }
9.2 使用 Python API 添加自定义层
尽管 C++ API 是实现自定义层的首选语言,但由于可以访问 CUDA 和 cuDNN 等库,您还可以在 Python 应用程序中使用自定义层。
您可以使用 C++ API 创建自定义层,在 Python 中使用 pybind11 打包层,然后将插件加载到 Python 应用程序中。更多的信息可以参考Creating a Network Definition in Python(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#network_python
).
相同的自定义层实现可用于 C++ 和 Python。
9.2.1 示例:使用 Python 将自定义层添加到 TensorRT 网络
可以使用插件节点将自定义层添加到 Python 中的任何 TensorRT 网络。
Python API 通过 add_plugin_v2 函数,将插件节点添加到网络中。下面的例子说明了这一点。 它创建了一个简单的 TensorRT 网络,并通过查找 TensorRT 插件注册表添加了一个 leaky ReLU 插件节点。
import tensorrt as trt import numpy as np TRT_LOGGER = trt.Logger() trt.init_libnvinfer_plugins(TRT_LOGGER, '') PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list def get_trt_plugin(plugin_name): plugin = None for plugin_creator in PLUGIN_CREATORS: if plugin_creator.name == plugin_name: lrelu_slope_field = trt.PluginField("neg_slope", np.array([0.1], dtype=np.float32), trt.PluginFieldType.FLOAT32) field_collection = trt.PluginFieldCollection([lrelu_slope_field]) plugin = plugin_creator.create_plugin(name=plugin_name, field_collection=field_collection) return plugin def main(): builder = trt.Builder(TRT_LOGGER) network = builder.create_network() config = builder.create_builder_config() config.max_workspace_size = 2**20 input_layer = network.add_input(name="input_layer", dtype=trt.float32, shape=(1, 1)) lrelu = network.add_plugin_v2(inputs=[input_layer], plugin=get_trt_plugin("LReLU_TRT")) lrelu.get_output(0).name = "outputs" network.mark_output(lrelu.get_output(0))
9.3 使用解析器导入模型时使用自定义层
ONNX 解析器会自动尝试将无法识别的节点导入为插件。 如果在插件注册表中找到与节点具有相同 op_type 的插件,则解析器将节点的属性作为插件字段参数转发给插件创建者,以便创建插件。
默认情况下,解析器使用“1”作为插件版本,“”作为插件命名空间。 可以通过在相应的 ONNX 节点中设置 plugin_version 和 plugin_namespace 字符串属性来覆盖此行为。
在某些情况下,您可能希望在将 ONNX 图导入 TensorRT 之前对其进行修改。 例如,用插件节点替换一组操作。为此,您可以使用 ONNX GraphSurgeon 工具。有关如何使用 ONNX-GraphSurgeon 替换子图的详细信息可以参考this example(https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon/examples/08_replacing_a_subgraph
).
9.4 插件接口说明
所有新插件都应从 IPluginCreator 和9.1 使用C++API增加自定义层中描述的插件基类之一派生。此外,新的插件还应该调用REGISTER_TENSORRT_PLUGIN(...) 宏向TensorRT 插件注册表注册插件或创建一个相当于initLibNvInferPlugins() 的init 函数。
9.4.1 将插件从 TensorRT 6.x 或 7.x 迁移到 TensorRT 8.x.x
仍然支持 IPluginV2 和 IPluginV2Ext 以分别向后兼容 TensorRT 5.1 和 6.0.x。 但是,新插件应以 IPluginV2DynamicExt 或 IPluginV2IOExt 接口为目标,而旧插件应使用这些接口重构。
IPluginV2DynamicExt中的新特性如下:
virtual DimsExprs getOutputDimensions(int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) = 0; virtual bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) = 0; virtual void configurePlugin(const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) = 0; virtual size_t getWorkspaceSize(const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const = 0; virtual int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) = 0;
IPluginV2IOExt的新特性如下:
virtual void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) = 0; virtual bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const = 0;
迁移到 IPluginV2DynamicExt 或 IPluginV2IOExt 的指南:
- getOutputDimensions 在给定输入的情况下实现输出张量维度的表达式。
- supportsFormatCombination 检查插件是否支持指定 I/O 的格式和数据类型。
- configurePlugin 模仿 IPluginV2Ext 中等效 configurePlugin 的行为,但接受张量描述符。
- getWorkspaceSize 和 enqueue 模仿 IPluginV2Ext 中等效 API 的行为,但接受张量描述符。
参考以下说明获取更多详细信息。
9.4.2 IPluginV2 API 说明
以下部分描述了 IPluginV2 类的功能。 要将插件层连接到相邻层并设置输入和输出数据结构,构建器会通过调用以下插件方法来检查输出的数量及其维度。
- getNbOutputs 用于指定输出张量的数量。
- getOutputDimensions 用于将输出的维度指定为输入维度的函数。
- supportsFormat 用于检查插件是否支持给定的数据格式。
- getOutputDataType 用于获取给定索引处输出的数据类型。 返回的数据类型必须具有插件支持的格式。
插件层可以支持以下数据格式:
- 「LINEAR」 单精度 (FP32)、半精度 (FP16)、整数 (INT8) 和整数 (INT32) 张量
- 「CHW32」 单精度 (FP32) 和整数 (INT8) 张量
- 「CHW2,HWC8,HWC16, DHWC8」 半精度 (FP16) 张量
- 「CHW4」 半精度 (FP16) 和整数 (INT8) 张量
格式由 PluginFormatType 计算。
不就地计算所有数据并且除了输入和输出张量之外还需要内存空间的插件可以使用 getWorkspaceSize 方法指定额外的内存要求,构建器调用该方法来确定和预分配暂存空间。
在构建和推理期间,插件层被配置和执行,可能多次。 在构建时,为了发现最佳配置,层被配置、初始化、执行和终止。在为插件选择最佳格式后,再次配置插件,然后初始化一次并在推理应用程序的生命周期内根据需要执行多次,最后在引擎销毁时终止。 这些步骤由构建器和引擎使用以下插件方法控制:
- 「configurePlugin」 传达输入和输出的数量、维度和所有输入和输出的数据类型、所有输入和输出的广播信息、选择的插件格式和最大批量大小。此时,插件设置其内部状态并为给定配置选择最合适的算法和数据结构。此 API 中不允许资源分配,因为它会导致资源泄漏。
- 「initialize」 此时配置已知,正在创建推理引擎,因此插件可以设置其内部数据结构并准备执行。
- 「enqueue」 封装插件的实际算法和内核调用,并提供运行时批处理大小、指向输入、输出和暂存空间的指针,以及用于内核执行的 CUDA 流。
- 「terminate」 引擎上下文被销毁,必须释放插件持有的所有资源。
- 「clone」 每次创建包含此插件层的新构建器、网络或引擎时都会调用此方法。 它必须返回一个具有正确参数的新插件对象。
- 「destroy」 用于销毁每次创建新插件对象时分配的插件对象和其他内存。 每当构建器或网络或引擎被销毁时都会调用它。
- 「set/getPluginNamespace」 该方法用于设置该插件对象所属的库命名空间(默认可以为"")。 同一个插件库中的所有插件对象都应该有相同的命名空间。
IPluginV2Ext 支持可以处理广播输入和输出的插件。 必须为此功能实现以下方法。
- 「canBroadcastInputAcrossBatch」 为每个输入调用此方法,其张量在批次中进行语义广播。 如果它返回 true(意味着插件可以支持广播),TensorRT 不会复制输入张量。 插件应在批处理中共享一个副本。 如果它返回 false,TensorRT 会复制输入张量,使其看起来像一个非广播张量。
- 「isOutputBroadcastAcrossBatch」 这是为每个输出索引调用的。 插件应该返回 true 给定索引处的输出,并在整个批次中广播。
- 「IPluginV2IOExt」 这是由构建器在 initialize() 之前调用的。 它为该层提供了根据 I/O PluginTensorDesc 和最大批量大小进行算法选择的机会。
注意:基于 IPluginV2 的插件是在引擎级别共享的,而不是执行上下文级别,因此这种可能被多个线程同时使用的插件必须以线程安全的方式管理它们的资源。 基于 IPluginV2Ext 和派生接口的插件在创建 ExecutionContext 时被克隆,因此不是必需的。
9.4.3 IPluginCreator API 说明
IPluginCreator 类中的以下方法用于从插件注册表中查找和创建适当的插件:
- 「getPluginName」 这将返回插件名称,并且应该与 IPluginExt::getPluginType 的返回值相匹配。
- 「getPluginVersion」 返回插件版本。 对于所有内部 TensorRT 插件,这默认为 1。
- 「getFieldNames」 要成功创建插件,需要了解插件的所有字段参数。 此方法返回 PluginFieldCollection 结构,其中填充了 PluginField 条目以反映字段名称和 PluginFieldType(数据应指向 nullptr)。
- 「createPlugin」 此方法用于使用 PluginFieldCollection 参数创建插件。 PluginField 条目的数据字段应该被填充以指向每个插件字段条目的实际数据。传递给 createPlugin 函数的数据应该由调用者分配,并最终在程序销毁时由调用者释放。它返回的插件对象的所有权被传递给调用者并且也必须被销毁。
- 「deserializePlugin」 该方法由 TensorRT 引擎根据插件名称和版本在内部调用。 它应该返回用于推理的插件对象。该函数中创建的插件对象在引擎销毁时被TensorRT引擎销毁。
- 「set/getPluginNamespace」 该方法用于设置本创建者实例所属的命名空间(默认可以是"")
9.5 自定义图层插件的最佳实践
9.5.1 插件编码指南
1、内存分配
必须释放插件中分配的内存,以确保不会发生内存泄漏。 如果在initialize() 函数中获取资源,则必须在terminate() 函数中释放它们。最好是在插件类析构函数或 destroy() 方法中释放所有其他内存分配。9.1 使用C++API增加自定义层对此进行了详细概述,并提供了一些使用插件时的最佳实践说明。
2、添加检查以确保配置正确并验证输入
意外插件行为的一个常见来源是不正确的配置(例如,无效的插件属性)和无效的输入。因此,对于预计插件无法工作的情况,在初始插件开发期间添加检查/断言是一种很好的做法。以下是可以添加检查的地方:
- createPlugin 插件属性检查
- configurePlugin 输入维度检查
- enqueue 输入值检查
3、对于创建新插件对象的方法,在出错时返回 Null
createPlugin、clone 和 deserializePlugin 应该创建并返回新的插件对象。在这些方法中,确保在出现任何错误或检查失败时返回空对象(C++ 中的 nullptr)。这可确保在插件配置不正确时不会返回非空插件对象。
4、避免在 clone() 中分配设备内存
由于在构建器中多次调用克隆,因此设备内存分配可能非常昂贵。一个好的做法是在初始化时进行持久内存分配,当插件准备好使用时(例如,在 configurePlugin 中)复制到设备,并在终止时释放。
9.5.2 在隐式/显式批处理网络中使用插件
TensorRT 允许以隐式批处理模式或显式批处理模式创建网络(参考Explicit Versus Implicit Batch)。记住以下有关隐式/显式批处理模式网络中的插件行为的信息很有用:
- 实现 IPluginV2DynamicExt 的插件只能添加到以显式批处理模式配置的网络。
- 可以将非 IPluginV2DynamicExt 插件添加到以隐式或显式批处理模式配置的网络。
重要:尽管非 IPluginV2DynamicExt 插件与显式批处理模式网络兼容,但它们的实现必须独立于预期使用的网络类型(隐式/显式批处理模式)。 因此,在显式批处理模式网络中使用此类插件时:
第一个输入的前导维度(在传递给插件之前)被推断为批次维度。
TensorRT 在将输入传递给插件之前弹出上面标识的第一个维度,并将其推到插件发出的任何输出的前面。这意味着不得在 getOutputDimensions 中指定批次维度。
9.5.3 将形状张量传递给插件
TensorRT插件API不支持形状张量直接输入插件,也不支持直接输出。但是,可以使用空张量解决此限制。 使用具有感兴趣维度和零维度的虚拟输入张量,以便输入几乎不占用空间。
例如,假设一个插件必须知道一个 2 元素 1D 形状张量值 [P,Q] 来计算其输出的形状,例如,以实现 IPluginV2DynamicExt::getOutputDimensions。不是传递形状张量 [P,Q],而是将插件设计为具有虚拟输入,该虚拟输入是维度为 [0,P,Q] 的执行张量。TensorRT 将告诉插件虚拟输入的维度,插件可以从中提取 [P,Q]。 因为张量是空的,它会占用很小的空间,刚好足以给它一个不同的地址。
在网络中,通过使用零步幅切片或重塑空张量来创建虚拟输入张量。 这是使用零步幅切片的机制。
// Shape tensor of interest. Assume it has the value [P,Q]. ITensor* pq = ...; // Create an empty-tensor constant with dimensions [0,1,1]. // Since it's empty, the type doesn't matter, but let's assume float. ITensor* c011 = network.addConstant({3, {0, 1, 1}}, {DataType::kFLOAT, nullptr, 0})->getOutput(0); // Create shape tensor that has the value [0,P,Q] static int32_t const intZero = 0; ITensor* z = network.addConstant({1, {1}}, {DataType::kINT32, &intZero, 1})->getOutput(0); ITensor* concatInputs[] = {z, pq}; IConcatenationLayer* zpq = network.addConcatenation(concatInputs, 2); zpq->setAxis(0); // Create zero-stride slice with output size [0,P,Q] Dims z3{3, {0, 0, 0}}; ISliceLayer* slice = network.addSlice(*c011, z3, z3, z3); slice->setInput(2, *zpq->getOutput(0));
使用 slice->getOutput(0) 作为插件的虚拟输入。
如果使用 IShuffleLayer 创建空张量,请务必关闭 reshape 维度中对零的特殊解释,即务必调用 setZeroIsPlaceholder(false)。