题意:给定一棵树,每个节点有一个颜色,问树上有多少种子串(定义子串为某两个点上的路径),保证叶子节点数<=20。n<=10^5
题解:
叶子节点小于等于20,考虑将每个叶子节点作为根把树给提起来形成一棵trie,然后定义这棵树的子串为从上到下的一个串(深度从浅到深)。
这样做我们可以发现每个子串必定是某棵trie上的一段直线。统计20棵树的不同子串只需要把它们建到一个自动机上就行了,相当于把20棵trie合并成一棵大的。
对于每个节点x,它贡献的子串数量是max[x]-min[x],又因为min[x]=max[fa]+1,则=max[x]-max[fa],就是step[x]-step[fa];
学会了怎样在sam上插入一颗trie,就直接记录一下父亲在sam上的节点作为p。注意每次都要新开一个点,不然会导致无意义的子串出现。
例如一棵树 (括号内为i颜色)
1(0)
2(1)
3(2) 4(3)
2是1的孩子,3和4都是2的孩子。在以1为根节点的时候插入了这棵trie,在以3为根节点的时候son[root][2]已经存在,如果用它来当现在的点的话就会让一棵trie接在另一棵的末位,导致无意义的子串出现,答案偏大。
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
using namespace std; typedef long long LL;
const int N=*;
int n,c,tot,len,last;
int w[N],son[N][],step[N],pre[N],first[N],cnt[N],id[N];
struct node{
int x,y,next;
}a[*N]; void ins(int x,int y)
{
a[++len].x=x;a[len].y=y;
a[len].next=first[x];first[x]=len;
} int add_node(int x)
{
step[++tot]=x;
return tot;
} int extend(int p,int ch)
{
// int np;
// if(son[p][ch]) return son[p][ch];
// else np=add_node(step[p]+1);
int np=add_node(step[p]+);//debug 每次都要新开一个点 while(p && !son[p][ch]) son[p][ch]=np,p=pre[p];
if(p==) pre[np]=;
else
{
int q=son[p][ch];
if(step[q]==step[p]+) pre[np]=q;
else
{
int nq=add_node(step[p]+);
memcpy(son[nq],son[q],sizeof(son[q]));
pre[nq]=pre[q];
pre[np]=pre[q]=nq;
while(son[p][ch]==q) son[p][ch]=nq,p=pre[p];
}
}
last=np;
return np;
} void dfs(int x,int fa,int now)
{
int nt=extend(now,w[x]);
// printf("%d\n",nt);
for(int i=first[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=fa) dfs(y,x,nt);
}
} int main()
{
freopen("a.in","r",stdin);
scanf("%d%d",&n,&c);
for(int i=;i<=n;i++) scanf("%d",&w[i]);
tot=;len=;
memset(son,,sizeof(son));
memset(pre,,sizeof(pre));
memset(cnt,,sizeof(cnt));
memset(first,,sizeof(first));
step[++tot]=;last=;
for(int i=;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
cnt[x]++;cnt[y]++;
}
// for(int i=1;i<=len;i++) printf("%d -- > %d\n",a[i].x,a[i].y);
for(int i=;i<=n;i++)
{
if(cnt[i]==) dfs(i,,);
}
// for(int i=1;i<=tot;i++) printf("%d ",id[i]);printf("\n");
LL ans=;
for(int i=;i<=tot;i++) ans+=(LL)(step[i]-step[pre[i]]);
printf("%lld\n",ans);
return ;
}