Strassen矩阵乘法问题(Java)

简介: Strassen矩阵乘法问题(Java)

Strassen矩阵乘法问题(Java)


a53fa7633514475fa766316fab7a2e3e.jpeg




1、前置介绍


矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设 AB 是2个nXn矩阵,

它们的乘积 AB 同样是一个nXn矩阵。 AB 的乘积矩阵 C 中元素C[i][j]定义为:

C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j]

1.jpeg


1.png


采用传统方法,时间复杂度为:O(n3)


因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j], 需要做n次乘法运算和n-1次加法运算 。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。


为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是 分治 ,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)


我们知道,

C11=A11>*B11+A12*B21


2e.jpeg


矩阵A和B的示意图如下:


3a.jpeg


传统方法: 


4.png


2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。


由此可得:


C11 = A11B11 + A12B21


C12 = A11B12 + A12B22


C21 = A21B11 + A22B21


C22 = A21B12 + A22B22



分治法: 

为了降低时间复杂度,必须减少乘法的次数。


使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:

5.jpeg


2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。


伪代码如下:

// 递归维度分半算法:publicvoidSTRASSEN(n,A,B,C);
{  
ifn=2thenMATRIX-MULTIPLY(A,B,C)
//结束循环,计算两个2阶方阵的乘法else{
将矩阵A和B分块;
STRASSEN(n/2,A11,B12-B22,M1);
STRASSEN(n/2,A11+A12,B22,M2); 
STRASSEN(n/2,A21+A22,B11,M3);
STRASSEN(n/2,A22,B21-B11,M4);
STRASSEN(n/2,A11+A22,B11+B22,M5);
STRASSEN(n/2,A12-A22,B21+B22,M6);
STRASSEN(n/2,A11-A21,B11+B12,M7);}
}                
```算法导论伪代码:![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/cef0ab294b824ecca6731c34edbcfa45.jpeg#pic_center)##3、代码实现```javapublicclassStrassenMatrixMultiply{
publicstaticvoidmain(String[] args)
    {
int[] a=newint[]
        {
1, 1, 1, 1,
2, 2, 2, 2,
3, 3, 3, 3,
4, 4, 4, 4        };
int[] b=newint[]
        {
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4        };
intlength=4;
int[] c=sMM(a, b, length);
for(inti=0; i<c.length; i++)
        {
System.out.print(c[i] +" ");
if((i+1) %length==0) //换行System.out.println();
        }
    }
publicstaticint[] sMM(int[] a, int[] b, intlength) {
if(length==2) {
returngetResult(a, b);
        }
else {
inttlength=length/2;
// 把a数组分为四部分,进行分治递归int[] aa=newint[tlength*tlength];
int[] ab=newint[tlength*tlength];
int[] ac=newint[tlength*tlength];
int[] ad=newint[tlength*tlength];
// 把b数组分为四部分,进行分治递归int[] ba=newint[tlength*tlength];
int[] bb=newint[tlength*tlength];
int[] bc=newint[tlength*tlength];
int[] bd=newint[tlength*tlength];
// TODO 划分子矩阵for(inti=0; i<length; i++) {
for(intj=0; j<length; j++) {
/** 划分矩阵:* 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,* 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵*/if(i<tlength) {
if(j<tlength) {
aa[i*tlength+j] =a[i*length+j];
ba[i*tlength+j] =b[i*length+j];
                        } else {
ab[i*tlength+ (j-tlength)] =a[i*length+j];
bb[i*tlength+ (j-tlength)] =b[i*length+j];
                        }
                    } else {
if(j<tlength) {
//i 大于 tlength 时,需要减去 tlength,j同理//因为 b,c,d三个子矩阵有对应了父矩阵的后半部分ac[(i-tlength) *tlength+j] =a[i*length+j];
bc[(i-tlength) *tlength+j] =b[i*length+j];
                        } else {
ad[(i-tlength) *tlength+ (j-tlength)] =a[i*length+j];
bd[(i-tlength) *tlength+ (j-tlength)] =b[i*length+j];
                        }
                    }
                }
            }
// TODO 分治递归int[] result=newint[length*length];
// temp:4个临时矩阵int[] t1=add(sMM(aa, ba, tlength), sMM(ab, bc, tlength));
int[] t2=add(sMM(aa, bb, tlength), sMM(ab, bd, tlength));
int[] t3=add(sMM(ac, ba, tlength), sMM(ad, bc, tlength));
int[] t4=add(sMM(ac, bb, tlength), sMM(ad, bd, tlength));
// TODO 归并结果for(inti=0; i<length; i++) {
for(intj=0; j<length; j++) {
if (i<tlength){
if(j<tlength) {
result[i*length+j] =t1[i*tlength+j];
                        } else {
result[i*length+j] =t2[i*tlength+ (j-tlength)];
                        }
                    } else {
if(j<tlength) {
result[i*length+j] =t3[(i-tlength) *tlength+j];
                        } else {
result[i*length+j] =t4[(i-tlength) *tlength+ (j-tlength)];
                        }
                    }
                }
            }
returnresult;
        }
    }
