java实现任意矩阵Strassen算法
本例输入为两个任意尺寸的矩阵m*n,n*m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l,使得l=2^k,k为整数并且l<m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java
packagematrixalgorithm;
importjava.util.Scanner;
publicclassStrassenMethodTest{
privateStrassenMethodstrassenMultiply;
StrassenMethodTest(){
strassenMultiply=newStrassenMethod();
}//endcons
publicstaticvoidmain(String[]args){
Scannerinput=newScanner(System.in);
System.out.println("Inputrowsizeofthefirstmatrix:");
intarow=input.nextInt();
System.out.println("Inputcolumnsizeofthefirstmatrix:");
intacol=input.nextInt();
System.out.println("Inputrowsizeofthesecondmatrix:");
intbrow=input.nextInt();
System.out.println("Inputcolumnsizeofthesecondmatrix:");
intbcol=input.nextInt();
double[][]A=newdouble[arow][acol];
double[][]B=newdouble[brow][bcol];
double[][]C=newdouble[arow][bcol];
System.out.println("InputdataformatrixA:");
/*Inallofthecodeslaterinthisproject,
rmeansrowwhilecmeanscolumn.
*/
for(intr=0;r<arow;r++){
for(intc=0;c<acol;c++){
System.out.printf("DataofA[%d][%d]:",r,c);
A[r][c]=input.nextDouble();
}//endinnerloop
}//endloop
System.out.println("InputdataformatrixB:");
for(intr=0;r<brow;r++){
for(intc=0;c<bcol;c++){
System.out.printf("DataofA[%d][%d]:",r,c);
B[r][c]=input.nextDouble();
}//endinnerloop
}//endloop
StrassenMethodTestalgorithm=newStrassenMethodTest();
C=algorithm.multiplyRectMatrix(A,B,arow,acol,brow,bcol);
//Displaythecalculationresult:
System.out.println("ResultfrommatrixC:");
for(intr=0;r<arow;r++){
for(intc=0;c<bcol;c++){
System.out.printf("DataofC[%d][%d]:%f\n",r,c,C[r][c]);
}//endinnerloop
}//endoutterloop
}//endmain
//Dealwithmatricesthatarenotsquare:
publicdouble[][]multiplyRectMatrix(double[][]A,double[][]B,
intarow,intacol,intbrow,intbcol){
if(arow!=bcol)//Invalidmultiplicatio
returnnewdouble[][]{{0}};
double[][]C=newdouble[arow][bcol];
if(arow<acol){
double[][]newA=newdouble[acol][acol];
double[][]newB=newdouble[brow][brow];
intn=acol;
for(intr=0;r<acol;r++)
for(intc=0;c<acol;c++)
newA[r][c]=0.0;
for(intr=0;r<brow;r++)
for(intc=0;c<brow;c++)
newB[r][c]=0.0;
for(intr=0;r<arow;r++)
for(intc=0;c<acol;c++)
newA[r][c]=A[r][c];
for(intr=0;r<brow;r++)
for(intc=0;c<bcol;c++)
newB[r][c]=B[r][c];
double[][]C2=multiplySquareMatrix(newA,newB,n);
for(intr=0;r<arow;r++)
for(intc=0;c<bcol;c++)
C[r][c]=C2[r][c];
}//endif
elseif(arow==acol)
C=multiplySquareMatrix(A,B,arow);
else{
intn=arow;
double[][]newA=newdouble[arow][arow];
double[][]newB=newdouble[bcol][bcol];
for(intr=0;r<arow;r++)
for(intc=0;c<arow;c++)
newA[r][c]=0.0;
for(intr=0;r<bcol;r++)
for(intc=0;c<bcol;c++)
newB[r][c]=0.0;
for(intr=0;r<arow;r++)
for(intc=0;c<acol;c++)
newA[r][c]=A[r][c];
for(intr=0;r<brow;r++)
for(intc=0;c<bcol;c++)
newB[r][c]=B[r][c];
double[][]C2=multiplySquareMatrix(newA,newB,n);
for(intr=0;r<arow;r++)
for(intc=0;c<bcol;c++)
C[r][c]=C2[r][c];
}//endelse
returnC;
}//endmethod
//Dealwithmatricesthataresquarematrices.
publicdouble[][]multiplySquareMatrix(double[][]A2,double[][]B2,intn){
double[][]C2=newdouble[n][n];
for(intr=0;r<n;r++)
for(intc=0;c<n;c++)
C2[r][c]=0;
if(n==1){
C2[0][0]=A2[0][0]*B2[0][0];
returnC2;
}//endif
intexp2k=2;
while(exp2k<=(n/2)){
exp2k*=2;
}//endloop
if(exp2k==n){
C2=strassenMultiply.strassenMultiplyMatrix(A2,B2,n);
returnC2;
}//endelse
//The"biggest"strassenmatrix:
double[][][]A=newdouble[6][exp2k][exp2k];
double[][][]B=newdouble[6][exp2k][exp2k];
double[][][]C=newdouble[6][exp2k][exp2k];
for(intr=0;r<exp2k;r++){
for(intc=0;c<exp2k;c++){
A[0][r][c]=A2[r][c];
B[0][r][c]=B2[r][c];
}//endinnerloop
}//endoutterloop
C[0]=strassenMultiply.strassenMultiplyMatrix(A[0],B[0],exp2k);
for(intr=0;r<exp2k;r++)
for(intc=0;c<exp2k;c++)
C2[r][c]=C[0][r][c];
intmiddle=exp2k/2;
for(intr=0;r<middle;r++){
for(intc=exp2k;c<n;c++){
A[1][r][c-exp2k]=A2[r][c];
B[3][r][c-exp2k]=B2[r][c];
}//endinnerloop
}//endoutterloop
for(intr=exp2k;r<n;r++){
for(intc=0;c<middle;c++){
A[3][r-exp2k][c]=A2[r][c];
B[1][r-exp2k][c]=B2[r][c];
}//endinnerloop
}//endoutterloop
for(intr=middle;r<exp2k;r++){
for(intc=exp2k;c<n;c++){
A[2][r-middle][c-exp2k]=A2[r][c];
B[4][r-middle][c-exp2k]=B2[r][c];
}//endinnerloop
}//endoutterloop
for(intr=exp2k;r<n;r++){
for(intc=middle;c<n-exp2k+1;c++){
A[4][r-exp2k][c-middle]=A2[r][c];
B[2][r-exp2k][c-middle]=B2[r][c];
}//endinnerloop
}//endoutterloop
for(inti=1;i<=4;i++)
C[i]=multiplyRectMatrix(A[i],B[i],middle,A[i].length,A[i].length,middle);
/*
Calculatethefinalresultsofgridsinthe"biggest2^ksquare,
accordingtotherulesofmatricemultiplication.
*/
for(introw=0;row<exp2k;row++){
for(intcol=0;col<exp2k;col++){
for(intk=exp2k;k<n;k++){
C2[row][col]+=A2[row][k]*B2[k][col];
}//endloop
}//endinnerloop
}//endoutterloop
//Usebruteforcetosolvetherest,willbeimprovedlater:
for(intcol=exp2k;col<n;col++){
for(introw=0;row<n;row++){
for(intk=0;k<n;k++)
C2[row][col]+=A2[row][k]*B2[k][row];
}//endinnerloop
}//endoutterloop
for(introw=exp2k;row<n;row++){
for(intcol=0;col<exp2k;col++){
for(intk=0;k<n;k++)
C2[row][col]+=A2[row][k]*B2[k][row];
}//endinnerloop
}//endoutterloop
returnC2;
}//endmethod
}//endclass
StrassenMethod.java
packagematrixalgorithm;
importjava.util.Scanner;
publicclassStrassenMethod{
privatedouble[][][][]A=newdouble[2][2][][];
privatedouble[][][][]B=newdouble[2][2][][];
privatedouble[][][][]C=newdouble[2][2][][];
/*//Codesfortestingthisclass:
publicstaticvoidmain(String[]args){
Scannerinput=newScanner(System.in);
System.out.println("Inputsizeofthematrix:");
intn=input.nextInt();
double[][]A=newdouble[n][n];
double[][]B=newdouble[n][n];
double[][]C=newdouble[n][n];
System.out.println("InputdataformatrixA:");
for(intr=0;r<n;r++){
for(intc=0;c<n;c++){
System.out.printf("DataofA[%d][%d]:",r,c);
A[r][c]=input.nextDouble();
}//endinnerloop
}//endloop
System.out.println("InputdataformatrixB:");
for(intr=0;r<n;r++){
for(intc=0;c<n;c++){
System.out.printf("DataofA[%d][%d]:",r,c);
B[r][c]=input.nextDouble();
}//endinnerloop
}//endloop
StrassenMethodalgorithm=newStrassenMethod();
C=algorithm.strassenMultiplyMatrix(A,B,n);
System.out.println("ResultfrommatrixC:");
for(intr=0;r<n;r++){
for(intc=0;c<n;c++){
System.out.printf("DataofC[%d][%d]:%f\n",r,c,C[r][c]);
}//endinnerloop
}//endoutterloop
}//endmain*/
publicdouble[][]strassenMultiplyMatrix(double[][]A2,doubleB2[][],intn){
double[][]C2=newdouble[n][n];
//Initializethematrix:
for(introwIndex=0;rowIndex<n;rowIndex++)
for(intcolIndex=0;colIndex<n;colIndex++)
C2[rowIndex][colIndex]=0.0;
if(n==1)
C2[0][0]=A2[0][0]*B2[0][0];
//"Slicematricesinto2*2parts:
else{
double[][][][]A=newdouble[2][2][n/2][n/2];
double[][][][]B=newdouble[2][2][n/2][n/2];
double[][][][]C=newdouble[2][2][n/2][n/2];
for(intr=0;r<n/2;r++){
for(intc=0;c<n/2;c++){
A[0][0][r][c]=A2[r][c];
A[0][1][r][c]=A2[r][n/2+c];
A[1][0][r][c]=A2[n/2+r][c];
A[1][1][r][c]=A2[n/2+r][n/2+c];
B[0][0][r][c]=B2[r][c];
B[0][1][r][c]=B2[r][n/2+c];
B[1][0][r][c]=B2[n/2+r][c];
B[1][1][r][c]=B2[n/2+r][n/2+c];
}//endloop
}//endloop
n=n/2;
double[][][]S=newdouble[10][n][n];
S[0]=minusMatrix(B[0][1],B[1][1],n);
S[1]=addMatrix(A[0][0],A[0][1],n);
S[2]=addMatrix(A[1][0],A[1][1],n);
S[3]=minusMatrix(B[1][0],B[0][0],n);
S[4]=addMatrix(A[0][0],A[1][1],n);
S[5]=addMatrix(B[0][0],B[1][1],n);
S[6]=minusMatrix(A[0][1],A[1][1],n);
S[7]=addMatrix(B[1][0],B[1][1],n);
S[8]=minusMatrix(A[0][0],A[1][0],n);
S[9]=addMatrix(B[0][0],B[0][1],n);
double[][][]P=newdouble[7][n][n];
P[0]=strassenMultiplyMatrix(A[0][0],S[0],n);
P[1]=strassenMultiplyMatrix(S[1],B[1][1],n);
P[2]=strassenMultiplyMatrix(S[2],B[0][0],n);
P[3]=strassenMultiplyMatrix(A[1][1],S[3],n);
P[4]=strassenMultiplyMatrix(S[4],S[5],n);
P[5]=strassenMultiplyMatrix(S[6],S[7],n);
P[6]=strassenMultiplyMatrix(S[8],S[9],n);
C[0][0]=addMatrix(minusMatrix(addMatrix(P[4],P[3],n),P[1],n),P[5],n);
C[0][1]=addMatrix(P[0],P[1],n);
C[1][0]=addMatrix(P[2],P[3],n);
C[1][1]=minusMatrix(minusMatrix(addMatrix(P[4],P[0],n),P[2],n),P[6],n);
n*=2;
for(intr=0;r<n/2;r++){
for(intc=0;c<n/2;c++){
C2[r][c]=C[0][0][r][c];
C2[r][n/2+c]=C[0][1][r][c];
C2[n/2+r][c]=C[1][0][r][c];
C2[n/2+r][n/2+c]=C[1][1][r][c];
}//endinnerloop
}//endoutterloop
}//endelse
returnC2;
}//endmethod
//Addtwomatricesaccordingtomatrixaddition.
privatedouble[][]addMatrix(double[][]A,double[][]B,intn){
doubleC[][]=newdouble[n][n];
for(intr=0;r<n;r++)
for(intc=0;c<n;c++)
C[r][c]=A[r][c]+B[r][c];
returnC;
}//endmethod
//Substracttwomatricesaccordingtomatrixaddition.
privatedouble[][]minusMatrix(double[][]A,double[][]B,intn){
doubleC[][]=newdouble[n][n];
for(intr=0;r<n;r++)
for(intc=0;c<n;c++)
C[r][c]=A[r][c]-B[r][c];
returnC;
}//endmethod
}//endclass
希望本文所述对大家学习java程序设计有所帮助。