https://www.luogu.org/problemnew/show/P3676

这题被我当成动态dp去做了,码了4k,搞了一个换根的动态dp

 #include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
struct E
{
int to,nxt;
}e[];
int f1[],ne;
struct P1
{
int len;ll a,b,c,d,e,f;
//长度,(点权)和,后缀和之和,后缀和的平方之和,(答案)和
//前缀和之和,前缀和的平方之和
};
struct P2
{
ll a,b;
//点权和,答案和
};
ll a[];
int sz[],hson[],ff[];
int b[],pl[];
int n,m;
inline void merge(P1 &c,const P1 &a,const P1 &b)
{
c.len=a.len+b.len;
c.a=a.a+b.a;
c.b=b.b+a.b+b.a*a.len;
c.c=b.c+b.a*b.a*a.len+*a.b*b.a+a.c;
c.d=a.d+b.d;
c.e=a.e+b.e+a.a*b.len;
c.f=a.f+a.a*a.a*b.len+*b.e*a.a+b.f;
}
inline void initnode(P1 &c,const P2 &a)
{
c.len=;c.a=c.b=c.e=a.a;c.c=c.f=a.a*a.a;c.d=a.b;
}
namespace S
{
#define lc (num<<1)
#define rc (num<<1|1)
P1 d[];
inline void upd(int num){merge(d[num],d[lc],d[rc]);}
P1 x;int L;
void _setx(int l,int r,int num)
{
if(l==r)
{
d[num]=x;
return;
}
int mid=(l+r)>>;
if(L<=mid) _setx(l,mid,lc);
else _setx(mid+,r,rc);
upd(num);
}
P1 getx(int L,int R,int l,int r,int num)
{
if(L<=l&&r<=R) return d[num];
int mid=(l+r)>>;
if(L<=mid&&mid<R)
{
P1 x;
merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+,r,rc));
return x;
}
else if(L<=mid)
return getx(L,R,l,mid,lc);
else if(mid<R)
return getx(L,R,mid+,r,rc);
else
exit(-);
}
}
void dfs1(int u,int fa)
{
sz[u]=;
for(int v,k=f1[u];k;k=e[k].nxt)
if(e[k].to!=fa)
{
v=e[k].to;
ff[v]=u;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[hson[u]]) hson[u]=v;
}
}
P2 d1[];//d1[i]维护i节点及其轻儿子的贡献
P2 d2[];//d2[i]维护i节点(是重链顶)所在重链的值
int tp[],dwn[];//链顶,链底
inline void updd1(int x)
{
initnode(S::x,d1[x]);S::L=pl[x];S::_setx(,n,);
}
void dfs2(int u,int fa)
{
d1[u].a=a[u];
b[++b[]]=u;pl[u]=b[];
tp[u]=(u==hson[fa])?tp[fa]:u;
if(hson[u]) dfs2(hson[u],u);
dwn[u]=hson[u]?dwn[hson[u]]:u;
int v,k;
for(k=f1[u];k;k=e[k].nxt)
if(e[k].to!=fa&&e[k].to!=hson[u])
{
v=e[k].to;
dfs2(v,u);
d1[u].b+=d2[v].b;
d1[u].a+=d2[v].a;
}
updd1(u);
if(u==tp[u])
{
P1 t=S::getx(pl[u],pl[dwn[u]],,n,);
d2[u].a=t.a;d2[u].b=t.d+t.c;
}
}
inline ll getsize(int x)
{
return S::getx(pl[x],pl[dwn[x]],,n,).a;
}
int main()
{
int i,x,y,idx;ll z,ans,szall;P1 t;
scanf("%d%d",&n,&m);
for(i=;i<n;++i)
{
scanf("%d%d",&x,&y);
e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
}
for(i=;i<=n;++i) scanf("%lld",a+i);
dfs1(,);
dfs2(,);
while(m--)
{
scanf("%d",&idx);
if(idx==)
{
scanf("%d%lld",&x,&z);
d1[x].a-=a[x];a[x]=z;d1[x].a+=z;
while(x)
{
updd1(x);
x=tp[x];y=ff[x];
t=S::getx(pl[x],pl[dwn[x]],,n,);
d1[y].a-=d2[x].a;d1[y].b-=d2[x].b;
d2[x].a=t.a;d2[x].b=t.d+t.c;
d1[y].a+=d2[x].a;d1[y].b+=d2[x].b;
x=y;
}
//printf("3t%d\n",d2[1].b);
}
else
{
scanf("%d",&x);
ans=d2[].b;
szall=getsize();
if(x!=tp[x])
{
y=tp[x];
z=d1[y].a;
d1[y].a+=szall-getsize(y);
updd1(y);
if(y!=dwn[y])
{
t=S::getx(pl[y]+,pl[dwn[y]],,n,);
ans-=t.c;
}
if(x!=dwn[y])
{
t=S::getx(pl[x]+,pl[dwn[y]],,n,);
ans+=t.c;
}
t=S::getx(pl[y],pl[x]-,,n,);
ans+=t.f;
d1[y].a=z;
updd1(y);
x=y;
}
while(x!=)
{
y=ff[x];
z=getsize(x);
ans-=z*z;
z=szall-z;
ans+=z*z;
x=y;
if(x!=tp[x])
{
y=tp[x];
z=d1[y].a;
d1[y].a+=szall-getsize(y);
updd1(y);
if(y!=dwn[y])
{
t=S::getx(pl[y]+,pl[dwn[y]],,n,);
ans-=t.c;
}
if(x!=dwn[y])
{
t=S::getx(pl[x]+,pl[dwn[y]],,n,);
ans+=t.c;
}
t=S::getx(pl[y],pl[x]-,,n,);
ans+=t.f;
d1[y].a=z;
updd1(y);
x=y;
}
}
printf("%lld\n",ans);
}
}
return ;
}

