开发者社区> 问答> 正文

为什么numpy的einsum比numpy的内置函数快?

让我们以的三个数组开始dtype=np.double。使用numpy 1.7.1编译icc并链接到intel的numpy 1.7.1在intel CPU上执行计时mkl。一个AMD的CPU与编译numpy的1.6.1 gcc不mkl也被用来验证的时序。请注意,计时几乎与系统大小成线性比例,并且不是由于numpy函数if语句中的开销很小,这些差异将以微秒而非毫秒显示:

arr_1D=np.arange(500,dtype=np.double) large_arr_1D=np.arange(100000,dtype=np.double) arr_2D=np.arange(5002,dtype=np.double).reshape(500,500) arr_3D=np.arange(5003,dtype=np.double).reshape(500,500,500) 首先让我们看一下np.sum函数:

np.all(np.sum(arr_3D)==np.einsum('ijk->',arr_3D)) True

%timeit np.sum(arr_3D) 10 loops, best of 3: 142 ms per loop

%timeit np.einsum('ijk->', arr_3D) 10 loops, best of 3: 70.2 ms per loop 权力:

np.allclose(arr_3Darr_3Darr_3D,np.einsum('ijk,ijk,ijk->ijk',arr_3D,arr_3D,arr_3D)) True

%timeit arr_3Darr_3Darr_3D 1 loops, best of 3: 1.32 s per loop

%timeit np.einsum('ijk,ijk,ijk->ijk', arr_3D, arr_3D, arr_3D) 1 loops, best of 3: 694 ms per loop 外部产品:

np.all(np.outer(arr_1D,arr_1D)==np.einsum('i,k->ik',arr_1D,arr_1D)) True

%timeit np.outer(arr_1D, arr_1D) 1000 loops, best of 3: 411 us per loop

%timeit np.einsum('i,k->ik', arr_1D, arr_1D) 1000 loops, best of 3: 245 us per loop 以上所有的速度是的两倍np.einsum。这些应该是苹果与苹果的比较,因为一切都是专门的dtype=np.double。我希望这样的操作会加快速度:

np.allclose(np.sum(arr_2D*arr_3D),np.einsum('ij,oij->',arr_2D,arr_3D)) True

%timeit np.sum(arr_2D*arr_3D) 1 loops, best of 3: 813 ms per loop

%timeit np.einsum('ij,oij->', arr_2D, arr_3D) 10 loops, best of 3: 85.1 ms per loop Einsum似乎是至少两倍快np.inner,np.outer,np.kron和,np.sum不管axes选择。主要的例外是np.dot 从BLAS库调用DGEMM。那么,为什么np.einsum其他同等的numpy函数更快呢?

DGEMM案例的完整性:

np.allclose(np.dot(arr_2D,arr_2D),np.einsum('ij,jk',arr_2D,arr_2D)) True

%timeit np.einsum('ij,jk',arr_2D,arr_2D) 10 loops, best of 3: 56.1 ms per loop

%timeit np.dot(arr_2D,arr_2D) 100 loops, best of 3: 5.17 ms per loop 领先的理论来自@sebergs注释,它np.einsum可以利用SSE2,但是numpy的ufuncs直到numpy 1.8才会使用(请参阅更改日志)。我相信这是正确的答案,但无法确认。通过更改输入数组的dtype并观察速度差异以及并非每个人都观察到相同的时序趋势这一事实,可以找到一些有限的证明。 问题来源于stack overflow

展开
收起
保持可爱mmm 2020-02-08 21:55:38 682 0
1 条回答
写回答
取消 提交回答
  • 现在numpy 1.8已发布,根据文档,所有ufunc都应使用SSE2,我想再次检查一下Seberg关于SSE2的评论是否有效。

    为了执行测试,创建了一个新的python 2.7安装-在icc运行Ubuntu的AMD opteron内核上使用标准选项编译了numpy 1.7和1.8 。

    这是在1.8升级之前和之后进行的测试:

    import numpy as np import timeit

    arr_1D=np.arange(5000,dtype=np.double) arr_2D=np.arange(5002,dtype=np.double).reshape(500,500) arr_3D=np.arange(5003,dtype=np.double).reshape(500,500,500)

    print 'Summation test:' print timeit.timeit('np.sum(arr_3D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print timeit.timeit('np.einsum("ijk->", arr_3D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print '----------------------\n'

    print 'Power test:' print timeit.timeit('arr_3Darr_3Darr_3D', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print timeit.timeit('np.einsum("ijk,ijk,ijk->ijk", arr_3D, arr_3D, arr_3D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print '----------------------\n'

    print 'Outer test:' print timeit.timeit('np.outer(arr_1D, arr_1D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print timeit.timeit('np.einsum("i,k->ik", arr_1D, arr_1D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print '----------------------\n'

    print 'Einsum test:' print timeit.timeit('np.sum(arr_2D*arr_3D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print timeit.timeit('np.einsum("ij,oij->", arr_2D, arr_3D)', 'import numpy as np; from main import arr_1D, arr_2D, arr_3D', number=5)/5 print '----------------------\n' Numpy 1.7.1:

    Summation test: 0.172988510132 0.0934836149216

    Power test: 1.93524689674 0.839519000053

    Outer test: 0.130380821228 0.121401786804

    Einsum test: 0.979052495956 0.126066613197 numpy 1.8:

    Summation test: 0.116551589966 0.0920487880707

    Power test: 1.23683619499 0.815982818604

    Outer test: 0.131808176041 0.127472200394

    Einsum test: 0.781750011444 0.129271841049 我认为这是相当确定的,因为SSE在时序差异中起着很大的作用,应该注意的是,重复这些测试仅会使时序接近0.003s。其余的差异应包含在该问题的其他答案中。

    2020-02-08 21:55:58
    赞同 展开评论 打赏
问答分类:
问答标签:
问答地址:
问答排行榜
最热
最新

相关电子书

更多
低代码开发师(初级)实战教程 立即下载
冬季实战营第三期:MySQL数据库进阶实战 立即下载
阿里巴巴DevOps 最佳实践手册 立即下载