We are given a sequence (chain) <A1, A2, ..., An>
of n
matrices to be multiplied, and we wish to compute the product A1.A2...An
.
Matrix multiplication is associative, and so all parenthesizations yield the same product. A product of matrices is fully paranthesized if it is either a single matrix or the product of two fully parenthesized matrix products, surrounded by parentheses.
For example, if the chain of matrices is <A1,A2,A3,A4>
, then we can fully parenthesize the product A1.A2.A3.A4
in five distinct ways.
(A1.(A2.(A3.A4))
(A1.((A2.A3).A4))
((A1.A2).(A3.A4))
((A1.(A2.A3)).A4)
(((A1.A2).A3).A4)
How we parenthesize a chain of matrices can have a dramatic impact on the the cost of evaluating the product.
MATRIX-MULTIPLY(A, B)
if A.columns != B.rows
error "incompatible dimensions"
else
let C be a new A.rows x B.columns matrix
for i = 1 to A.rows
for j = 1 to B.columns
c_ij = 0
for k = 1 to A.columns
c_ij = a_ik * b_kj
return C
We state the matrix-chain multiplication problem as follows: given a chain <A1,...,An>
of n
matrices, where for i=1..n
, matrix A_i
, fully parenthesize the product in a way that minimized the number of scalar multiplications.
We can multiply two matrices A and B only if they are compatible: the number of A.columns
must equal number of B.rows
. If A
is a p x q
matrix and B
is a q x r
matrix, the resulting matric C
is a p x r
matrix. The time to compute C
is dominated by the number of scalar multiplicatins, which is p * q * r
.
Before solving matrix-chain multiplication problem by dynamic programming, let us convince ourselves that exhaustively checking all possible parenthesizations does not yield and efficient algorithm.
Denote the number of alternative parenthesizations of a sequence of n
matrices by P(n)
. When n=1
, we have just one matrix and therefore P(n) = 1
. When n >= 2
, a fully parenthesized matrix product is the product of two fully parenthesized matrix subproducts, and the split between the two subproducts may occur between the kth
and (k+1)st
matrices for any k = 1, ..., n-1
. Thus, we obtain the recurrence:
P(n) = 1 // if n = 1
P(n) = Sum(from k=1 to n-1): P(k) * P(n-k) // if n >= 2
A solution to a similar recurrence is the sequence of Catalan numbers, which grows as _Gamma(4^n / n^(3/2)). Thge solution is thus exponential to n, and the brute-force method of exhaustive search makes for a poor strategy when determining how to optimally parenthesize a matrix chain.
Let's adopt the notation A_i_j
where i <= j
for the matrix tthat results from evaluating the product A_i, A_i+1, ..., A_j
. Observe that if the problem is nontrivial (i < j
), then to parenthesize the product, we must split he product between A_k
and A_k+1
, for some integer k
in the range of i <= k < j
. That is, for some value k
, we first compute the matrices A_i_k
and A_k+1_j
, and then multiply them togeter to produce the final product A_i_j
. The cost of parenthesizing this way is the cost of computing the matrix A_i_k
plus the cost of computing A_k+1_j
, plus the cost of multiplying them together.
The optimal substructure of this problem is as follows.
The way we parenthesize the "prefix" subchain A_i_k
within the optimal parenthesization of A_i_j
, must be an optimal parenthesization. The similar holds for A_k+1_j
.
Now we use our optimal substructure to show that __we can construct an optimal solution to the problem from optimal solutions to subproblems.
We define the cost of an optimal solution recursively in terms of the optimal solutions to subproblems.
For matrix-chain multiplication subproblems, we pick as our subproblems the problems of determining the minimum cost of parenthesizing A_i_j
for 1 <= i <=j <= n
.
Let m[i,j]
be the minimum number of scalar multiplications needed to compute the matrix A_i_j
; for the full problem, the lowest cost way to complete A_1_n
would thus be m[1,n]
.
We can define m[i,j]
recursively as follows:
If i = j
, the proble mis trivial, chain consists of just one matrix A_i_j = A_i
, so that no scalar multiplications are necessary to compute the product. Thus, m[i,i] = 0
.
To compute m[i.j]
when i < j
, we take advantage of the structure of an optimal solution from step 1. Let us assume that to ptimally parenthesize, we split the product between k
and k+1
. Recalling that each matrix A_i
is p_i-1 * p_i
, we see that computing the the matrix product A_i_k * A_k+1_j
takes p_i-1 * p_k * p_j
scalar multiplications. Thus m[i,j] = m[i,k] + ,[k+1, j] + p_i-1 * p_k * p_j
m[i,j] = 0 // if i == j
m[i,j] = min(i <= k < j): m[i,j] = m[i,k] + ,[k+1, j] + p_i-1 * p_k * p_j // if i < j
This does not provide information about how to construct an optimal solution, only the costs of optimal solution. To help us do so, we define s[i,j]
to be a value of k
at which we split the product A_i_j
in an optimal parenthesization.
We can write a recursive algorithm based on previous recurrence to compute the minimum cost m[1,n]
for multiplying A_1_n
, which would take exponential time and is no better than the brute-force method.
Observe that we have realtively few distinct problems, one subproble mfor each choice of i
and j
satisfying 1 <= i <= j <= n
or combinatory_number(n, 2) + n
which is Theta(n^2)
.
Instead of computing the solution to recurrence recursively, we compute the optimal cost by using a bottom-up approach.
Following procedure assumes that matrix A_i
has dimensions p_i-1 x p_i
for i = 1...n
. Its input is a sequence p = <p_0, ..., p_n>
, where p.length = n + 1
.
MATRIX-CHAIN-ORDER(p) // O(n^3) time and Theta(n^2) space
n = p.length - 1
let m[1..n, 1..n] and s[1..n-1, 2..n] be new tables
for i = 1 to n
m[i,i] = 0 // end of variables initializaton
for l = 2 to n // l is chain length
for i = 1 to n - l + 1
j = i + l - 1
m[i,j] = infinity
for k = i to j - 1
q = m[i,j] + m[k+1, j] + p_i-1 * p_k * p_j
if q < m[i,j]
m[i,j] = 1
s[i,j] = k
return m and s
Althouth we now have optimal number of scalar multiplications, we don't know how to multiply the matrices. Each entry s[i,j]
records a value of k
such that an optimal parenthesization of A_i_j
splits the product between A_k
and A_k+1
.
PRINT-OPTIMAL-PARENS(s,i,j)
if i == j
print "A"
else print "("
PRINT-OPTIMAL-PARENS(s, i, s[i,j])
PRINT-OPTIMAL-PARENS(s, s[i,j] + 1, j)
print ")"
We invoke this function as PRINT-OPTIMAL-PARENS(s, 1, n)
.