Strassen矩阵乘法问题(Java)

简介: Strassen矩阵乘法问题(Java)

Strassen矩阵乘法问题(Java)


a53fa7633514475fa766316fab7a2e3e.jpeg




1、前置介绍


矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设 AB 是2个nXn矩阵,

它们的乘积 AB 同样是一个nXn矩阵。 AB 的乘积矩阵 C 中元素C[i][j]定义为:

C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j]

1.jpeg


1.png


采用传统方法,时间复杂度为:O(n3)


因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j], 需要做n次乘法运算和n-1次加法运算 。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。


为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是 分治 ,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)


我们知道,

C11=A11>*B11+A12*B21


2e.jpeg


矩阵A和B的示意图如下:


3a.jpeg


传统方法: 


4.png


2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。


由此可得:


C11 = A11B11 + A12B21


C12 = A11B12 + A12B22


C21 = A21B11 + A22B21


C22 = A21B12 + A22B22



分治法: 

为了降低时间复杂度,必须减少乘法的次数。


使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:

5.jpeg


2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。


伪代码如下:

// 递归维度分半算法:publicvoidSTRASSEN(n,A,B,C);
{  
ifn=2thenMATRIX-MULTIPLY(A,B,C)
//结束循环,计算两个2阶方阵的乘法else{
将矩阵A和B分块;
STRASSEN(n/2,A11,B12-B22,M1);
STRASSEN(n/2,A11+A12,B22,M2); 
STRASSEN(n/2,A21+A22,B11,M3);
STRASSEN(n/2,A22,B21-B11,M4);
STRASSEN(n/2,A11+A22,B11+B22,M5);
STRASSEN(n/2,A12-A22,B21+B22,M6);
STRASSEN(n/2,A11-A21,B11+B12,M7);}
}                
```算法导论伪代码:![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/cef0ab294b824ecca6731c34edbcfa45.jpeg#pic_center)##3、代码实现```javapublicclassStrassenMatrixMultiply{
publicstaticvoidmain(String[] args)
    {
int[] a=newint[]
        {
1, 1, 1, 1,
2, 2, 2, 2,
3, 3, 3, 3,
4, 4, 4, 4        };
int[] b=newint[]
        {
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4        };
intlength=4;
int[] c=sMM(a, b, length);
for(inti=0; i<c.length; i++)
        {
System.out.print(c[i] +" ");
if((i+1) %length==0) //换行System.out.println();
        }
    }
publicstaticint[] sMM(int[] a, int[] b, intlength) {
if(length==2) {
returngetResult(a, b);
        }
else {
inttlength=length/2;
// 把a数组分为四部分,进行分治递归int[] aa=newint[tlength*tlength];
int[] ab=newint[tlength*tlength];
int[] ac=newint[tlength*tlength];
int[] ad=newint[tlength*tlength];
// 把b数组分为四部分,进行分治递归int[] ba=newint[tlength*tlength];
int[] bb=newint[tlength*tlength];
int[] bc=newint[tlength*tlength];
int[] bd=newint[tlength*tlength];
// TODO 划分子矩阵for(inti=0; i<length; i++) {
for(intj=0; j<length; j++) {
/** 划分矩阵:* 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,* 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵*/if(i<tlength) {
if(j<tlength) {
aa[i*tlength+j] =a[i*length+j];
ba[i*tlength+j] =b[i*length+j];
                        } else {
ab[i*tlength+ (j-tlength)] =a[i*length+j];
bb[i*tlength+ (j-tlength)] =b[i*length+j];
                        }
                    } else {
if(j<tlength) {
//i 大于 tlength 时,需要减去 tlength,j同理//因为 b,c,d三个子矩阵有对应了父矩阵的后半部分ac[(i-tlength) *tlength+j] =a[i*length+j];
bc[(i-tlength) *tlength+j] =b[i*length+j];
                        } else {
ad[(i-tlength) *tlength+ (j-tlength)] =a[i*length+j];
bd[(i-tlength) *tlength+ (j-tlength)] =b[i*length+j];
                        }
                    }
                }
            }
// TODO 分治递归int[] result=newint[length*length];
// temp:4个临时矩阵int[] t1=add(sMM(aa, ba, tlength), sMM(ab, bc, tlength));
int[] t2=add(sMM(aa, bb, tlength), sMM(ab, bd, tlength));
int[] t3=add(sMM(ac, ba, tlength), sMM(ad, bc, tlength));
int[] t4=add(sMM(ac, bb, tlength), sMM(ad, bd, tlength));
// TODO 归并结果for(inti=0; i<length; i++) {
for(intj=0; j<length; j++) {
if (i<tlength){
if(j<tlength) {
result[i*length+j] =t1[i*tlength+j];
                        } else {
result[i*length+j] =t2[i*tlength+ (j-tlength)];
                        }
                    } else {
if(j<tlength) {
result[i*length+j] =t3[(i-tlength) *tlength+j];
                        } else {
result[i*length+j] =t4[(i-tlength) *tlength+ (j-tlength)];
                        }
                    }
                }
            }
returnresult;
        }
    }