码完一看题解,???好像画风不太对??

所以还是无视上面那个代码吧...

正常得多的做法:

 #include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
struct E
{
int to,nxt;
}e[];
int f1[],ne;
int n,m;
struct S
{
#define lowbit(x) ((x)&(-x))
ll d1[],d2[];
void _add(int p,ll x,ll *d)
{
for(;p<=n;p+=lowbit(p))
d[p]+=x;
}
ll _sum(int p,ll *d)
{
ll ans=;
for(;p>;p-=lowbit(p))
ans+=d[p];
return ans;
}
void add(int l,int r,ll x)
{
_add(l,x,d1);
_add(r+,-x,d1);
_add(l,x*l,d2);
_add(r+,-x*(r+),d2);
}
ll sum(int l,int r)
{
return (r+)*_sum(r,d1)-_sum(r,d2)
-l*_sum(l-,d1)+_sum(l-,d2);
}
}s1;
int b[],pl[];
ll a[],a2[];
int sz[],hson[],tp[];
ll dep[];
int ff[];
void dfs1(int u,int fa)
{
sz[u]=;
for(int k=f1[u];k;k=e[k].nxt)
if(e[k].to!=fa)
{
ff[e[k].to]=u;
dep[e[k].to]=dep[u]+;
dfs1(e[k].to,u);
sz[u]+=sz[e[k].to];
if(sz[e[k].to]>sz[hson[u]]) hson[u]=e[k].to;
}
}
void dfs2(int u,int fa)
{
b[++b[]]=u;pl[u]=b[];
tp[u]=u==hson[fa]?tp[fa]:u;
a2[u]=a[u];
if(hson[u])
{
dfs2(hson[u],u);
a2[u]+=a2[hson[u]];
}
for(int k=f1[u];k;k=e[k].nxt)
if(e[k].to!=fa&&e[k].to!=hson[u])
{
dfs2(e[k].to,u);
a2[u]+=a2[e[k].to];
}
}
inline ll gsum1(int x)//x到1的路径和
{
int y;ll an=;
for(;x;x=ff[y])
{
y=tp[x];
an+=s1.sum(pl[y],pl[x]);
}
return an;
}
inline void add1(int x,ll z)//x到1加上z
{
int y;
for(;x;x=ff[y])
{
y=tp[x];
s1.add(pl[y],pl[x],z);
}
}
ll anss;
int main()
{
ll ans,z,t;
int i,x,y,idx;
scanf("%d%d",&n,&m);
for(i=;i<n;++i)
{
scanf("%d%d",&x,&y);
e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
}
for(i=;i<=n;++i)
scanf("%lld",a+i);
dfs1(,);
dfs2(,);
for(i=;i<=n;++i)
{
s1.add(pl[i],pl[i],a2[i]);
anss+=a2[i]*a2[i];
}
while(m--)
{
scanf("%d",&idx);
if(idx==)
{
scanf("%d%lld",&x,&z);
z=z-a[x];a[x]+=z;
anss+=z*z*(dep[x]+);
anss+=*gsum1(x)*z;
add1(x,z);
}
else
{
scanf("%d",&x);
ans=anss;
t=gsum1();
ans+=dep[x]*t*t;
ans-=*t*(gsum1(x)-t);
printf("%lld\n",ans);
}
}
return ;
}
05-11 20:50