Skip to content

Latest commit

 

History

History
327 lines (262 loc) · 7.88 KB

张量相乘的最小开销问题.md

File metadata and controls

327 lines (262 loc) · 7.88 KB

张量相乘的最小开销问题


题目

by 吴世光

描述

规定一种张量乘法, 对于 $$K(K \geq 2)$$ 维张量, 前 $$K - 2$$ 维按张量广播乘法, 后两维按矩阵乘法. 现有 $$n$$$$K$$ 维张量按此相乘, 可以为其加上若干对括号, 求最少标量乘法次数.

提示: 张量个数小于 $$2048$$, 维数小于 $$32$$, 每个维度的大小小于 $$1000$$.

IO格式

输入:

n K
第1个张量第1维长度...
...

输出:

最少标量乘法次数

样例

输入:

3 3
1 1 2
1 2 3
10 3 4

输出:

126

解答

思路

这是一道典型的区间 DP.

对于动态规划问题, 我们的解题思路是: 设计状态函数、求状态函数初值、写出转移方程、化简转移方程、遍历求解.

状态函数

dp(i, j) 为从第 i 个张量一直连乘到第 j 个张量的最少标量乘法次数. 我们不难得出以下两个结论:

  • 我们的所求为 dp(0, n-1);
  • 我们已经得知的初值为 dp(i, i) = 0 for i = 0 ... n-1.

转移方程

对于一个添加括号的行为, 其结果是将一部分的乘法运算独立, 因此其最简单的就是将一连串的乘法分为两段 (构造最优子结构); 因此, 转移方程可以写为如下:

dp(i, j) = min(
    dp(i, k) + dp(k+1, j) + cost(i, j, k)
    for k = i ... j-1
)

这里, cost(i, j, k) 是将两部分的乘法结果合并所产生的花销, 关于它究竟是什么暂且按下不谈.

求解

现在我们来考虑一下 ij 的取值.

考虑如下的一张 dp 表, 我们已经填入了初值; 表中最右上角即为我们的目标解:

$$\begin{bmatrix} 0 & \infty & \infty & \cdots & \infty \\ & 0 & \infty & \ddots & \vdots \\ & & 0 & \ddots & \vdots \\ & & & \ddots & \infty \\ & & & & 0 \end{bmatrix}$$

根据状态方程, 对于一个待求的 dp(i, j), 我们需要知道在第 i 行和第 j 列上介于 dp(i, j)dp(i, i)dp(j, j) 的值:

$$\begin{bmatrix} \ddots \\ & 0 & \cdots & \infty \\ & & \ddots & \vdots \\ & & & 0 \\ & & & & \ddots \end{bmatrix}$$

因此, 我们将会遵循以下的遍历顺序:

  • s = 0 为对角线, s 可以赋予意义为 dp 表的斜行坐标, 不过它在问题中的真实意义为我们对问题拆分的小问题的步长.
  • for s = 1 ... n-1, 我们有 j = i + s for i = 0 ... n-1-s.

这个遍历为 $$O(n^2)$$ 的遍历.

合并

最后, 我们来看一下 cost(i, j, k). 对于两段已经乘好的张量, 我们将其合并, 我们可以得出如下两部分花销:

  • 一是前面 $$K - 2$$ 维的结果恒定, 并且张量乘法遵循广播, 所以由于维度导致的标量乘法次数系数是不变的: cost1(i, j) = prod(max(Tensor[i ... j]D(k)) for k = 1 ... K-2).
  • 二是在 $$l \times m$$$$m \times n$$ 矩阵乘的过程中, 这一部分的花销形如 $$lmn$$. $$m$$ 是变化的, 但是 $$l$$$$n$$ 是不变的: cost2(i, j) = Tensor[i]D(K-1) * Tensor[j]D(K), cost3(k) = Tensor[k]D(K).

故, cost(i, j, k) = cost1(i, j) * cost2(i, j) * cost3(k).

代码

初步思路

最初为了计算 cost(i, j, k), 我是希望将计算过程全部记录下来的, 但是这不仅费空间, 而且不是一个好的算法:

#include <cstdio>
#include <climits>
using namespace std;

#define prf printf
#define scf scanf
using ui = unsigned int;
using us = unsigned short;

ui n, K, Ks;

struct Tensor {
    Tensor() : dims(new us[Ks]), mat(new us[2]) {}
    ~Tensor() {
        delete[] dims;
        delete[] mat;
    }
    us *dims, *mat;
    void in() {
        for (ui i = 0; i < Ks; i++) scf("%hu", &dims[i]);
        scf("%hu%hu", &mat[0], &mat[1]);
    }
};

struct Cost {
    Cost() : val(UINT_MAX) {}
    Cost(ui _val) : val(_val) {}
    ui val;
    operator ui() { return val; }
    bool operator<(Cost& other) { return val < other.val; }
};

