Skip to content

Latest commit

 

History

History

matrix-chain-mult

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 

Matrix Chain Multiplication

We are given a sequence (chain) <A1, A2, ..., An> of n matrices to be multiplied, and we wish to compute the product A1.A2...An.

Matrix multiplication is associative, and so all parenthesizations yield the same product. A product of matrices is fully paranthesized if it is either a single matrix or the product of two fully parenthesized matrix products, surrounded by parentheses.

For example, if the chain of matrices is <A1,A2,A3,A4>, then we can fully parenthesize the product A1.A2.A3.A4 in five distinct ways.

  • (A1.(A2.(A3.A4))
  • (A1.((A2.A3).A4))
  • ((A1.A2).(A3.A4))
  • ((A1.(A2.A3)).A4)
  • (((A1.A2).A3).A4)

How we parenthesize a chain of matrices can have a dramatic impact on the the cost of evaluating the product.

MATRIX-MULTIPLY(A, B)
    if A.columns != B.rows
        error "incompatible dimensions"
    else
        let C be a new A.rows x B.columns matrix
            for i = 1 to A.rows
                for j = 1 to B.columns
                    c_ij = 0
                    for k = 1 to A.columns
                    c_ij = a_ik * b_kj
        return C

Optimizazation problem

We state the matrix-chain multiplication problem as follows: given a chain <A1,...,An> of n matrices, where for i=1..n, matrix A_i, fully parenthesize the product in a way that minimized the number of scalar multiplications.

Cost

We can multiply two matrices A and B only if they are compatible: the number of A.columns must equal number of B.rows. If A is a p x q matrix and B is a q x r matrix, the resulting matric C is a p x r matrix. The time to compute C is dominated by the number of scalar multiplicatins, which is p * q * r.

Counting number of parenthesizations

Before solving matrix-chain multiplication problem by dynamic programming, let us convince ourselves that exhaustively checking all possible parenthesizations does not yield and efficient algorithm.

Denote the number of alternative parenthesizations of a sequence of n matrices by P(n). When n=1, we have just one matrix and therefore P(n) = 1. When n >= 2, a fully parenthesized matrix product is the product of two fully parenthesized matrix subproducts, and the split between the two subproducts may occur between the kth and (k+1)st matrices for any k = 1, ..., n-1. Thus, we obtain the recurrence:

P(n) = 1 // if n = 1
P(n) = Sum(from k=1 to n-1): P(k) * P(n-k) // if n >= 2

A solution to a similar recurrence is the sequence of Catalan numbers, which grows as _Gamma(4^n / n^(3/2)). Thge solution is thus exponential to n, and the brute-force method of exhaustive search makes for a poor strategy when determining how to optimally parenthesize a matrix chain.

Applying dynamic programming

Step 1: Structure of an optimal parenthesization

Let's adopt the notation A_i_j where i <= j for the matrix tthat results from evaluating the product A_i, A_i+1, ..., A_j. Observe that if the problem is nontrivial (i < j), then to parenthesize the product, we must split he product between A_k and A_k+1, for some integer k in the range of i <= k < j. That is, for some value k, we first compute the matrices A_i_k and A_k+1_j, and then multiply them togeter to produce the final product A_i_j. The cost of parenthesizing this way is the cost of computing the matrix A_i_k plus the cost of computing A_k+1_j, plus the cost of multiplying them together.

The optimal substructure of this problem is as follows.

The way we parenthesize the "prefix" subchain A_i_k within the optimal parenthesization of A_i_j, must be an optimal parenthesization. The similar holds for A_k+1_j.

Now we use our optimal substructure to show that __we can construct an optimal solution to the problem from optimal solutions to subproblems.

Step 2: A recursive solution

We define the cost of an optimal solution recursively in terms of the optimal solutions to subproblems.

For matrix-chain multiplication subproblems, we pick as our subproblems the problems of determining the minimum cost of parenthesizing A_i_j for 1 <= i <=j <= n.

Let m[i,j] be the minimum number of scalar multiplications needed to compute the matrix A_i_j; for the full problem, the lowest cost way to complete A_1_n would thus be m[1,n].

We can define m[i,j] recursively as follows:

If i = j, the proble mis trivial, chain consists of just one matrix A_i_j = A_i, so that no scalar multiplications are necessary to compute the product. Thus, m[i,i] = 0.

To compute m[i.j] when i < j, we take advantage of the structure of an optimal solution from step 1. Let us assume that to ptimally parenthesize, we split the product between k and k+1. Recalling that each matrix A_i is p_i-1 * p_i, we see that computing the the matrix product A_i_k * A_k+1_j takes p_i-1 * p_k * p_j scalar multiplications. Thus m[i,j] = m[i,k] + ,[k+1, j] + p_i-1 * p_k * p_j

m[i,j] = 0 // if i == j
m[i,j] = min(i <= k < j): m[i,j] = m[i,k] + ,[k+1, j] + p_i-1 * p_k * p_j // if i < j

This does not provide information about how to construct an optimal solution, only the costs of optimal solution. To help us do so, we define s[i,j] to be a value of k at which we split the product A_i_j in an optimal parenthesization.

Step 3: Computing optimal costs

We can write a recursive algorithm based on previous recurrence to compute the minimum cost m[1,n] for multiplying A_1_n, which would take exponential time and is no better than the brute-force method.

Observe that we have realtively few distinct problems, one subproble mfor each choice of i and j satisfying 1 <= i <= j <= n or combinatory_number(n, 2) + n which is Theta(n^2).

Instead of computing the solution to recurrence recursively, we compute the optimal cost by using a bottom-up approach.

Following procedure assumes that matrix A_i has dimensions p_i-1 x p_i for i = 1...n. Its input is a sequence p = <p_0, ..., p_n>, where p.length = n + 1.

MATRIX-CHAIN-ORDER(p)       // O(n^3) time and Theta(n^2) space
    n = p.length - 1
    let m[1..n, 1..n] and s[1..n-1, 2..n] be new tables
    for i = 1 to n
        m[i,i] = 0  // end of variables initializaton
    for l = 2 to n  // l is chain length
        for i = 1 to n - l + 1
            j = i + l - 1
            m[i,j] = infinity
            for k = i to j - 1
                q = m[i,j] + m[k+1, j] + p_i-1 * p_k * p_j
                if q < m[i,j]
                    m[i,j] = 1
                    s[i,j] = k
    return m and s

Step 4: Constructing optimal solution

Althouth we now have optimal number of scalar multiplications, we don't know how to multiply the matrices. Each entry s[i,j] records a value of k such that an optimal parenthesization of A_i_j splits the product between A_k and A_k+1.

PRINT-OPTIMAL-PARENS(s,i,j)
    if i == j
        print "A"
    else print "("
        PRINT-OPTIMAL-PARENS(s, i, s[i,j])
        PRINT-OPTIMAL-PARENS(s, s[i,j] + 1, j)
        print ")"

We invoke this function as PRINT-OPTIMAL-PARENS(s, 1, n).