WQS 二分优化 DP

First Post:

Last Update:

Word Count:
746

Read Time:
3 min

Page View: loading...

WQS 二分,在不同的文章里,也常称作带权二分、凸优化 DP、凸完全单调性 DP、Lagrange 乘子法等,在国外也称作 Aliens Trick。它最早由王钦石(wqs)在《浅析一类二分方法》一文中总结。

WQS 二分的典型特征是限制“恰好取 个”的最优化问题。

应用 WQS 二分要求函数具有凸性;例如,设“从 个中恰好取 个”的最优解为 ,则一般要求 关于 是凸的。

二分

假如将 作为 轴, 作为 轴,则函数图像可能形如:

wqs

这里使用下凸函数为例。

我们其实不知道这个函数具体长什么样子,只知道它是凸的,我们尝试使用一条直线去切这个凸壳,检查切点的位置:

  • 若切点的横坐标恰为所求的 ,则问题解决;
  • 否则,由于凸壳的斜率是单调的,因此可以二分斜率。

可以结合上图理解。

现在的问题是,如何求出切点的位置。

考虑这个东西的含义,假设斜率为 的直线 恰好经过点 ,则截距

可以理解为每次选择都花费了 的代价。

注意到直线切到凸壳时截距取最小值,因此在做完转化后求一个不带限制的最小化 DP 即可。

易错点

凸壳上可能有相邻一系列点的斜率相同,这时 dp 和二分的一些端点未处理好可能会把正确的斜率排除到二分区间以外。

我代码中的习惯是,dp 时使用最小的转移点,这样计算出的分割数是最小的;二分时将小于等于 的区间归结到 上并将大于 的区间扔掉。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Luogu P4983
#include <cstdint>
#include <deque>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;

static constexpr int64_t INF = 1e18;

int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int64_t n, m;
cin >> n >> m;
vector<int64_t> S(n + 1);
for (int64_t i = 1; i <= n; ++i)
cin >> S[i], S[i] += S[i - 1];

auto check = [&](int64_t k) -> pair<int64_t, int64_t> {
vector<int64_t> F(n + 1), G(n + 1);
auto gX = [&](int64_t j) { return S[j]; };
auto gY = [&](int64_t j) { return F[j] + S[j] * S[j] - 2 * S[j]; };
auto gF = [&](int64_t i, int64_t j) {
return F[j] + k + (S[i] - S[j] + 1) * (S[i] - S[j] + 1);
};
deque<int64_t> q;
q.emplace_back(0);
for (int64_t i = 1, j = 0; i <= n; ++i) {
while (j + 1 < q.size() && gF(i, q[j + 1]) < gF(i, q[j]))
++j;
if (j >= q.size())
j = q.size() - 1;
F[i] = gF(i, q[j]), G[i] = G[q[j]] + 1;
while (q.size() >= 2) {
int64_t j1 = q[q.size() - 2], j2 = q[q.size() - 1];
int64_t k1 = (gY(j2) - gY(j1)) * (gX(i) - gX(j2));
int64_t k2 = (gY(i) - gY(j2)) * (gX(j2) - gX(j1));
if (k1 < k2)
break;
q.pop_back();
}
q.emplace_back(i);
}
return {F[n], G[n]};
};

int64_t l = 0, r = INF;
while (l < r) {
int64_t mid = l + ((r - l) >> 1);
auto [f, g] = check(mid);
if (g <= m)
r = mid;
else
l = mid + 1;
}
auto [f, g] = check(l);
cout << f - l * m << '\n';
return 0;
}