题意
考虑将\(s1\)和\(s2\)接在一起求出相同子串个数,再求出\(s1\)自己匹配的相同子串个数和\(s2\)自己匹配的相同子串个数减去即可。
如何求相同子串个数:
我们知道子串的集合即所有后缀的前缀集合,于是实际上答案就是:
\(\sum\limits_{i=1}^n\sum\limits_{j=i+1}^nlcp(sa_i,sa_j)\)
接下来就和这题相同了。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=400010;
int n,m,num,top;
int sa[maxn],rk[maxn],oldrk[maxn],id[maxn],tmpid[maxn],cnt[maxn],height[maxn],L[maxn],R[maxn],sta[maxn];
ll ans;
char s1[maxn],s2[maxn],s[maxn];
inline bool check(int x,int y,int k){return oldrk[x]==oldrk[y]&&oldrk[x+k]==oldrk[y+k];}
inline void get_sa(char* s,int len)
{
memset(cnt,0,sizeof(cnt));
num=300;
for(int i=1;i<=len;i++)cnt[rk[i]=s[i]]++;
for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
for(int i=len;i;i--)sa[cnt[rk[i]]--]=i;
for(int t=1;t<=len;t<<=1)
{
int tot=0;
for(int i=len-t+1;i<=len;i++)id[++tot]=i;
for(int i=1;i<=len;i++)if(sa[i]>t)id[++tot]=sa[i]-t;
tot=0;
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=len;i++)cnt[tmpid[i]=rk[id[i]]]++;
for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
for(int i=len;i;i--)sa[cnt[tmpid[i]]--]=id[i];
memcpy(oldrk,rk,sizeof(rk));
for(int i=1;i<=len;i++)rk[sa[i]]=check(sa[i-1],sa[i],t)?tot:++tot;
num=tot;
if(num==len)break;
}
for(int i=1,j=0;i<=len;i++)
{
if(j)j--;
while(s[i+j]==s[sa[rk[i]-1]+j])j++;
height[rk[i]]=j;
}
}
inline ll calc(int len)
{
ll res=0;
sta[top=1]=1;
for(int i=2;i<=len;i++)
{
while(top&&height[sta[top]]>=height[i])R[sta[top--]]=i;
L[i]=sta[top];sta[++top]=i;
}
while(top)R[sta[top--]]=len+1;
for(int i=2;i<=len;i++)res+=1ll*height[i]*(R[i]-i)*(i-L[i]);
return res;
}
int main()
{
scanf("%s%s",s1+1,s2+1);
n=strlen(s1+1),m=strlen(s2+1);
get_sa(s1,n);ans-=calc(n);
get_sa(s2,m);ans-=calc(m);
for(int i=1;i<=n;i++)s[i]=s1[i];
s[n+1]='#';
for(int i=1;i<=m;i++)s[n+i+1]=s2[i];
get_sa(s,n+m+1);ans+=calc(n+m+1);
printf("%lld\n",ans);
return 0;
}