`
shuofenglxy
  • 浏览: 189614 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

矩阵链乘法算法讲解

阅读更多

矩阵链乘是一个计算性问题,是动态规划的适用范例。

动态规划要满足以下三个条件:

1 最优化原理(最优子结构性质)

最优化原理可这样阐述:一个最优化策略具有这样的性质,不论过去状态和决策如何,对前面的决策所形成的状态而言,余下的诸决策必须构成最优策略。简而言之,一个最优化策略的子策略总是最优的。一个问题满足最优化原理又称其具有最优子结构性质

 

2.无后向性

将各阶段按照一定的次序排列好之后,对于某个给定的阶段状态,它以前各阶段的状态无法直接影响它未来的决策,而只能通过当前的这个状态。换句话说,每个状态都是过去历史的一个完整总结。这就是无后向性 ,又称为无后效性

如果用前面的记号来描述无后向性,就是:对于确定的xk ,无论p1,k-1 如何,最优子策略pkn* 是唯一确定的,这种性质称为无后向性。

 

3.子问题的重叠性

动态规划将原来具有指数级复杂度的搜索算法改进成了具有多项式时间的算法。其中的关键在于解决冗余 ,这是动态规划算法的根本目的。动态规划实质上是一种以空间换时间的技术,它在实现的过程中,不得不存储产生过程中的各种状态,所以它的空间复杂度要大于其它的算法。

 

具体矩阵链乘分析:

 矩阵乘法有这样的特点: A(m,n),B(n,q)两个矩阵相乘产生一个C(m,q),总的计算次数为m*n*q。而且,矩阵链中,任何两个连续矩阵提前相乘对产生的最终结果没有影响,只是对产生的运算次数有影响。所以,基于这点,通过估算运算次数,可以选出运算次数最少的方案进行矩阵乘法运算,节约运算次数。

 

具体思路参见代码及注释:

矩阵类:

package matrixchain;

import java.util.Random;

public class Matrix {

	private int m;//矩阵行
	
	private int n;//矩阵列
	
	private double [][] matrix ;//存放矩阵元素
	private  static int count=0;//统计一次运算中调用的乘法运算次数
	
	public static int getCount() {
		return count;
	}

	public static void setCount(int count) {
		Matrix.count = count;
	}

	public  Matrix(int m,int n){
		this.setM(m);
		this.setN(n);
		setMatrix(new double[m][n]);
		
	}
	
	public  static  Matrix mutiplyMatrix(Matrix A,Matrix B){
		if(A.getN()!=B.getM()){
			System.out.println("不合矩阵相乘条件");
			System.exit(0);
		}
		Matrix result = new Matrix(A.getM(),B.getN());
		for(int i = 0;i<result.getM();i++)
			for(int j = 0;j<result.getN();j++){
				double temp=0;
				for(int p = 0;p<A.getN();p++){
					temp += A.getMatrix()[i][p]*B.getMatrix()[p][j];
					count++;
					}
				result.getMatrix()[i][j] = temp;
			}
		return result;
	}

	public static Matrix generateElement(Matrix input){
		Random random = new Random(47);
		for(int i = 0;i<input.getM();i++)
			for(int j = 0;j<input.getN();j++)
				input.getMatrix()[i][j] = random.nextDouble()*1000;
		
		return input;
	}
	public void setM(int m) {
		this.m = m;
	}

	public int getM() {
		return m;
	}

	public void setN(int n) {
		this.n = n;
	}

	public int getN() {
		return n;
	}

	public void setMatrix(double[][] matrix) {
		this.matrix = matrix;
	}

	public double [][] getMatrix() {
		return matrix;
	}
}
 
矩阵链类:
package matrixchain;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class MatrixChain {
	/**
	 * 存放矩阵
	 */
	private List<Matrix> matrixChain;
	
	/**
	 * 存放矩阵子链计算次数
	 */
	private  double[][] m;
	/**
	 * 存放子链最优划分位置
	 */
	private int[][] s;
	
	private static int[] sizeLimit;
	
	
	public MatrixChain(int[] sizeLimit){
		int n = sizeLimit.length-1;
		this.setMatrixChain(generateMatrixChain(sizeLimit));
		this.setM(new double[n][n]);
		this.setS(new int[n][n]);
	}
	
	/**
	 * 生成矩阵链
	 * @param sizeLimit 这是矩阵的行列数组 根据接着的两个元素表示一个矩阵的行与列 
	 * 矩阵链为如下形式:A0A1.....An-1
	 * @return
	 */
	private static  List<Matrix> generateMatrixChain(int[] sizeLimit){
		List<Matrix> resultChain = new ArrayList<Matrix>();
		for(int i =0;i<sizeLimit.length-1;i++){
			Matrix temp = new Matrix(sizeLimit[i],sizeLimit[i+1]);
			Matrix.generateElement(temp);
			resultChain.add(temp);
		} 
		return resultChain;
	}
	/**
	 * 生成numbs个整数表示n-1个矩阵的行与列大小
	 * 此处的15 并没有特殊含义,只是为了保证数组中元素值大小都不为0
	 * @param nums
	 */
	public static int[] generateMatrixDetails(int nums){
		Random random = new Random(47);
		sizeLimit = new int[nums];
		for(int i=0;i<nums;i++){
				sizeLimit[i]= (int)random.nextInt(nums);
			}
		return sizeLimit;
	}
	/**
	 * 普通 的矩阵链直接顺序乘法
	 * @param matrixChain
	 * @return
	 */
	public Matrix mutiplyMatrixChain(List<Matrix> matrixChain){
		checkMatrixChain(matrixChain);
		
		Matrix result =null;
		for(int i=0;i<matrixChain.size()-1;i++){
			if(i==0)
				result = Matrix.mutiplyMatrix(matrixChain.get(0), matrixChain.get(1));
			else
				result = Matrix.mutiplyMatrix(result, matrixChain.get(i+1));
		}
		return result;
	}
	/**
	 * 优化的矩阵链乘
	 * @param matrixChain
	 * @return
	 */
	public Matrix optimizedMutiplyMatixChain(List<Matrix> matrixChain){
		checkMatrixChain(matrixChain);
		
		Matrix result =null;
		long start = System.nanoTime();
		calMinMutiplyTimes(matrixChain,sizeLimit);
		long end = System.nanoTime();
		System.out.println("OptimizedMutiplyMatixChain Method calculate postion totally costs  "+(end-start)+"  nanoseconds");
		result = mutiplyMatrixChainOptimized(matrixChain,s,0,matrixChain.size()-1);
		return result;
	}
	
	 private Matrix mutiplyMatrixChainOptimized(List<Matrix> matrixChain,int[][]s,int i,int j){

         Matrix x,y;

         if (j>i){

                x=mutiplyMatrixChainOptimized(matrixChain,s,i,s[i][j]);

                y=mutiplyMatrixChainOptimized(matrixChain,s,s[i][j]+1,j);

                return Matrix.mutiplyMatrix(x, y);
                

         }

        return matrixChain.get(i); 

  }
	private void initMArray(){
		for(int i =0;i<m.length;i++)
			m[i][i] = 0;
		
		for(int i=0;i<m.length;i++)
			for(int j =i+1;j<m.length;j++)
				m[i][j] = 999999999999999999999.0;
	}
	
	private void initSArray(){
		for(int i=0;i<m.length;i++)
				s[i][i] = 0;
	}
	
	/**
	 * 计算最小乘法次数,将子链最小次数放在m数组中,将加括号位置放在s数组中
	 * @param matrixChain
	 * @param sizeLimit
	 * 算法思想:
	 * 1 计算出m[p][q]的起始值并将括号位置放在p位置
	 * 2在p q之间取出括号位置,计算出比原值小的m[p][q]值替换掉原值,并替换掉括号位置
	 */
	
	private void calMinMutiplyTimes(List<Matrix> matrixChain, int[] sizeLimit) {

		initMArray();
		initSArray();

		int n = matrixChain.size();
		for (int l = 2; l < n; l++) {

			for (int i = 1; i < (n - l + 1); i++) {

				int j = i + l - 1;

				for (int k = i; k < j; k++) {

					double q = m[i][k] + m[k + 1][j] + sizeLimit[i - 1]
							* sizeLimit[k] * sizeLimit[j];

					if (q < m[i][j]) {

						m[i][j] = q;

						s[i][j] = k;

					}

				}

			}

		}

	}
	
	/**
	 * 校验输入的矩阵链是否符合乘法运算条件
	 * @param matrixChain
	 */
	private void checkMatrixChain(List<Matrix> matrixChain){
		if(matrixChain.size()<2){
			System.out.println("矩阵个数小于2,不能做乘法");
			System.exit(0);
		}
	}
	
	
	public void setM(double[][] m) {
		this.m = m;
	}
	public double[][] getM() {
		return m;
	}
	public void setS(int[][] s) {
		this.s = s;
	}
	public int[][] getS() {
		return s;
	}
	public void setMatrixChain(List<Matrix> matrixChain) {
		this.matrixChain = matrixChain;
	}
	public List<Matrix> getMatrixChain() {
		return matrixChain;
	}
	
	public int[] getSizeLimit(){
		return MatrixChain.sizeLimit;
	}
}
 
测试类:
package matrixchain;
/**
 * 
 * @author shuofenglxy
 */
public class MatrixChainMuitiplyTest {

	
	public static void main(String[]args){
		
		MatrixChain demo = new MatrixChain(MatrixChain.generateMatrixDetails(200));
		
		
		/**优化的的矩阵链取最佳加括号位置相乘
		 * 统计消耗时间和共计的加法次数
		 */
		long startNomal = System.nanoTime();
		demo.optimizedMutiplyMatixChain(demo.getMatrixChain());
		long endNomal = System.nanoTime();
		System.out.println("OptimizedMutiplyMatixChain Method totally costs  "+(endNomal-startNomal)+"  nanoseconds");
		System.out.println("OptimizedMutiplyMatixChain Method caluclating times is: "+Matrix.getCount()+" times");
		
		//将统计运算次数总数归0
		Matrix.setCount(0);
		
		System.out.println("---------------------");
		
		/**正常的矩阵链直接相乘
		 * 统计消耗时间和共计的加法次数
		 */
		startNomal = System.nanoTime();
		demo.mutiplyMatrixChain(demo.getMatrixChain());
		endNomal = System.nanoTime();
		System.out.println("NomalMethod totally costs  "+(endNomal-startNomal)+"  nanoseconds");
		System.out.println("NomalMethod caluclating times is: "+Matrix.getCount()+" times");
	}
}
 

测试结果:

OptimizedMutiplyMatixChain Method calculate postion totally costs  19624306  nanoseconds
OptimizedMutiplyMatixChain Method totally costs  385362418  nanoseconds
OptimizedMutiplyMatixChain Method caluclating times is: 52070102 times
---------------------
NomalMethod totally costs  768270995  nanoseconds
NomalMethod caluclating times is: 111755908 times

 

结果分析:

可以看到,通过提前计算好加括号位置,虽然耗费一定的时间,但大大减少的运算次数,从而节省了矩阵链乘的总时间,这里相比原来的直接相乘,空间代价增加,但时间代价减少,也是经典的空间换时间思路。

参考资料:

http://hi.baidu.com/lewutian/blog/item/f46bf609b7c55ec53ac7633c.html

算法导论相关章节

 

1
1
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics