Strassen’s algorithm for matrix multiplication

本文涉及的产品
文档翻译,文档翻译 1千页
语种识别,语种识别 100万字符
图片翻译,图片翻译 100张
简介:

Strassen算法能够在time cost时间内完成矩阵乘法。


package chapter4;

import Utils.P;

class Matrix {
    private int[][] data;
    int lengthX;
    int lengthY;
    private int xs;
    private int ys;

    public Matrix(int[][] data) {
        this.data = data;
        lengthX = data.length;
        lengthY = data[0].length;
        this.xs = this.ys = 0;
    }

    Matrix(int[][] data, int xs, int xe, int ys, int ye) {
        this.data = data;
        this.xs = xs;
        this.ys = ys;
        this.lengthX = xe - xs;
        this.lengthY = ye - ys;
    }

    public Matrix subMatrix(int xs, int xe, int ys, int ye) {
        return new Matrix(data, xs, xe, ys, ye);
    }

    public int get(int x, int y) {
        return data[xs + x][ys + y];
    }

    public void set(Matrix mt) {
        for (int i = 0; i < mt.lengthX; i++) {
            for (int j = 0; j < mt.lengthY; j++) {
                set(mt.get(i, j), i, j);
            }
        }
    }

    public void set(int d, int x, int y) {
        this.data[xs + x][ys + y] = d;
    }

    public Matrix minus(Matrix m) {
        if (m.lengthX != lengthX || m.lengthY != lengthY) {
            try {
                throw new Exception();
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
                System.exit(0);
            }
        }
        Matrix ret = new Matrix(new int[lengthX][lengthY]);
        for (int i = 0; i < lengthX; i++) {
            for (int j = 0; j < lengthY; j++) {
                ret.set(get(i, j) - m.get(i, j), i, j);
            }
        }
        return ret;
    }

    public Matrix sub(Matrix m) {
        if (m.lengthX != lengthX || m.lengthY != lengthY) {
            try {
                throw new Exception(m.lengthX + " " + lengthX + " " + m.lengthY
                        + " " + lengthY);
            } catch (Exception e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
                System.exit(0);
            }
        }
        Matrix ret = new Matrix(new int[lengthX][lengthY]);
        for (int i = 0; i < lengthX; i++) {
            for (int j = 0; j < lengthY; j++) {
                ret.set(get(i, j) - m.get(i, j), i, j);
            }
        }
        return ret;
    }

    public Matrix add(Matrix m) {
        if (m.lengthX != lengthX || m.lengthY != lengthY)
            return null;
        Matrix ret = new Matrix(new int[lengthX][lengthY]);
        for (int i = 0; i < lengthX; i++) {
            for (int j = 0; j < lengthY; j++) {
                ret.set(get(i, j) + m.get(i, j), i, j);
            }
        }
        return ret;
    }

    public void print() {
        for (int i = xs; i < xs + lengthX; i++) {
            for (int j = ys; j < ys + lengthY; j++) {
                P.rint(data[i][j]);
                P.rint(" ");
            }
            P.rintln();
        }
    }
}

public class Strassen {

    public static Matrix strassen(Matrix a, Matrix b) {
        Matrix c = new Matrix(new int[a.lengthX][b.lengthY]);
        if (a.lengthX == 1 && a.lengthY == 1) {
            c.set(a.get(0, 0) * b.get(0, 0), 0, 0);
            return c;
        }
        int mx = a.lengthX / 2;
        int my = a.lengthY / 2;
        Matrix a11 = a.subMatrix(0, mx, 0, my);
        Matrix a12 = a.subMatrix(0, mx, my, b.lengthY);
        Matrix a21 = a.subMatrix(mx, a.lengthX, 0, my);
        Matrix a22 = a.subMatrix(mx, a.lengthX, my, a.lengthY);
        Matrix b11 = b.subMatrix(0, mx, 0, my);
        Matrix b12 = b.subMatrix(0, mx, my, b.lengthY);
        Matrix b21 = b.subMatrix(mx, b.lengthX, 0, my);
        Matrix b22 = b.subMatrix(mx, b.lengthX, my, b.lengthY);
        Matrix c11 = c.subMatrix(0, mx, 0, my);
        Matrix c12 = c.subMatrix(0, mx, my, c.lengthY);
        Matrix c21 = c.subMatrix(mx, c.lengthX, 0, my);
        Matrix c22 = c.subMatrix(mx, c.lengthX, my, c.lengthY);

        Matrix s1 = b12.sub(b22);
        Matrix s2 = a11.add(a12);
        Matrix s3 = a21.add(a22);
        Matrix s4 = b21.sub(b11);
        Matrix s5 = a11.add(a22);
        Matrix s6 = b11.add(b22);
        Matrix s7 = a12.sub(a22);
        Matrix s8 = b21.add(b22);
        Matrix s9 = a11.sub(a21);
        Matrix s10 = b11.add(b12);

        Matrix p1 = strassen(a11, s1);
        Matrix p2 = strassen(s2, b22);
        Matrix p3 = strassen(s3, b11);
        Matrix p4 = strassen(a22, s4);
        Matrix p5 = strassen(s5, s6);
        Matrix p6 = strassen(s7, s8);
        Matrix p7 = strassen(s9, s10);

        c11.set(p5.add(p4).sub(p2).add(p6));
        c12.set(p1.add(p2));
        c21.set(p3.add(p4));
        c22.set(p5.add(p1).sub(p3).sub(p7));
        return c;
    }

    public static void main(String[] args) {
        Matrix mt = new Matrix(new int[][] { { 1, 3 }, { 7, 5 } });
        Matrix mt2 = new Matrix(new int[][] { { 6, 8 }, { 4, 2 } });
        Strassen.strassen(mt, mt2).print();
    }
}


目录
相关文章
|
机器学习/深度学习
上三角矩阵(Upper Triangular Matrix
上三角矩阵(Upper Triangular Matrix)是一种特殊形式的矩阵,其非零元素仅位于主对角线以上。在数学和工程领域中,上三角矩阵通常用于线性代数和微积分等问题。以下是一些关于上三角矩阵的特点和应用:
1435 0
uva442 Matrix Chain Multiplication
uva442 Matrix Chain Multiplication
43 0
|
存储 索引
LeetCode 73. Set Matrix Zeroes
给定一个m * n 的矩阵,如果当前元是0,则把此元素所在的行,列全部置为0. 额外要求:是否可以做到空间复杂度O(1)?
103 0
LeetCode 73. Set Matrix Zeroes
UVA442 矩阵链乘 Matrix Chain Multiplication
UVA442 矩阵链乘 Matrix Chain Multiplication
ValueError: Sample larger than population or is negative
ValueError: Sample larger than population or is negative
203 0
[LeetCode] Sparse Matrix Multiplication
Problem Description: Given two sparse matrices A and B, return the result of AB. You may assume that A's column number is equal to B's row number.
1031 0