爱因斯坦求和约定 含代码

简介: 爱因斯坦求和约定 含代码

一、简介

       爱因斯坦求和约定(Einstein summation convention)是一种标记的约定, 又称为爱因斯坦标记法(Einstein notation), 可以基于一些约定简写格式表示多维线性代数数组操作,让表达式更加简洁明了,比如通过省略求和符号

       我们先来看两个概念,自由标和哑标:

1.自由标

       自由标是在表达式的两边都出现,并且不遵循求和约定的指标。自由标用于指示表达式结果中保留的维度。在爱因斯坦求和约定中,自由标的顺序决定了结果张量的维度顺序。下图中i是自由标、j是哑标。


     爱因斯坦和表示为

2.哑标

       哑标是在表达式的同一边出现两次的指标,遵循求和约定,即对这个指标进行求和。它在张量运算中仅仅是起到一个辅助的作用,并不影响最终结果的形状,相同的哑标表示对应位置进行求和。下图中的i和j都是哑标。



       总的来说,自由标用于表示张量运算的结果的维度,而哑标则是进行求和操作时的辅助指标

二、torch实现

       Einsum在torch、tf和numpy中都有实现,而且用方式差不多,这里我们以torch为例,使用torch.einsum方法。

   总体思想是用一些下标标记输入的每个维度,并定义哪些下标是输出的一部分。然后,通过将操作中下标不属于输出的维度的元素先乘积再求和来计算输出。下面是一些例子,还是很好理解的。

   值得注意的是torch.einsum会自动调整张量的乘法顺序以匹配所需的乘法操作,并且会自动处理张量的维度匹配。因此,无论参数的顺序如何,都会得到相同的结果

1.计算迹

torch.einsum('ii', torch.randn(4, 4))
# tensor(-1.2104)

       ii表示对第一个维度和第二个维度取相同的索引值,并对所有这些索引值的元素进行求和。在一个方阵中,就是对其对角线上的元素求和。没有显式的输出就是先求和再输出。

2.取矩阵对角线

torch.einsum('ii->i', torch.randn(4, 4))
# tensor([-0.1034,  0.7952, -0.2433,  0.4545])

       ii表示对张量的第一个维度和第二个维度取相同的索引值,并对所有这些索引值的元素进行操作。而 ->i 表示我们希望得到的输出张量的形状是一个一维张量,其中包含对每个索引值进行操作后的结果。

3.计算外积

x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)
# tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
#         [-0.3744,  0.9381,  1.2685, -1.6070],
#         [ 0.7208, -1.8058, -2.4419,  3.0936],
#         [ 0.1713, -0.4291, -0.5802,  0.7350],
#         [ 0.5704, -1.4290, -1.9323,  2.4480]])

       i 和 j 表示两个张量 x 和 y 的维度。标记中的箭头 ->ij 表示我们希望得到的输出张量的形状是一个二维张量,其中第一个维度的大小与 x 张量的大小相同,第二个维度的大小与 y 张量的大小相同。


       i,j->ij表示对两个张量 x 和 y 进行乘法操作,并返回一个形状为 (len(x), len(y)) 的二维张量,其中的每个元素是两个输入张量对应位置元素的乘积,即外积。

4.batch矩阵乘法

As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)
torch.einsum('bij,bjk->bik', As, Bs)
# tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
#          [-1.6706, -0.8097, -0.8025, -2.1183]],
# 
#         [[ 4.2239,  0.3107, -0.5756, -0.2354],
#          [-1.4558, -0.3460,  1.5087, -0.8530]],
# 
#         [[ 2.8153,  1.8787, -4.3839, -1.2112],
#          [ 0.3728, -2.1131,  0.0921,  0.8305]]])

       bij 和 bjk 分别代表两个输入张量 As 和 Bs 的维度。bik 表示我们希望得到的输出张量的形状。


       bij,bjk->bik表示对两个张量 As 和 Bs 进行乘法操作,并返回一个形状为 (b, i, k) 的张量,其中 b 是批量大小,i 是 As 张量的第二个维度大小,k 是 Bs 张量的第三个维度大小。


   一行代码,将转置和乘法放在一起,确实很方便。

5.带有子列表和省略号

