# 题目大意

一个车间有 nn 台设备,构成以 11 为根结点的一棵树,结点 ii 有权值 wiw_i

  • 叶节点的权值 wiw_i 表示每单位时间将产出 wiw_i 单位的材料并送往父结点
  • 根结点的权值 wiw_i 表示每单位时间内能打包多少单位成品
  • 其他结点的权值 wiw_i 表示每单位时间最多能加工 wiw_i 单位的材料并送往父结点。

由于存在某些结点每单位时间收到的材料超过了当前结点的加工能力上限,现计划删除一些结点使得所有结点都能正常运行
请问删除一些结点后,根结点每单位时间内最多能打包多少单位的成品?

# 数据范围

  • 对于 100%100\% 的评测用例,2n10002 \leq n \leq 10001wi10001\leq w_i \leq 1000.

# 题解

dp[u][x] 表示在以 uu 为根的子树中,是否能选出若干节点,使得权值和恰好等于 xx
从根节点开始往下搜

  • 当搜到叶结点时, dp[叶结点][0]dp[叶结点][w[u]] 显然可以
  • 当搜到的是非叶结点(暂记为 xx 结点)时,对 xx 结点能走到的所有点,遍历所有可能的组合
    C++
    1
    2
    3
    4
    5
    6
    for (int s = 0; s <= w[u]; ++s)
    if (cur[s])
    // 遍历当前遍历到的一棵子树能带来的 s + k 的新收益
    for (int k = 0; k + s <= w[u]; ++k)
    if (dp[v][k])
    nxt[s + k] = true;

    dp[x结点][目前x结点能凑成的值 + x结点走到的下一层点能提供的值] 是应该被更新的
    两层循环,时间复杂度 O(w[u]2)O({w[u]}^2)

# 完整代码

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
using namespace std;

const int N = 1e3 + 9;
int n, ans = 0;
vector<int> w(N);
vector<vector<int>> g(N);
vector<vector<bool>> dp(N, vector<bool>(N));

// p 是用来防止走回头路的,剩下一个 st 数组
// dp[u][x] 表示在以 u 为根的子树中,是否能选出若干节点,使得权值和恰好等于 x
void dfs(int u, int p) {
if (g[u].size() == 1 && p != -1) { // 叶子节点判断
// 以叶子结点为根的子树,没有继续往下了,所以只能凑到 0 和 w[u]
dp[u][0] = true;
dp[u][w[u]] = true;
return;
}

vector<bool> cur(N), nxt(N);
/*
cur[i]表示在遍历当前结点的某些子树后,能否凑出权值和 x
nxt 是临时存储更新后的可行权值和
*/
cur[0] = true; // 凑 0 显然可以

for (int v : g[u]) {
if (v != p) {
dfs(v, u);

nxt.assign(N, 0);
nxt[0] = true;

// 遍历当前根能凑成的所有权值
for (int s = 0; s <= w[u]; ++s)
if (cur[s])
// 遍历当前遍历到的一棵子树能带来的 s + k 的新收益
for (int k = 0; k + s <= w[u]; ++k)
if (dp[v][k])
nxt[s + k] = true;

cur = nxt;
}
}

dp[u] = cur;
}

int main() {
cin >> n;

for (int i = 1; i <= n; i++)
cin >> w[i];

for (int i = 1, u, v; i < n; i++)
cin >> u >> v, g[u].push_back(v), g[v].push_back(u);

dfs(1, -1);

for (int x = w[1]; x >= 0; x--)
if (dp[1][x]) {
cout << x << endl;
break;
}

return 0;
}

dfsdfs 搜一遍所有点复杂度 O(N)O(N),总复杂度 O(NW2)O(N · W^2)
由题目数据范围有,将达到 1e91e9 的级别(洛谷数据水,蹭 ACAC 了),应该利用 bitsetbitset 优化

# bitset 优化

优化后,时间复杂度为 O(NW)O(N · W)

C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
vector<int> w(N);
vector<vector<int>> g(N);
vector<bitset<N>> dp(N);

void dfs(int u, int p) {
if (g[u].size() == 1 && p != -1) {
dp[u][0] = 1, dp[u][w[u]] = 1;
return;
}

bitset<N> cur, mask; cur[0] = 1;
for (int i = 0; i <= w[u]; ++i)
mask.set(i);

for (int v : g[u])
if (v != p) {
dfs(v, u);

bitset<N> nxt;
nxt[0] = 1;

for (int s = cur._Find_first(); s <= w[u]; s = cur._Find_next(s))
nxt |= (dp[v] << s);

nxt &= mask, cur = nxt;
}

dp[u] = cur;
}

相较原先代码,变化的地方在于更新及添加了一个 maskmask 数组

C++
1
2
for (int s = cur._Find_first(); s <= w[u]; s = cur._Find_next(s))
nxt |= (dp[v] << s);

  • 更新改变解释
    s = cur._Find_first() + s = cur._Find_next(s) 可以快速找 11 的位置,利用 bitsetbitset,减少了原先的循环

    C++
    1
    2
    3
    for (int k = 0; k + s <= w[u]; ++k)
    if (dp[v][k])
    nxt[s + k] = true;

    例如: dp[v]{0,0,1,1} ,原先能凑出 0011
    s = 2 ,那么 dp[v] << s 即为 {1, 1, 0, 0} ,现在能凑出 2233

  • 为什么要多一个数组
    上面解释到,利用 dp[v] << s 来简化原先 dp[v][k] = 1 时, nxt[k + s] = 1 的循环
    而左移的时候,是可能超出 w[u] 的范围的,所以要先初始化好一个 mask[0wu]=1mask[0 \to w_u] = 1bitsetbitset
    每次算完 nxtnxt 后,要先和 maskmask 与一下再赋值给 curcur