struct Dp {
    Dp(ui _n) : n(_n), __arr(new Cost[_n * _n]), __res(new Tensor[_n * _n]) {}
    ~Dp() {
        delete[] __arr;
        delete[] __res;
    }
    ui n;
    Cost* __arr;
    Tensor* __res;
    Cost& operator()(ui i, ui j) { return __arr[i * n + j]; }
    Tensor& result(ui i, ui j) { return __res[i * n + j]; }
};

inline ui cnt(Tensor& a, Tensor& b, Tensor& res) {
    ui ret = 1;
    for (ui i = 0; i < Ks; i++)
        ret *= res.dims[i] = a.dims[i] == 1 ? b.dims[i] : a.dims[i];
    ret *= (res.mat[0] = a.mat[0]) * b.mat[0] * (res.mat[1] = b.mat[1]);
    return ret;
}

int main() {
    scf("%u%u", &n, &K);
    Ks = K - 2;
    Dp dp(n);
    for (ui i = 0; i < n; i++) {
        dp.result(i, i).in();
        dp(i, i) = 0;
    }
    ui rend = n - 1;
    for (ui r = 1; r <= rend; r++) {
        ui iend = n - 1 - r;
        for (ui i = 0; i <= iend; i++) {
            ui j = i + r;
            Cost& cur = dp(i, j);
            Tensor& curRes = dp.result(i, j);
            for (ui k = i; k < j; k++) {
                Cost cost = dp(i, k) + dp(k + 1, j) +
                            cnt(dp.result(i, k), dp.result(k + 1, j), curRes);
                if (cost < dp(i, j)) cur = cost;
            }
        }
    }
    prf("%u", ui(dp(0, n - 1)));
    return 0;
}

最终提交

#include <cstdio>
#include <climits>
using namespace std;

#define prf printf
#define scf scanf
using ui = unsigned int;
using us = unsigned short;

ui n, K, Ks;

struct Tensor {
    Tensor() : dims(new us[Ks]), mat(new us[2]) {}
    ~Tensor() {
        delete[] dims;
        delete[] mat;
    }
    us *dims, *mat;
    void in() {
        for (ui i = 0; i < Ks; i++) scf("%hu", &dims[i]);
        scf("%hu%hu", &mat[0], &mat[1]);
    }
};

struct Cost {
    Cost() : val(UINT_MAX) {}
    Cost(ui _val) : val(_val) {}
    ui val;
    operator ui() { return val; }
    bool operator<(Cost& other) { return val < other.val; }
};

struct Dp {
    Dp(ui _n) : n(_n), __arr(new Cost[_n * _n]), __ten(new Tensor[_n]) {}
    ~Dp() {
        delete[] __arr;
        delete[] __ten;
    }
    ui n;
    Cost* __arr;
    Tensor* __ten;
    Cost& operator()(ui i, ui j) { return __arr[i * n + j]; }
    Tensor& result(ui i) { return __ten[i]; }
};

inline ui cnt(Dp& dp, ui i, ui j) {
    ui ret = 1;
    for (ui kk = 0; kk < Ks; kk++) {
        for (ui ii = i; ii <= j; ii++) {
            ui dim = dp.result(ii).dims[kk];
            if (dim != 1) {
                ret *= dim;
                break;
            }
        }
    }
    return ret;
}

inline ui cnt(Dp& dp, ui i, ui j, ui k) {
    return dp.result(i).mat[0] * dp.result(k).mat[1] * dp.result(j).mat[1];
}

int main() {
    scf("%u%u", &n, &K);
    Ks = K - 2;
    Dp dp(n);
    for (ui i = 0; i < n; i++) {
        dp.result(i).in();
        dp(i, i) = 0;
    }
    ui rend = n - 1;
    for (ui r = 1; r <= rend; r++) {
        ui iend = n - 1 - r;
        for (ui i = 0; i <= iend; i++) {
            ui j = i + r;
            Cost& cur = dp(i, j);
            ui curCntKs = cnt(dp, i, j);
            for (ui k = i; k < j; k++) {
                Cost cost =
                    dp(i, k) + dp(k + 1, j) + curCntKs * cnt(dp, i, j, k);
                if (cost < cur) cur = cost;
            }
        }
    }
    prf("%u", ui(dp(0, n - 1)));
    return 0;
}

时空消耗

1	Accepted	0 ms	832 KB
2	Accepted	0 ms	832 KB
3	Accepted	60 ms	1716 KB
4	Accepted	0 ms	828 KB
5	Accepted	60 ms	1716 KB
6	Accepted	0 ms	836 KB
7	Accepted	64 ms	1712 KB
8	Accepted	164 ms	1712 KB
9	Accepted	2832 ms	13564 KB
10	Accepted	2844 ms	13564 KB