-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
sum-of-total-strength-of-wizards.cpp
84 lines (82 loc) · 3.53 KB
/
sum-of-total-strength-of-wizards.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// Time: O(n)
// Space: O(n)
// mono stack, prefix sum, optimized from solution2
class Solution {
public:
int totalStrength(vector<int>& strength) {
static const int MOD = 1e9 + 7;
const auto& add = [&](const int64_t a, const int64_t b) {
return (a + b) % MOD;
};
const auto& sub = [&](const int64_t a, const int64_t b) {
return (a - b + MOD) % MOD;
};
const auto& mult = [&](const int64_t a, const int64_t b) {
return (a * b) % MOD;
};
vector<int64_t> prefix(size(strength) + 1);
int64_t curr = 0;
for (int i = 0; i < size(strength); ++i) {
curr = add(curr, strength[i]);
prefix[i + 1] = add(prefix[i], curr);
}
vector<int> stk = {-1};
int result = 0;
for (int i = 0; i <= size(strength); ++i) {
while (stk.back() != -1 && (i == size(strength) || strength[stk.back()] >= strength[i])) {
const int x = stk[size(stk) - 2] + 1;
const int y = stk.back(); stk.pop_back();
const int z = i - 1;
// assert(all(strength[j] >= strength[y] for j in xrange(x, y+1)))
// assert(all(strength[j] > strength[y] for j in xrange(y+1, z+1)))
result = add(result, mult(strength[y], sub(mult(y - x + 1, sub(prefix[z + 1], prefix[y])), mult(z - y + 1, sub(prefix[y], prefix[max(x - 1, 0)])))));
}
stk.emplace_back(i);
}
return result;
}
};
// Time: O(n)
// Space: O(n)
// mono stack, prefix sum
class Solution2 {
public:
int totalStrength(vector<int>& strength) {
static const int MOD = 1e9 + 7;
const auto& add = [&](const int64_t a, const int64_t b) {
return (a + b) % MOD;
};
const auto& sub = [&](const int64_t a, const int64_t b) {
return (a - b + MOD) % MOD;
};
const auto& mult = [&](const int64_t a, const int64_t b) {
return (a * b) % MOD;
};
vector<int64_t> prefix(size(strength) + 1), prefix2(size(strength) + 1);
for (int i = 0; i < size(strength); ++i) {
prefix[i + 1] = add(prefix[i], strength[i]);
prefix2[i + 1] = add(prefix2[i], strength[i] * static_cast<int64_t>(i + 1));
}
vector<int64_t> suffix(size(strength) + 1), suffix2(size(strength) + 1);
for (int i = size(strength) - 1; i >= 0; --i) {
suffix[i] = add(suffix[i + 1], strength[i]);
suffix2[i] = add(suffix2[i + 1], strength[i] * (size(strength) - i));
}
vector<int> stk = {-1};
int result = 0;
for (int i = 0; i <= size(strength); ++i) {
while (stk.back() != -1 && (i == size(strength) || strength[stk.back()] >= strength[i])) {
const int x = stk[size(stk) - 2] + 1;
const int y = stk.back(); stk.pop_back();
const int z = i - 1;
// assert(all(strength[j] >= strength[y] for j in xrange(x, y+1)))
// assert(all(strength[j] > strength[y] for j in xrange(y+1, z+1)))
const int64_t left = mult(z - y + 1, sub(sub(prefix2[y + 1], prefix2[x]), mult(x, sub(prefix[y + 1], prefix[x]))));
const int64_t right = mult(y - x + 1, sub(sub(suffix2[y + 1], suffix2[z + 1]), mult(size(strength) - (z + 1), sub(suffix[y + 1], suffix[z + 1]))));
result = add(result, mult(strength[y], add(left, right)));
}
stk.emplace_back(i);
}
return result;
}
};