贴贴大佬的计蒜客题解:
The Preliminary Contest for ICPC China Nanchang National Invitational and International Silk-Road Programming Contest
离线处理,
运用树链剖分让LCA跑快点
关键是把n-1条边,和m条询问边存起来
然后按边权值W进行升序;
这样在计数询问的时候我们从小到大计数;
每条边只会被记一次且从小到大,这样就不用担心当前计数会受上一计数更新时的影响;
每次把小于等于当前查询的边加到链上;
查询就是查询链上有多少条边被加入过;
#include<bits/stdc++.h>
using namespace std;
const int M=1e5+10;
inline int read(){
int sum=0,x=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')
x=0;
ch=getchar();
}
while(ch>='0'&&ch<='9')
sum=(sum<<1)+(sum<<3)+(ch^48),ch=getchar();
return x?sum:-sum;
}
int f[M],sz[M],deep[M],son[M],dfn[M],top[M],ans[M],t[M<<2],n,cnt;
vector<int >graph[M];
struct node{
int u,v,w,index;
bool operator<(const node &b)const{
return w<b.w;
}
}q[M],e[M];
void dfs1(int u,int from){
f[u]=from;
sz[u]=1;
deep[u]=deep[from]+1;
for(int i=0;i<graph[u].size();i++){
int v=graph[u][i];
if(v!=from){
dfs1(v,u);
sz[u]+=sz[v]; if(sz[v]>sz[son[u]])
son[u]=v;
}
}
}
void dfs2(int u,int t){
top[u]=t;
dfn[u]=++cnt;
if(!son[u])
return ;
dfs2(son[u],t);
for(int i=0;i<graph[u].size();i++){
int v=graph[u][i];
if(v!=son[u]&&v!=f[u]){
dfs2(v,v);
}
}
}
void update(int sign,int c,int root,int l,int r){
if(l==r){
t[root]+=c;
return ;
}
int midd=l+r>>1;
if(sign<=midd)
update(sign,c,root<<1,l,midd);
else
update(sign,c,root<<1|1,midd+1,r);
t[root]=t[root<<1]+t[root<<1|1];
}
int find(int L,int R,int root,int l,int r){
if(L<=l&&r<=R)
return t[root];
int midd=l+r>>1;
int c=0;
if(L<=midd)
c+=find(L,R,root<<1,l,midd);
if(R>midd)
c+=find(L,R,root<<1|1,midd+1,r);
return c;
}
int solve(int u,int v){
int c=0;
int fu=top[u],fv=top[v];
while(fu!=fv){
if(deep[fu]>=deep[fv]){
c+=find(dfn[fu],dfn[u],1,1,n);
u=f[fu],fu=top[u];
}
else{
c+=find(dfn[fv],dfn[v],1,1,n);
v=f[fv],fv=top[v];
} }
if(dfn[u]<dfn[v])
c+=find(dfn[u]+1,dfn[v],1,1,n);
else if(dfn[u]>dfn[v])
c+=find(dfn[v]+1,dfn[u],1,1,n);
return c; }
int main(){
n=read();
int m=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),w=read();
e[i].u=x,e[i].v=y,e[i].w=w;
graph[x].push_back(y);
graph[y].push_back(x);
}
dfs1(1,1);
dfs2(1,1);
/*cout<<"~~~~~~~~~";
for(int i=1;i<=n;i++)
cout<<dfn[i]<<" ";
cout<<endl;*/
for(int i=1;i<=m;i++){
int x=read(),y=read(),w=read();
q[i].u=x,q[i].v=y,q[i].w=w;
q[i].index=i;
}
sort(e+1,e+n);
sort(q+1,q+1+m);
int cur=1;
for(int i=1;i<=m;i++){
while(cur<n&&e[cur].w<=q[i].w){
int u=e[cur].u,v=e[cur].v;
if(deep[u]<deep[v])
swap(u,v);
update(dfn[u],1,1,1,n);
cur++;
}
ans[q[i].index]+=solve(q[i].u,q[i].v);
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}