前去学习了一下启发式合并当然并查集的还没学呢。。

去做了一下网上的例题梦幻布丁,发现set的功能真是挺强大的(有机会好好学学)。

我比较蒙逼的就是dalao们的fa数组,后来了解了,解释在代码里,挺浅显易懂的emm大概?

#include<bits/stdc++.h>
using namespace std;
const int N=100005;
int n,m,ans;
int fa[1000005],v[100005];
set<int> a[1000005];
inline long long read() {
    long long x = 0;
    long long f = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-')
            f = -f;
        c = getchar();
    }
    while (c <= '9' && c >= '0') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}
void solve(int x,int y){
    for(set<int>::iterator i=a[x].begin();i!=a[x].end();i++){
        if(v[*i-1]==y)ans--;
        if(v[*i+1]==y)ans--;
        a[y].insert(*i);
    }
    for(set<int>::iterator i=a[x].begin();i!=a[x].end();i++){
        v[*i]=y;
    }
    a[x].clear();
}
int main(){
     n=read();m=read();
    for(int i=1;i<=n;i++) v[i]=read();
    for(int i=1;i<=n;i++){
        fa[v[i]]=v[i];
        if(v[i]!=v[i-1]) ans++;  //有颜色间隔
        a[v[i]].insert(i);  //当前下标扔对应颜色的框里 
    }
    for(int i=1;i<=m;i++) {
        int o=read();
        if(o==2)printf("%d\n",ans);
        else{
            int x=read(),y=read();
            if(x==y)continue;
            //cout<<fa[x]<<endl<<fa[y]<<endl; 
            if(a[fa[x]].size()>a[fa[y]].size()){
               swap(fa[x],fa[y]); //fa[i]是用于存储值为i的数位置存在了S{fa[i]}集合中,所以合并时也不要忘记了修改fa[i] 
            } //避免下次合并时S{y}为空 
            x=fa[x],y=fa[y];
            solve(x,y);
        }
    }
} 
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 5006
#define all 5000
#define LL long long
using namespace std;
int n,T,tot[2],num[maxn],w[maxn],p[maxn],lnk[2][maxn],son[2][maxn*2],nxt[2][maxn*2];
LL f[maxn][maxn];
bool vis[maxn];
void add(int t,int x,int y){
    nxt[t][++tot[t]]=lnk[t][x];son[t][tot[t]]=y;lnk[t][x]=tot[t];
}
void dfs(int x){
    vis[x]=0;num[x]=1;
    for(int j=lnk[0][x];j;j=nxt[0][j]) if(vis[son[0][j]]){
        add(1,x,son[0][j]);dfs(son[0][j]);num[x]+=num[son[0][j]];
    }
}
void merge(int t,int x){
    for(int j=all;j>=w[x];j--)f[t][j]=max(f[t][j],f[t][j-w[x]]+p[x]);
    for(int j=lnk[1][x];j;j=nxt[1][j])merge(t,son[1][j]);
}
void solve(int x){
    int Max=0,t=0;vis[x]=0;
    for(int j=lnk[1][x];j;j=nxt[1][j]){
        if(num[son[1][j]]>Max)Max=num[son[1][j]],t=son[1][j];
        solve(son[1][j]);
    }
    if(t)memcpy(f[x],f[t],sizeof(f[x]));
    for(int j=all;j>=w[x];j--)f[x][j]=max(f[x][j],f[x][j-w[x]]+p[x]);
    for(int j=lnk[1][x];j;j=nxt[1][j]) if(son[1][j]!=t)merge(x,son[1][j]);
}
int main(){
    freopen("A.in","r",stdin);
    freopen("A.out","w",stdout);
    scanf("%d",&n);
    for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),add(0,x,y),add(0,y,x);
    for(int i=1;i<=n;i++)scanf("%d%d",&p[i],&w[i]);
    memset(vis,1,sizeof(vis));
    dfs(1);solve(1);
    scanf("%d",&T);
    while(T--){
        int x,s;
        scanf("%d%d",&x,&s);
        printf("%lld\n",f[x][s]);
    }
    return 0;
}
01-10 00:21