题目链接
洛谷:https://www.luogu.org/problemnew/show/P4075
LOJ:https://loj.ac/problem/2065
Solution
这种题看起来就很点分治啊...
我们可以发现,我们需要一个支持询问字符串相等,并且支持在一个串前面加一个串的数据结构,显然我们用哈希就行了。
那么我们直接开桶然后拿哈希维护,总复杂度\(O(Tn\log n)\)。
#include<bits/stdc++.h>
using namespace std;
template<typename T> void read(T &x) {
x=0;T ff=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') ff=-ff;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=ff;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
#define pii pair<int,int >
#define vec vector<int >
#define pb push_back
#define mp make_pair
#define fr first
#define sc second
#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int bs = 29;
const int mod = 1e6+3;
ll ans;
char s[maxn];
int sz[maxn],rs[maxn],rp[maxn],rt,f[maxn],SZ,mxd,rrs[maxn],rrp[maxn];
int n,m,r[maxn],tag[maxn],head[maxn],tot,len,suf[maxn],pre[maxn],pw[maxn],vis[maxn];
struct edge{int to,nxt;}e[maxn<<1];
void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}
void prepare() {
for(len=m;len<n;len+=m)
for(int i=len+1;i<=len+m;i++) tag[i]=tag[i-len];
for(int i=1;i<=len;i++) pre[i]=(pre[i-1]*bs+tag[i])%mod;
suf[len+1]=0;for(int i=len;i;i--) suf[i]=(suf[i+1]*bs+tag[i])%mod;
reverse(suf+1,suf+len+1);
pw[0]=1;for(int i=1;i<=len;i++) pw[i]=pw[i-1]*bs%mod;
}
void get_rt(int x,int fa) {
sz[x]=1,f[x]=0;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa&&!vis[e[i].to]) get_rt(e[i].to,x),sz[x]+=sz[e[i].to],f[x]=max(f[x],sz[e[i].to]);
f[x]=max(f[x],SZ-sz[x]);if(f[rt]>f[x]) rt=x;
}
void calc(int x,int fa,int dep,int hs,const int &c) {
hs=(hs+1ll*r[x]*pw[dep])%mod;dep++;mxd=max(mxd,dep);
if(hs==pre[dep]) {
rrp[dep%m]++;
if(c==tag[dep%m+1]) ans+=rs[m-dep%m-1];
}
if(hs==suf[dep]) {
rrs[dep%m]++;
if(c==tag[m-dep%m]) ans+=rp[m-dep%m-1];
}
for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]) calc(e[i].to,x,dep,hs,c);
}
void dfs(int x) {
f[rt=0]=maxn;get_rt(x,0);x=rt;vis[x]=1;mxd=0;
rp[0]=rs[0]=1;int mmxd=0;
for(int i=head[x];i;i=e[i].nxt)
if(!vis[e[i].to]) {
mxd=0;calc(e[i].to,x,0,0,r[x]);mmxd=max(mmxd,mxd);
for(int j=0;j<=min(mxd,m);j++) rp[j]+=rrp[j],rs[j]+=rrs[j],rrp[j]=rrs[j]=0;
}
for(int i=0;i<=mmxd;i++) rs[i]=rp[i]=0;
for(int i=head[x];i;i=e[i].nxt)
if(!vis[e[i].to]) SZ=sz[e[i].to],dfs(e[i].to);
}
void solve() {
read(n),read(m);scanf("%s",s+1);ans=0;
for(int i=1;i<=n;i++) r[i]=s[i]-'A'+1;
for(int x,y,i=1;i<n;i++) read(x),read(y),ins(x,y);
scanf("%s",s+1);
for(int i=1;i<=m;i++) tag[i]=s[i]-'A'+1;
prepare();SZ=n;dfs(1);write(ans);
}
#define clr(x) memset(x,0,(n+3)*4)
void clear() {clr(head),clr(vis);tot=0;}
int main() {
int st=clock();
int t;read(t);while(t--) solve(),clear();
cerr << (double)(clock()-st)/1e3 << endl;
return 0;
}