人生的第一道树分治,要是早点学我南京赛就不用那么挫了,树分治的思路其实很简单,就是对子树找到一个重心(Centroid),实现重心分解,然后递归的解决分开后的树的子问题,关键是合并,当要合并跨过重心的两棵子树的时候,需要有一个接近O(n)的方法,因为f(n)=kf(n/k)+O(n)解出来才是O(nlogn).在这个题目里其实就是将第一棵子树的集合里的每个元素,判下有没符合条件的,有就加上,然后将子树集合压进大集合,然后继续搞第二棵乃至第n棵.我的过程用了map,合并是nlogn的所以代码速度颇慢,大概6s,题目时限10s,可以改成hash应该会快许多,毕竟用map实在太慢,用vector也可以,具体可以参见挑战程序设计竞赛代码.下面的代码查找重心用了挑战的代码.

#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<cstring>
#include<string>
#include<cstdio>
#include<algorithm>
#include<map>
#include<vector>
#define maxv 50000
#define ll long long
using namespace std; int n,k;
vector<int> G[maxv+50];
ll val[maxv+50];
ll prime[maxv+50];
ll convert_three(ll v)
{
ll bas=1;ll res=0;
for(int i=0;i<k;++i){
int num=0;
while(v%prime[i]==0){
v/=prime[i];
num++;
}
num%=3;res+=num*bas;
bas*=3;
}
return res;
} ll xor(ll x,ll y)
{
ll res=0;ll bas=1;
for(int i=0;i<k;++i){
res+=((x%3)+(y%3))%3*bas;
x/=3;y/=3;
bas*=3;
}
return res;
} ll inv(ll x)
{
ll res=0;ll bas=1;
for(int i=0;i<k;++i){
res+=((3-(x%3))%3)*bas;
x/=3;
bas*=3;
}
return res;
} void print(ll x){
while(x){
cout<<x%3;
x/=3;
}
cout<<endl;
} bool centroid[maxv+50];
int ssize[maxv+50];
int ans; map<ll,int> sta;
map<ll,int>::iterator it;
int compute_ssize(int v,int p)
{
int c=1;
for(int i=0;i<G[v].size();++i){
int w=G[v][i];
if(w==p||centroid[w]) continue;
c+=compute_ssize(G[v][i],v);
}
ssize[v]=c;
return c;
} pair<int,int> search_centroid(int v,int p,int t)
{
pair<int,int> res=make_pair(INT_MAX,-1);
int s=1,m=0;
for(int i=0;i<G[v].size();++i){
int w=G[v][i];
if(w==p||centroid[w]) continue;
res=min(res,search_centroid(w,v,t));
m=max(m,ssize[w]);
s+=ssize[w];
}
m=max(m,t-s);
res=min(res,make_pair(m,v));
return res;
} void enumerate_mul(int v,int p,ll d,map<ll,int> &ds)
{
if(ds.count(d)) ds[d]++;
else ds[d]=1;
for(int i=0;i<G[v].size();++i){
int w=G[v][i];
if(w==p||centroid[w]) continue;
enumerate_mul(w,v,xor(d,val[w]),ds);
}
} void solve(int v)
{
compute_ssize(v,-1);
int s=search_centroid(v,-1,ssize[v]).second;
centroid[s]=true;
for(int i=0;i<G[s].size();++i){
if(centroid[G[s][i]]) continue;
solve(G[s][i]);
}
sta.clear();
sta[val[s]]=1;map<ll,int> tds;
for(int i=0;i<G[s].size();++i){
if(centroid[G[s][i]]) continue;
tds.clear();
enumerate_mul(G[s][i],s,val[G[s][i]],tds);
it=tds.begin();
while(it!=tds.end()){
ll rev=inv((*it).first);
if(sta.count(rev)){
ans+=sta[rev]*(*it).second;
}
++it;
}
it=tds.begin();
while(it!=tds.end()){
ll vv=xor((*it).first,val[s]);
if(sta.count(vv)){
sta[vv]+=(*it).second;
}
else{
sta[vv]=(*it).second;
}
++it;
}
}
centroid[s]=false;
} int main()
{
while(cin>>n>>k){
ans=0;
for(int i=0;i<k;++i){
scanf("%I64d",&prime[i]);
}
G[0].clear();
for(int i=1;i<=n;++i){
scanf("%I64d",&val[i]);
val[i]=convert_three(val[i]);
if(val[i]==0) ans++;
//print(val[i]);
G[i].clear();
}
int u,v;
for(int i=0;i<n-1;++i){
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
memset(centroid,0,sizeof(centroid));
solve(1);
printf("%d\n",ans);
}
return 0;
}
04-30 19:10