树形动态规划问题解析&树上背包问题研究

First Post:

Last Update:

1. 树上DP概述

树形 DP,即在树上进行的 DP。由于树固有的递归性质,树形 DP 一般都是递归进行的。

大部分的树形 DP 都是线性的,并且由于树本身就是有序的,所以具有十分良好的性质,例如子结构性质等。

树形 DP 在算法竞赛中考察多样,但是简单的可以分为:

  1. 树上线性 DP。
  2. 换根 DP。

树上线性 DP 也分为很多种,不同的题有不同的考法,实际上,所有的 DP 都能在树上考,但是蓝桥比赛中,一般就几种考法:

  1. 树上决策,例如选最大值,最小值。
  2. 树上背包。
  3. 换根,换根 dp 是树上的一类特殊性质。

接下来,我们将通过几个问题来描述这三种问题的解法。

2. 树上决策问题

树上决策问题,往往是子节点向父节点转移时,只取最优的解,这一点与线性 DP 十分相似。

看一道例题:

2. 1 生命之树-真题

图片描述

这题看着挺玄乎,其实并没有那么复杂。

我们观察题目要求:给定一棵树,选出一个非空集合,使得对于任意两个元素 $a, b$,都存在一个序列 $a, v_1, …v_k, b$ 是这个集合里的元素,并且相邻两个点之间有一条边。

本来可以一句话说清楚的事情,但是偏偏要给出数学定义,所以要考察大家的归纳整理能力。

实际上,就是要在树中选出一个连通块即可,并且满足连通块的和值最大。

为什么呢?

我们观察一幅图,相信大家能理解了:

图片描述

绿色的代表我们选择的点集合。这些点是连通的,所以满足要求。

如果换成这个样子:

图片描述

这样就不满足题目要求了。

所以大家可以体会出来,题目的要求,其实就是找一个树上的连通块。

那么我们的问题就变成了在树上找最大的连通块了。

树形 DP,终究还是 DP,所以需要划分子问题。

我们常用的方法是,将子节点为根的子树,看成子问题,然后合并到当前根

将节点从深到浅(子树从小到大)的顺序作为 DP 的阶段,在 DP 的表示中,通常第一维代表节点的编号,后续维度按照问题进行设计。

首先我们需要解决一个问题,树上的连通块是什么?有什么性质可以利用。

答案是:树上的联通块也是树,他一定有根。所以我们要是找到这个根,或者枚举这个根,就可以找到答案。

我们设计的状态如下:

$dp_i$ 表示,对于节点为 $i$ 的子树,我们找到的以 $i$ 为根的连通块和值最大是 $dp_i$。

那么我们的转移的意义就是:对于 $i$ 来说,由于 $i$ 一定存在连通块中,所以,我们要找到他的儿子中,哪些是和 $i$ 连着的。

有一种贪心方案,对于 $i$ 的儿子 $v \in son(i)$ ,如果 $dp_v \ge 0$,我们就将他接入父亲即可。

所以,我们的转移方程就是: $$ dp_i = w_i + \sum _{dp_j \ge 0 & j \in son(i)} dp_j $$ 代码如下:

  • C++
1
#include <iostream> #include <vector> using namespace std; const int N = 1e5+100; typedef long long ll; vector<int> G[N]; int w[N]; ll dp[N], ans = -1e18; int n; void dfs(int u, int f) { dp[u] = w[u]; for (int v : G[u]) { if (v == f) continue; dfs(v, u); if (dp[v] > 0) { dp[u] += dp[v]; } } ans = max(ans, dp[u]); } int main() { cin >> n; for (int i = 1; i <= n; ++i) { cin >> w[i]; } int u, v; for (int i = 1; i < n; ++i) { cin >> u >> v; G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); cout << ans << endl; return 0; }
  • Java