publicstaticint[] getResult(int[] a, int[] b) {
intp1=a[0] * (b[1] -b[3]);
intp2= (a[0] +a[1]) *b[3];
intp3= (a[2] +a[3]) *b[0];
intp4=a[3] * (b[2] -b[0]);
intp5= (a[0] +a[3]) * (b[0] +b[3]);
intp6= (a[1] -a[3]) * (b[2] +b[3]);
intp7= (a[0] -a[2]) * (b[0] +b[1]);
intc00=p5+p4-p2+p6;
intc01=p1+p2;
intc10=p3+p4;
intc11=p5+p1-p3-p7;
returnnewint[] {c00, c01, c10, c11};
    }
publicstaticint[] add(int[] a, int[] b) {
int[] c=newint[a.length];
for(inti=0; i<a.length; i++) {
c[i] =a[i] +b[i];
     }
returnc;
    }
// TODO 返回一个数是不是2的幂次方publicstaticbooleanadjust(intx) {
return (x& (x-1)) ==0;
    }
}


4、复杂度分析


传统方法和分治法的复杂度比较,如下图所示;

6b.png


Snipaste_2023-01-04_13-44-48.jpeg

T(n) = 0(nlog7) = 0(n2.81)


5、参考资料

  • 算法分析与设计(第四版)
  • 算法导论第三版
  • [博客园]
目录
相关文章
|
5月前
|
存储 Java
创建一个乘法练习题生成器 using Java
创建一个乘法练习题生成器 using Java
【Java每日一题,dfs】矩阵查找字符串
【Java每日一题,dfs】矩阵查找字符串
|
3月前
|
算法 Java
LeetCode经典算法题:矩阵中省份数量经典题目+三角形最大周长java多种解法详解
LeetCode经典算法题:矩阵中省份数量经典题目+三角形最大周长java多种解法详解
51 6
|
5月前
|
Java
基本矩阵运算的Java实现
基本矩阵运算的Java实现
33 0
|
Java C++
【Java】剑指offer(29)顺时针打印矩阵
【Java】剑指offer(29)顺时针打印矩阵
|
6月前
|
Rust 索引
Rust 编程小技巧摘选(5)
Rust 编程小技巧摘选(5)
76 0
Rust 编程小技巧摘选(5)
|
6月前
|
Java Go C++
Java每日一练(20230417) N 皇后、搜索二维矩阵、发奖金问题
Java每日一练(20230417) N 皇后、搜索二维矩阵、发奖金问题
44 0
Java每日一练(20230417) N 皇后、搜索二维矩阵、发奖金问题
|
11月前
|
算法 Java
240. 搜索二维矩阵 II -- 力扣 --JAVA
编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。 每列的元素从上到下升序排列。
55 0
|
机器学习/深度学习 存储 算法
Java每日一练(20230515) 阶乘后的零、矩阵置零、两数相除
Java每日一练(20230515) 阶乘后的零、矩阵置零、两数相除
110 0
|
Java
矩阵重叠(Java实现)
矩阵重叠(Java实现)
148 1
矩阵重叠(Java实现)