前言
一个深度学习框架的初步实现为例,讨论如何在一个相对较大的项目中深入应用元编程,为系统优化提供更多的可能。
以下是本书的原文《C++模板元编程实战》,由李伟先生所著写。
一、循环执行的代码
有如下一个例子:给定一个无符号整数,求该整数所对应的二进制表示中 1 的个数。
同学们可以先想一想,如何在编译器实现这个代码,再看如下给出的示例
temolate <size_t input> constexpr size_t OnesCount = (input % 2) + OnesCount<(input / 2); template <> constexpr size_t OnesCount<0> = 0; constexpr size_t res = OnesCount<45>;
同学们可以想想是怎样实现的再看下方讲解!!!
递归展开过程
OnesCount<N> / \ (N % 2) OnesCount<(N / 2)> / \ / \ OnesCount<K> (K % 2) OnesCount<(K / 2)> / \ / \ OnesCount<J> ... / \ / \ OnesCount<2> (2 % 2) OnesCount<I> ... / ... (2 / 2) = 1 / \ OnesCount<1> OnesCount<0> ... / ... (1 % 2) (1 / 2) = 0
代码示例讲解
展示每次递归调用的结果:
OnesCount<45> = (45 % 2) + OnesCount<22> = 1 + OnesCount<22>
然后进一步展开 `OnesCount<22>`:
OnesCount<22> = (22 % 2) + OnesCount<11> = 0 + OnesCount<11>
继续展开 `OnesCount<11>`:
OnesCount<11> = (11 % 2) + OnesCount<5> = 1 + OnesCount<5>
然后展开 `OnesCount<5>`:
OnesCount<5> = (5 % 2) + OnesCount<2> = 1 + OnesCount<2>
继续展开 `OnesCount<2>`:
OnesCount<2> = (2 % 2) + OnesCount<1> = 0 + OnesCount<1>
最终展开到 `OnesCount<1>`:
OnesCount<1> = (1 % 2) + OnesCount<0> = 1 + OnesCount<0>
最后,`OnesCount<0>` 是递归的结束条件:
OnesCount<0> = 0
将上述结果依次代入原表达式可以得到最终结果:
OnesCount<1> = 1 + OnesCount<0> = 1 + 0 = 1 OnesCount<2> = 0 + OnesCount<1> = 0 + 1 = 1 OnesCount<5> = 1 + OnesCount<2> = 1 + 1 = 2 OnesCount<11> = 1 + OnesCount<5> = 1 + 2 = 3 OnesCount<22> = 0 + OnesCount<11> = 0 + 3 = 3 OnesCount<45> = 1 + OnesCount<22> = 1 + 3 = 4
所以,对于输入值 `45`,其二进制表示中包含 4 个位为 1。
1.1 数组处理
template <size_t...Inputs> constexpr size_t Accumulate = 0; template <size_t CurInput, size_t...Inputs> constexpr size_t Accumulate<CurIput, Inputs...> = CurInput + Accumulate<Inputs...>; constexpr size_t res = Accumulate<1, 2, 3, 4, 5>;
代码讲解
上述代码是一个使用可变模板参数和递归调用的示例。这个代码片段展示了如何使用模板元编程的方式计算一系列整数的累加和。
首先,我们定义了一个模板 `Accumulate`,它是一个递归模板,并设置基本情况的初始值为 0。这个模板接受一个可变数量的模板参数 `Inputs`。
template <size_t... Inputs> constexpr size_t Accumulate = 0;
然后,我们定义了另一个模板部分特化,用于递归展开 `Accumulate`。这个模板的第一个模板参数 `CurInput` 是当前要累加的值,后面的模板参数 `Inputs` 是剩余的参数序列。
template <size_t CurInput, size_t... Inputs> constexpr size_t Accumulate<CurInput, Inputs...> = CurInput + Accumulate<Inputs...>;
我们使用递归调用来展开参数序列 `Inputs`,每次递归调用将会取出序列中的第一个值 `CurInput`,并与累加值进行相加。递归展开将会一直进行到参数序列为空的情况,即达到了模板的基本情况 `Accumulate = 0`。
最后,我们使用具体的数值调用 `Accumulate` 模板来计算结果。在这个例子中,我们使用参数序列 `{1, 2, 3, 4, 5}` 来调用模板。
constexpr size_t res = Accumulate<1, 2, 3, 4, 5>;
计算过程
Accumulate<1, 2, 3, 4, 5> = 1 + Accumulate<2, 3, 4, 5> = 1 + 2 + Accumulate<3, 4, 5> = 1 + 2 + 3 + Accumulate<4, 5> = 1 + 2 + 3 + 4 + Accumulate<5> = 1 + 2 + 3 + 4 + 5 + Accumulate<> = 1 + 2 + 3 + 4 + 5 + 0 = 15
所以,根据给定的参数序列,`Accumulate<1, 2, 3, 4, 5>` 的结果是 `15`。
1.2 C++17 fold expression简便写法
折叠表达式提供了一种更简洁的方法来实现对参数序列的累加操作。
template <size_t... values> constexpr size_t fun() { return (0 + ... + values); } constexpr size_t res = fun<1, 2, 3, 4, 5>();
折叠表达式提供了一种更简洁的方法来实现对参数序列的累加操作。
0 + 1 + 2 + 3 + 4 + 5 = 15
1.3 C++17 fold expression 介绍
当我们使用可变参数模板时,`fold expression` 提供了一种更简洁的语法来对参数序列执行各种操作,比如求和、求积、逻辑与/或等。
折叠表达式(fold expression)的一般语法形式如下:
(操作符 ... op) // 从左至右展开 (op ... 操作符) // 从右至左展开
其中,`操作符` 是要执行的操作,可以是二元操作符,也可以是逗号表达式,而 `op` 是要折叠的参数序列。
下面是一些常见的折叠表达式用法:
取和操作:
template <typename... Ts> bool all(Ts... args) { return (args && ...); // 对逻辑与操作符进行折叠 } template <typename... Ts> bool any(Ts... args) { return (args || ...); // 对逻辑或操作符进行折叠 } template <typename... Ts> auto sum(Ts... args) { return (args + ...); // 对加法操作符进行折叠 }
取乘积操作:
template <typename... Ts> auto multiply(Ts... args) { return (args * ...); // 对乘法操作符进行折叠 }
字符串拼接:
template <typename... Ts> std::string concatenate(Ts... args) { return (std::string("") + ... + args); // 对字符串拼接进行折叠 }
在折叠表达式中,操作符将会在参数序列中的每个参数之间进行运算,一直折叠到最终生成一个值。展开的顺序可以是从左到右或从右到左,具体取决于折叠表达式的写法。
例如,对于折叠表达式 `(args && ...)`,它将会计算 `args` 参数序列中所有参数的逻辑与操作;而对于折叠表达式 `(args + ...) + init`,它将会从左到右依次累加 `args` 参数序列,最后再加上初始值 `init`。
二、小心:实例化爆炸与编译崩溃
书中原图如下
代码示例
template <size_t A> struct Wrap_ { template <size_t ID, typename TDummy = void> struct imp { constexpr static size_t value = ID + imp<ID - 1>:;value }; template <typename TDummy> struct imp<0, TDummy> { constexpr static size_t value = 0; }; template <size_t ID> constexpr static size_t value = imp<A + ID>::value; }; int main() { std::cerr << Wrap_<3>::value<2> << std::endl; std::cerr << Wrap_<10>::value<2> << std::endl; }
这段代码定义了一个模板类Wrap_,其中模板参数A表示要进行求和的起始值。Wrap_内部定义了一个内嵌的模板结构imp,用于执行求和操作。
imp模板结构有两个模板参数:ID表示当前要求和的值,TDummy是一个占位类型参数。imp内部有一个静态成员变量value,表示求和结果。
第一个部分是递归定义的imp结构模板,当ID不为0时,使用递归计算ID + imp<ID - 1>::value作为当前的value值。
第二个部分是递归的终止条件,当ID为0时,value被定义为0。
在Wrap_模板类内部,还定义了一个模板结构变量value,它是通过使用imp模板结构来获得计算结果的简便方式。使用模板参数A与ID的和作为imp的模板参数。
在main函数中,通过调用Wrap_类模板,并指定A的值为3和10,以及ID的值为2,分别输出了对应的求和结果。输出结果分别为 3 + 2 = 5 和 10 + 2 = 12。
2.1 代码问题
看如下原图解释
2.2 名字空间污染?
template <size_t ID> struct imp { constexpr static size_t value = ID + imp<ID - 1>::value; }; template <> struct imp<0> { constexpr static size_t value = 0; }; template <size_t A> struct Wrap_ { template <size_t ID> constexpr static size_t value = imp<A + ID>::value; };
在后面的实现中,特化了imp模板结构体,并将其定义为一个完全特化版本,这个特化版本会对相同的命名空间造成污染。
在这种情况下,如果尝试在相同的命名空间中引入另一个名为imp的构造,将会发生名称冲突,导致编译错误。这是因为已经存在一个完全特化的imp模板结构体,编译器无法区分它们。
三、分支选择与短路逻辑
以下内容请结合书中原文一起看
修改后的代码
template <bool cur, typename TNext> constexpr static bool AndValue = false; template <typename TNext> constexpr static bool AndValue<true, TNext> = TNext::value; template <size_t N> struct AllOdd_ { constexpr static bool is_cur_odd = is_odd<N>; constexpr static bool value = AndValue<is_cur_odd, AllOdd_N - 1>>; };
代码讲解
上面的代码展示了一个模板元编程的示例,用于检查给定范围内的所有整数是否都为奇数。让我们一步步详细讲解这些代码的含义和作用。
首先,我们定义了一个`AndValue`模板变量模板,该模板接受两个模板参数:`cur`表示当前条件的值,`TNext`表示下一个条件的类型。它初始化为`false`,表示默认情况下条件不匹配。
template <bool cur, typename TNext> constexpr static bool AndValue = false;
然后,我们对`AndValue`进行了部分特化,当`cur`为`true`时,它的值将由`TNext::value`决定。
template <typename TNext> constexpr static bool AndValue<true, TNext> = TNext::value;
接下来,我们定义了一个`AllOdd_`结构体模板,该模板接受一个非负整数`N`作为模板参数。它包含两个静态成员变量:`is_cur_odd`表示当前值`N`是否为奇数,`value`表示在给定范围内所有整数是否都为奇数。
template <size_t N> struct AllOdd_ { constexpr static bool is_cur_odd = is_odd<N>; constexpr static bool value = AndValue<is_cur_odd, AllOdd_<N - 1>>; };
在这个结构体中,我们使用了一个递归的方式来检查给定范围内的所有整数是否都为奇数。我们使用`is_odd<N>`来判断当前值`N`是否为奇数,并将结果赋值给`is_cur_odd`。
然后,我们使用`AndValue<is_cur_odd, AllOdd_<N - 1>>`来检查当前值是否为奇数,并将结果与范围内其他值的奇偶性结果进行逻辑与操作。递归的终止条件是`N`为0,此时我们认为0是一个奇数。通过不断减小`N`的值,我们最终检查了给定范围内的所有整数。
需要注意的是,上面的代码中没有提供完整的定义`is_odd`和`AllOdd_N`,这些是必要的辅助函数和模板参数,用于判断一个整数是否为奇数并限定给定范围。
总结
前面所讲的内容,例如输入一个类型,返回相应的指针类型的元函数是最低级的;在此之上则是包含了之前所讲的顺序、分支与现在所讲的循环逻辑的元函数,如果你感觉都掌握了,后面将开始讲解更高级的元编程方式,奇特的递归模板式就是其中之一。如果对前面还没有完全掌握,除了结合原文读文章外还可以先去掌握基础的模板以及动手写代码示例。