1
import java.util.*; public class Main { private static final int N = (int) (1e5 + 100); private static long[] dp; private static int[] w; private static List<List<Integer>> G; private static long ans = Long.MIN_VALUE; private static int n; private static void dfs(int u, int f) { dp[u] = w[u]; for (int v : G.get(u)) { if (v == f) continue; dfs(v, u); if (dp[v] > 0) { dp[u] += dp[v]; } } ans = Math.max(ans, dp[u]); } public static void main(String[] args) { Scanner scanner = new Scanner(System.in); n = scanner.nextInt(); w = new int[N]; G = new ArrayList<>(); for (int i = 0; i < N; i++) { G.add(new ArrayList<>()); } dp = new long[N]; for (int i = 0; i < n; i++) { w[i] = scanner.nextInt(); } for (int i = 0; i < n - 1; i++) { int u = scanner.nextInt() - 1; // 0-indexed in Java int v = scanner.nextInt() - 1; G.get(u).add(v); G.get(v).add(u); } dfs(0, -1); System.out.println(ans); scanner.close(); } }
  • Python
1
import sys sys.setrecursionlimit(100000) n = int(input()) aList = [0] + [int(i) for i in input().split()] tree = [[]for i in range(n+1)] ans = 0 dp = [0 for i in range(n+1)] for i in range(n-1): m, n =map(int, input().split()) tree[m].append(n) tree[n].append(m) def dfs(u,f): global ans dp[u] = aList[u] for i in tree[u]: if i !=f: dp[i] = dfs(i, u) if dp[i]>0: dp[u] += dp[i] ans=max(ans, dp[u]) return dp[u] dfs(1, 0) print(ans)

3. 树上背包问题

树上背包问题,本质上还是背包,可以看成在树上进行的背包。

每次转移都是在父亲与儿子之间进行了一次经典背包转移。

3.1 小明的背包6

图片描述

这个是典型的依赖背包问题。

并且依赖关系构成了一棵树。

我们看样例:

1
2
3
4
5
6
7
<span>6 </span><span>15</span>
<span>3 </span><span>4</span> <span>0</span>
<span>2 </span><span>3</span> <span>1</span>
<span>2 </span><span>5</span> <span>1</span>
<span>3 </span><span>5</span> <span>1</span>
<span>4 </span><span>8</span> <span>2</span>
<span>3 </span><span>9</span> <span>2</span>

图片描述

依赖关系如上图所示:上图的含义是如果只有购买了 $1$ 号物品,才能购买 $2, 3, 4$ 号物品。

记住,我们的目标是划分子问题,也就是说,只要保证了一个子问题的划分是正确的,那么由于树的优良递归性质,其他的也会是正确的。

复习一下普通的背包问题,用 $dp_i$ 表示,在使用了 $i$ 空间的情况下的最大价值。

但是在树问题中,由于第一维度是节点的编号,所以我们用 $dp_{i,j}$ 表示对于 $i$ 子树来说,使用了 $j$ 空间的最大价值。

当然题目中有要求,必须满足依赖关系,所以,我们需要重新定义: $dp_{i,j}$ 表示对于 $i$ 子树来说,使用了 $j$ 空间且满足依赖关系的最大价值。

如何满足呢?

我们只需要保证每一个 $dp_{i,j}$ 都选了 $i$ 节点即可。

我们可以在背包中预留出节点 $i$ 的空间即可。

代码如下:

  • C ++
1
#include <iostream> #include <algorithm> #include <vector> using namespace std; const int N = 1e2+20; vector<int> G[N]; int n, V; int v[N], w[N]; int dp[N][N]; void dfs(int u) { for (int i = v[u]; i <= V; ++i) dp[u][i] = w[u]; for (int i : G[u]) { dfs(i); for (int j = V; j >= v[u] + v[i]; --j) { for (int k = v[i]; k <= j - v[u]; ++k) // 剩余的空间 dp[u][j] = max(dp[u][j - k] + dp[i][k], dp[u][j]); } } } int main() { cin >> n >> V; int s; for (int i = 1; i <= n; ++i) { cin >> v[i] >> w[i] >> s; G[s].push_back(i); } dfs(0); cout << dp[0][V] << '\n'; }
  • Java
1
import java.util.ArrayList; import java.util.List; import java.util.Scanner; public class Main { private static int V; private static int[][] dp; private static List<List<Integer>> G; private static int[] v; private static int[] w; private static void dfs(int u) { for (int i = v[u]; i <= V; ++i) { dp[u][i] = w[u]; } for (int child : G.get(u)) { dfs(child); for (int j = V; j >= v[u] + v[child]; --j) { for (int k = v[child]; k <= j - v[u]; ++k) { dp[u][j] = Math.max(dp[u][j - k] + dp[child][k], dp[u][j]); } } } } public static void main(String[] args) { Scanner scanner = new Scanner(System.in); int n = scanner.nextInt(); V = scanner.nextInt(); G = new ArrayList<>(); for (int i = 0; i <= n; ++i) { G.add(new ArrayList<>()); } v = new int[n + 1]; w = new int[n + 1]; dp = new int[n + 1][V + 1]; for (int i = 1; i <= n; ++i) { v[i] = scanner.nextInt(); w[i] = scanner.nextInt(); int s = scanner.nextInt(); G.get(s).add(i); } dfs(0); System.out.println(dp[0][V]); scanner.close(); } }
  • Python
1
class Solution: def dfs(self, u, dp, G, v, w, V): for i in range(v[u], V + 1): dp[u][i] = w[u] for child in G[u]: self.dfs(child, dp, G, v, w, V) for j in range(V, v[u] + v[child] - 1, -1): for k in range(v[child], j - v[u] + 1): dp[u][j] = max(dp[u][j - k] + dp[child][k], dp[u][j]) def main(self): n, V = map(int, input().split()) G = [[] for _ in range(n + 1)] # 0-indexed in Python v = [0] * (n + 1) w = [0] * (n + 1) for i in range(1, n + 1): v[i], w[i], s = map(int, input().split()) G[s].append(i) dp = [[0] * (V + 1) for _ in range(n + 1)] self.dfs(0, dp, G, v, w, V) print(dp[0][V]) # Run the main function solution = Solution() solution.main()

4. 换根 DP 问题

换根 DP,面对的问题通常是“不定根”问题,也就是说,对于一棵树,他的根不一定是 $1$ 号点,可能是任意某个点。

或者在某些问题中,我们需要尝试计算以每个点为根的情况,最后维护出最大值。

我们先看一副图,来理解所谓的“换根”。

图片描述

我们将原来以 $1$ 为根换成了以 $2$ 为根。那么树的形态也就发生了变化。

如果每次都是选择一个点作为根进行处理,那么总的时间复杂度为 $O(n^2)$,但是如果我们能发现性质,我们可以将复杂度降为 $O(n)$。

即换一次根的复杂度为 $O(1)$,下面,我们将讲述这种方法。

在一般的问题中,我们常常是利用dfs来不断的将根转换为根的子节点。

我们会发现一些事情:

图片描述

我们一次转换的过程,其实有很大一部分并没有发生变化,体现在 DP 转移中,就是这些点的 DP 值也不会发生改变。

实际上改变的只有改变身份的两个点,其他的点都不会发生变化。

在换根的问题中,一般的步骤如下:

  1. 以 $1$ 为根进行一遍扫描,并且处理出必要的信息,例如深度、DP 值等。
  2. 开始以 $1$ 进行换根,并且向下递归,在递归之前,需要将自己变成子节点的身份。
  3. 进入新的根后,按照根的身份,重新进行转移。并且维护答案。

4.1 卖树

图片描述

本题需要计算以每个点为根的情况下,产生的盈利。

如果我们确定了一个点为根,我们很容易算出答案,如果确定了根,问题就变成了求最大深度,这个问题只需要一遍DFS就可以完成。

1
void dfs(int u, int f, int dt) { // 求出以1为根的原始信息 dep[u] = dt; Mdp[u] = 0; // Mdp即为当前点为根的最大深度 for (int v : G[u]) { if (v == f) continue; dfs(v, u, dt + 1); Mdp[u] = max(Mdp[v] + 1, Mdp[u]); } }

因为节点数量太多,我们无法承受 $O(n^2)$ 的复杂度,所以我们需要进行换根,

基本思想如上述一致:

  1. 我们需要先算出以 $1$ 为根的信息,包括以每个节点为子树的最大深度,从 $1$ 转移到 $i$ 节点的代价。
  2. 我们从 $1$ 号点开始换根,每次只将根的身份换给儿子,然后进入递归,进入之前,我们需要将当前点的身份改为子节点。
  3. 进行新的根,由于原来的转移已经失效,所以需要重新转移。并且维护答案,然后重复2步骤。
  • C++
1
#include <iostream> #include <vector> using namespace std; const int N = 1e5+10; vector<int> G[N]; int n, k, c; int dep[N], Mdp[N]; typedef long long ll; ll ans = 0; void dfs(int u, int f, int dt) { // 求出以1为根的原始信息 dep[u] = dt; Mdp[u] = 0; for (int v : G[u]) { if (v == f) continue; dfs(v, u, dt + 1); Mdp[u] = max(Mdp[v] + 1, Mdp[u]); } } void dfs2(int u, int f) { // 开始换根 /** * 重新转移 */ int tmpf = 0, Mx1 = 0, Mx2 = 0; for (int v : G[u]) { tmpf = max(tmpf, Mdp[v] + 1); } // 维护答案 ans = max(1ll * tmpf * k - 1ll * dep[u] * c, ans); // 根变儿子步骤 int pre = Mdp[u]; for (int v : G[u]) { if (Mdp[v] + 1 > Mx1) { Mx2 = Mx1; Mx1 = Mdp[v] + 1; } else if (Mdp[v] + 1 > Mx2) { Mx2 = Mdp[v] + 1; } } for (int v : G[u]) { if (v == f) continue; // 由于根要变成儿子,所以要改变原来的转移值 if (Mdp[v] + 1 == Mx1) Mdp[u] = Mx2; else Mdp[u] = Mx1; dfs2(v, u); } // 还原原始的值。 Mdp[u] = pre; } void sol() { for (int i = 1; i <= n; ++i) G[i].clear(); ans = 0; cin >> n >> k >> c; int u, v; for (int i = 1; i < n; ++i) { cin >> u >> v; G[u].push_back(v); G[v].push_back(u); } dfs(1, 0, 0); dfs2(1, 0); cout << ans << '\n'; } int main() { ios::sync_with_stdio(0); int T; cin >> T; while (T --) { sol(); } return 0; }
  • Python
1
from collections import defaultdict import sys sys.setrecursionlimit(100000) N = 100010 G = defaultdict(list) n, k, c = 0, 0, 0 dep = [0] * N Mdp = [0] * N ans = 0 def dfs(u, f, dt): # 求出以1为根的原始信息 global dep, Mdp dep[u] = dt Mdp[u] = 0 for v in G[u]: if v == f: continue dfs(v, u, dt + 1) Mdp[u] = max(Mdp[v] + 1, Mdp[u]) def dfs2(u, f): # 开始换根 global ans, dep, Mdp tmpf = 0 Mx1 = 0 Mx2 = 0 # 重新转移 for v in G[u]: tmpf = max(tmpf, Mdp[v] + 1) # 维护答案 ans = max(ans, tmpf * k - dep[u] * c) # 根变儿子步骤 pre = Mdp[u] for v in G[u]: if Mdp[v] + 1 > Mx1: Mx2 = Mx1 Mx1 = Mdp[v] + 1 elif Mdp[v] + 1 > Mx2: Mx2 = Mdp[v] + 1 for v in G[u]: if v == f: continue # 由于根要变成儿子,所以要改变原来的转移值 if Mdp[v] + 1 == Mx1: Mdp[u] = Mx2 else: Mdp[u] = Mx1 dfs2(v, u) # 还原原始的值。 Mdp[u] = pre def sol(): global n, k, c, ans, G, dep, Mdp n, k, c = map(int, input().split()) G.clear() ans = 0 for _ in range(n - 1): u, v = map(int, input().split()) G[u].append(v) G[v].append(u) dfs(1, 0, 0) dfs2(1, 0) print(ans) T = int(input()) for _ in range(T): sol()
  • Java
1
import java.util.*; import java.io.*; public class Main { static final int N = 100010; static List<Integer>[] G; static int n, k, c; static int[] dep, Mdp; static long ans; static void dfs(int u, int f, int dt) { // 求出以1为根的原始信息 Mdp[u] = 0; for (int v : G[u]) { if (v == f) continue; dfs(v, u, dt + 1); Mdp[u] = Math.max(Mdp[v] + 1, Mdp[u]); } } static void dfs2(int u, int f) { // 开始换根 int tmpf = 0, Mx1 = 0, Mx2 = 0; /** * 重新转移 */ for (int v : G[u]) { tmpf = Math.max(tmpf, Mdp[v] + 1); } // 维护答案 ans = Math.max(ans, (long) tmpf * k - (long) dep[u] * c); // 根变儿子步骤 int pre = Mdp[u]; for (int v : G[u]) { if (Mdp[v] + 1 > Mx1) { Mx2 = Mx1; Mx1 = Mdp[v] + 1; } else if (Mdp[v] + 1 > Mx2) { Mx2 = Mdp[v] + 1; } } for (int v : G[u]) { if (v == f) continue; // 由于根要变成儿子,所以要改变原来的转移值 if (Mdp[v] + 1 == Mx1) { Mdp[u] = Mx2; } else { Mdp[u] = Mx1; } dfs2(v, u); } // 还原原始的值。 Mdp[u] = pre; } static void sol(Scanner scanner) { for (int i = 1; i <= n; i++) G[i].clear(); ans = 0; n = scanner.nextInt(); k = scanner.nextInt(); c = scanner.nextInt(); int u, v; for (int i = 1; i < n; i++) { u = scanner.nextInt(); v = scanner.nextInt(); G[u].add(v); G[v].add(u); } dfs(1, 0, 0); dfs2(1, 0); System.out.println(ans); } public static void main(String[] args) { G = new ArrayList[N]; for (int i = 0; i < N; i++) { G[i] = new ArrayList<>(); } dep = new int[N]; Mdp = new int[N]; Scanner scanner = new Scanner(System.in); int T = scanner.nextInt(); while (T-- > 0) { sol(scanner); } } }

5. 作业

题目 链接
取气球(算法赛) https://www.lanqiao.cn/problems/17024/learning/
左孩子右兄弟(21 年省赛) https://www.lanqiao.cn/problems/1451/learning/