树形DP有一个独特的优化,就是通过递归,枚举目前有效的元素个数,求dp[ i ][ j ] (表示 选取以i为根的子树中有选取j个元素的最大取值)

(搭配 siz 数组表示当前该节点的总共子孙数)

1.hdu1561(树形依赖背包裸题)

 注意 siz 数组的运用,以及 u 点选择的节点数时要逆向枚举,就像01背包

 复杂度看似O(n^3),实际是 O( n^2 ) 左右。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;

const int maxn = 250;
vector<int> g[maxn];
int dp[maxn][maxn];
int val[maxn];
int siz[maxn];
int n,m;

//dp[i][j] 表示 选取以i为根的子树中有选取j个元素的最大取值
void dfs(int u){
    siz[u]=1;
    dp[u][1] = val[u];
    for(int i=0; i<g[u].size(); i++){
        int v = g[u][i];
        dfs(v);                                //这里的siz[u]不包括siz[v] ,并且是把效率很低的2^n举法用01背包来做
        for(int i=siz[u]; i>=1; i--){         //这里就像01背包里,避免由这个点的情况递推这个点的更佳情况
            for(int j=1; j<=siz[v]&&i+j<=m; j++){  //就比如要避免刚刚还说是从v取3个点推出的最优
                dp[u][i+j] = max(dp[u][i+j], dp[u][i]+dp[v][j]); //后面又从前面的dp值而只从j中取1个点得出错误的更优解
            }
        }
        siz[u] += siz[v];
    }
}

int main(){
    while(scanf("%d%d",&n,&m)!=EOF){
        if(n==0&&m==0) break;
        for(int i=0; i<=n; i++){
            for(int j=0; j<=n; j++){
                dp[i][j] = 0;
            }
            g[i].clear();
        }
        int t;
        for(int i=1; i<=n; i++){
            scanf("%d%d",&t,val+i);
            g[t].push_back(i);
        }
        m++;
        dfs(0);
        printf("%d\n", dp[0][m]);
    }
}
View Code

2.codeforces 815C  (树形dp)

这个选取树上物品可以不需要有父子关系的,但使用优惠券和父子关系有关,所以可以把 dp数组多增加一维,表示是否能够使用优惠券。

只需要设置默认值为 inf ,再这样初始化:

dp[u][0][0]=0;

dp[u][1][0]=c[u];

dp[u][1][1]=c[u]-d[u];

就可以在枚举时考虑到 0 这个元素。

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;

const int maxn = 5005;
int dp[maxn][maxn][2];    //dp[i][j]表示以i为根的子树中取j个元素的最大值
vector<int> g[maxn];      //再来一维表示是否购买根节点i这个元素,也就是用不用优惠券
int val[maxn],d[maxn];
int siz[maxn];

void dfs(int u){
    siz[u] = 1;
    dp[u][1][1] = val[u]-d[u];
    dp[u][1][0] = val[u];
    dp[u][0][0] = 0;
    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        dfs(v);      //这里的siz[u]不包括siz[v]  
        for(int i=siz[u]; i>=0; i--){    //这里的0是为了处理可以不取
            for(int j=0; j<=siz[v]; j++){
                dp[u][i+j][0] = min(dp[u][i+j][0], dp[u][i][0]+dp[v][j][0]);
                dp[u][i+j][1] = min(dp[u][i+j][1], dp[u][i][1]+min(dp[v][j][0],dp[v][j][1]));
            }
        }
        siz[u] += siz[v];
    }
}

int main(){
    int n,b;
    scanf("%d%d",&n,&b);
    scanf("%d%d",val+1,d+1);
    for(int i=2; i<=n; i++){
        int t;
        scanf("%d%d%d",val+i,d+i,&t);
        g[t].push_back(i);
    }
    memset(dp,0x3f,sizeof(dp));
    dfs(1);
    int ans=n;
    while(dp[1][ans][1]>b&&dp[1][ans][0]>b){
        ans--;
        //printf("%d %d\n",dp[1][ans][1],dp[1][ans][0] );
    }
    printf("%d\n", ans);
}
View Code
01-25 18:48