As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
# tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
#          [-1.6706, -0.8097, -0.8025, -2.1183]],
# 
#         [[ 4.2239,  0.3107, -0.5756, -0.2354],
#          [-1.4558, -0.3460,  1.5087, -0.8530]],
# 
#         [[ 2.8153,  1.8787, -4.3839, -1.2112],
#          [ 0.3728, -2.1131,  0.0921,  0.8305]]])

       [..., 0, 1] 表示对 As 进行切片操作。省略号 ... 表示我们不关心其他的维度,而 [0, 1] 表示我们选择 As 张量的最后两个维度。


       [..., 1, 2] 表示对 Bs 进行切片操作。同样,省略号 ... 表示其他的维度不变,而 [1, 2] 表示我们选择 Bs 张量的最后两个维度。


       [..., 0, 2] 表示我们希望得到的输出张量的形状。同样,省略号 ... 表示其他的维度不变,而 [0, 2] 表示我们选择输出张量的倒数第二个维度和最后一个维度。


       所以下面的代码表示对输入张量 As 和 Bs 进行一系列切片操作,并对结果进行乘法和求和,最后返回一个输出张量,其形状与输入张量的形状相同,但最后两个维度的顺序交换了位置。

6.变换维度

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])

       ...ij表示输入张量 A 的维度。省略号 ... 表示可以匹配任意数量的维度,而 'ij' 表示张量中的最后两个维度。


       ->...ji表示我们希望得到的输出张量的形状,其维度与输入张量的维度相同,但是最后两个维度交换了位置。


       ...ij->...ji 表示对输入张量 A 进行转置操作。

7.双线性变换,类似于torch.nn.functional.bilinear

l = torch.randn(2,5)
A = torch.randn(3,5,4)
r = torch.randn(2,4)
torch.einsum('bn,anm,bm->ba', l, A, r)
# tensor([[-0.3430, -5.2405,  0.4494],
#         [ 0.3311,  5.5201, -3.0356]])

       这个比较复杂,计算步骤如下:


       a.bn和bm中的b表示l和r在这个维度上是相同的,所以会对这个维度进行广播操作,得到中间值:维度(2,5,4),即(b,n,m)


       b.A维度(a,n,m)中的n和m与中间值的n和m相对应,表示在这两个维度上进行乘法操作。


       c.对n和m维度上的结果进行求和,得到最终的输出张量,其形状由->ba指定,即(2, 3)。


       这个例子有点绕,在实际工作中也不会经常遇到,还是建议大家把逻辑写的可读性强一点,这样以后的你会感激现在的自己。


       爱因斯坦求和约定就介绍到这里,点个关注不迷路(#^.^#)!

相关文章
|
3月前
|
机器学习/深度学习 存储 人工智能
矩阵乘法运算:在这看似枯燥的数字组合中,究竟蕴含着怎样令人称奇的奥秘?
【8月更文挑战第19天】矩阵乘法不仅是数学概念,还在工程、图像处理及AI等领域发挥核心作用。例如,通过矩阵乘法可精确实现图像变换;在神经网络中,它帮助模型学习和优化以识别图像和理解语言。两个矩阵A(m×n)与B(n×p)相乘得C(m×p),其中C[i,j]为A的第i行与B的第j列元素乘积之和。尽管面临维度匹配等挑战,矩阵乘法仍在持续推动技术创新。下次享受智能服务时,不妨想想背后的矩阵乘法吧。
68 3
|
3月前
|
算法
聊聊一个面试中经常出现的算法题:组合运算及其实际应用例子
聊聊一个面试中经常出现的算法题:组合运算及其实际应用例子
|
5月前
高等数学II-知识点(1)——原函数的概念、不定积分、求原函数的两种常用方法 (凑微分法、第二换元法)、分部积分法、有理函数原函数求法、典型三角函数原函数求法
高等数学II-知识点(1)——原函数的概念、不定积分、求原函数的两种常用方法 (凑微分法、第二换元法)、分部积分法、有理函数原函数求法、典型三角函数原函数求法
104 1
|
Java
【附录】概率基本性质与法则的推导证明
本文从概率论三大公理出发,推导证明概率基本法则。
149 0
【附录】概率基本性质与法则的推导证明
|
算法
绪论以及递归式上界函数的证明
绪论以及递归式上界函数的证明
|
算法 索引 Python
从一道简单算法题里面解释什么叫做 O(1)
从一道简单算法题里面解释什么叫做 O(1)
116 0
|
机器学习/深度学习
划重点!通俗解释协方差与相关系数
划重点!通俗解释协方差与相关系数
540 1
划重点!通俗解释协方差与相关系数
|
存储 算法
算法~简单的计算器(验证数学表达式是否合法~“状态机思想”)
算法~简单的计算器(验证数学表达式是否合法~“状态机思想”)
329 0
算法~简单的计算器(验证数学表达式是否合法~“状态机思想”)
|
机器学习/深度学习 移动开发
【计算理论】可判定性 ( 对角线方法 | 证明自然数集 N 与实数集 R 不存在一一对应关系 )
【计算理论】可判定性 ( 对角线方法 | 证明自然数集 N 与实数集 R 不存在一一对应关系 )
341 0