Strassen矩阵乘法问题(Java)
1、前置介绍
矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设 A 和 B 是2个nXn矩阵,
它们的乘积 AB 同样是一个nXn矩阵。 A 和 B 的乘积矩阵 C 中元素C[i][j]定义为:
C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j]
采用传统方法,时间复杂度为:O(n3)
因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j], 需要做n次乘法运算和n-1次加法运算
。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。
为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是 分治
,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)
我们知道,
C11=A11>*B11+A12*B21
矩阵A和B的示意图如下:
传统方法:
2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。
由此可得:
C11 = A11B11 + A12B21
C12 = A11B12 + A12B22
C21 = A21B11 + A22B21
C22 = A21B12 + A22B22
分治法:
为了降低时间复杂度,必须减少乘法的次数。
使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:
2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。
伪代码如下:
// 递归维度分半算法:publicvoidSTRASSEN(n,A,B,C); { ifn=2thenMATRIX-MULTIPLY(A,B,C) //结束循环,计算两个2阶方阵的乘法else{ 将矩阵A和B分块; STRASSEN(n/2,A11,B12-B22,M1); STRASSEN(n/2,A11+A12,B22,M2); STRASSEN(n/2,A21+A22,B11,M3); STRASSEN(n/2,A22,B21-B11,M4); STRASSEN(n/2,A11+A22,B11+B22,M5); STRASSEN(n/2,A12-A22,B21+B22,M6); STRASSEN(n/2,A11-A21,B11+B12,M7);} } ```算法导论伪代码:![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/cef0ab294b824ecca6731c34edbcfa45.jpeg#pic_center)##3、代码实现```javapublicclassStrassenMatrixMultiply{ publicstaticvoidmain(String[] args) { int[] a=newint[] { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4 }; int[] b=newint[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; intlength=4; int[] c=sMM(a, b, length); for(inti=0; i<c.length; i++) { System.out.print(c[i] +" "); if((i+1) %length==0) //换行System.out.println(); } } publicstaticint[] sMM(int[] a, int[] b, intlength) { if(length==2) { returngetResult(a, b); } else { inttlength=length/2; // 把a数组分为四部分,进行分治递归int[] aa=newint[tlength*tlength]; int[] ab=newint[tlength*tlength]; int[] ac=newint[tlength*tlength]; int[] ad=newint[tlength*tlength]; // 把b数组分为四部分,进行分治递归int[] ba=newint[tlength*tlength]; int[] bb=newint[tlength*tlength]; int[] bc=newint[tlength*tlength]; int[] bd=newint[tlength*tlength]; // TODO 划分子矩阵for(inti=0; i<length; i++) { for(intj=0; j<length; j++) { /** 划分矩阵:* 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,* 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵*/if(i<tlength) { if(j<tlength) { aa[i*tlength+j] =a[i*length+j]; ba[i*tlength+j] =b[i*length+j]; } else { ab[i*tlength+ (j-tlength)] =a[i*length+j]; bb[i*tlength+ (j-tlength)] =b[i*length+j]; } } else { if(j<tlength) { //i 大于 tlength 时,需要减去 tlength,j同理//因为 b,c,d三个子矩阵有对应了父矩阵的后半部分ac[(i-tlength) *tlength+j] =a[i*length+j]; bc[(i-tlength) *tlength+j] =b[i*length+j]; } else { ad[(i-tlength) *tlength+ (j-tlength)] =a[i*length+j]; bd[(i-tlength) *tlength+ (j-tlength)] =b[i*length+j]; } } } } // TODO 分治递归int[] result=newint[length*length]; // temp:4个临时矩阵int[] t1=add(sMM(aa, ba, tlength), sMM(ab, bc, tlength)); int[] t2=add(sMM(aa, bb, tlength), sMM(ab, bd, tlength)); int[] t3=add(sMM(ac, ba, tlength), sMM(ad, bc, tlength)); int[] t4=add(sMM(ac, bb, tlength), sMM(ad, bd, tlength)); // TODO 归并结果for(inti=0; i<length; i++) { for(intj=0; j<length; j++) { if (i<tlength){ if(j<tlength) { result[i*length+j] =t1[i*tlength+j]; } else { result[i*length+j] =t2[i*tlength+ (j-tlength)]; } } else { if(j<tlength) { result[i*length+j] =t3[(i-tlength) *tlength+j]; } else { result[i*length+j] =t4[(i-tlength) *tlength+ (j-tlength)]; } } } } returnresult; } } publicstaticint[] getResult(int[] a, int[] b) { intp1=a[0] * (b[1] -b[3]); intp2= (a[0] +a[1]) *b[3]; intp3= (a[2] +a[3]) *b[0]; intp4=a[3] * (b[2] -b[0]); intp5= (a[0] +a[3]) * (b[0] +b[3]); intp6= (a[1] -a[3]) * (b[2] +b[3]); intp7= (a[0] -a[2]) * (b[0] +b[1]); intc00=p5+p4-p2+p6; intc01=p1+p2; intc10=p3+p4; intc11=p5+p1-p3-p7; returnnewint[] {c00, c01, c10, c11}; } publicstaticint[] add(int[] a, int[] b) { int[] c=newint[a.length]; for(inti=0; i<a.length; i++) { c[i] =a[i] +b[i]; } returnc; } // TODO 返回一个数是不是2的幂次方publicstaticbooleanadjust(intx) { return (x& (x-1)) ==0; } }
4、复杂度分析
传统方法和分治法的复杂度比较,如下图所示;
T(n) = 0(nlog7) = 0(n2.81)
5、参考资料
- 算法分析与设计(第四版)
- 算法导论第三版
- [博客园]