publicstaticint[] getResult(int[] a, int[] b) {
intp1=a[0] * (b[1] -b[3]);
intp2= (a[0] +a[1]) *b[3];
intp3= (a[2] +a[3]) *b[0];
intp4=a[3] * (b[2] -b[0]);
intp5= (a[0] +a[3]) * (b[0] +b[3]);
intp6= (a[1] -a[3]) * (b[2] +b[3]);
intp7= (a[0] -a[2]) * (b[0] +b[1]);
intc00=p5+p4-p2+p6;
intc01=p1+p2;
intc10=p3+p4;
intc11=p5+p1-p3-p7;
returnnewint[] {c00, c01, c10, c11};
    }
publicstaticint[] add(int[] a, int[] b) {
int[] c=newint[a.length];
for(inti=0; i<a.length; i++) {
c[i] =a[i] +b[i];
     }
returnc;
    }
// TODO 返回一个数是不是2的幂次方publicstaticbooleanadjust(intx) {
return (x& (x-1)) ==0;
    }
}


4、复杂度分析


传统方法和分治法的复杂度比较,如下图所示;

6b.png


Snipaste_2023-01-04_13-44-48.jpeg

T(n) = 0(nlog7) = 0(n2.81)


5、参考资料

  • 算法分析与设计(第四版)
  • 算法导论第三版
  • [博客园]
目录
相关文章
|
5月前
|
Java 程序员
【Java编程实现 9 * 9 乘法表格打印四种形态,七种打法】
【Java编程实现 9 * 9 乘法表格打印四种形态,七种打法】
29 0
|
Java
Java实现最小二乘法线性拟合,传感与检测,单臂半桥全桥实验,江南大学自动化
Java实现最小二乘法线性拟合,传感与检测,单臂半桥全桥实验,江南大学自动化
163 0
Java实现最小二乘法线性拟合,传感与检测,单臂半桥全桥实验,江南大学自动化
|
算法 Java
java float乘法不正确的解决办法
java float乘法不正确的解决办法
|
算法 Java
51 Nod 1028 大数乘法 V2【Java大数乱搞】
1028 大数乘法 V2 基准时间限制:2 秒 空间限制:131072 KB 分值: 80 难度:5级算法题 给出2个大整数A,B,计算A*B的结果。 Input 第1行:大数A 第2行:大数B (A,B的长度 = 0) Output 输出A * B Input示...
1117 0
|
Java
51 Nod 1027 大数乘法【Java大数乱搞】
1027 大数乘法 基准时间限制:1 秒 空间限制:131072 KB 分值: 0 难度:基础题 给出2个大整数A,B,计算A*B的结果。 Input 第1行:大数A 第2行:大数B (A,B的长度 = 0) Output 输出A * B Input示例 123...
1113 0
|
Java
java报告(一)编程打印一个三角形的乘法口诀表
 编程打印一个三角形的乘法口诀表(注意对齐),并练习对程序进行单步运行、断点调试等。   实验要求: 1. 在实验报告中给出程序运行结果截图。 2. 源程序代码附到实验报告的最后。 3. 认真填写实验报告并妥善存档,在下次上机实验课之前发送电子版实验报告至 wsycup@foxmail.com。
1034 0
|
人工智能 Java
Java工作利器之常用工具类(二)——数字工具类-大数乘法、加法、减法运算
上篇分享了一下数字转汉字的小功能,这里再分享一下大数相乘、相加、相减的功能。其他的不做过多的铺垫了,我先讲一下各个功能的计算原理。 Ⅰ. 乘法运算 为什么先说乘法运算——因为我先做了乘法运算。
1217 0
|
19小时前
|
Java
Java一分钟:线程协作:wait(), notify(), notifyAll()
【5月更文挑战第11天】本文介绍了Java多线程编程中的`wait()`, `notify()`, `notifyAll()`方法,它们用于线程间通信和同步。这些方法在`synchronized`代码块中使用,控制线程执行和资源访问。文章讨论了常见问题,如死锁、未捕获异常、同步使用错误及通知错误,并提供了生产者-消费者模型的示例代码,强调理解并正确使用这些方法对实现线程协作的重要性。
9 3
|
19小时前
|
安全 算法 Java
Java一分钟:线程同步:synchronized关键字
【5月更文挑战第11天】Java中的`synchronized`关键字用于线程同步,防止竞态条件,确保数据一致性。本文介绍了其工作原理、常见问题及避免策略。同步方法和同步代码块是两种使用形式,需注意避免死锁、过度使用导致的性能影响以及理解锁的可重入性和升级降级机制。示例展示了同步方法和代码块的运用,以及如何避免死锁。正确使用`synchronized`是编写多线程安全代码的核心。
10 2