题目描述
有一个只包含小写字母,长度为 $n$ 的字符串 $S$ 。有一些字母是好的,剩下的是坏的。
定义一个子串 $S_{l\ldots r}$是好的,当且仅当这个子串包含不超过 $k$ 个坏的字母。
求有多少个不同的满足以下要求的字符串 $T$ :
- $T$ 作为 $S$ 的子串出现过。
- 存在一个 $T$ 出现的位置 $[l,r]$ ,满足 $S_{l\ldots r}$ 是好的。
输入格式
第一行有一个字符串 $S$ 。
第二行有一个字符串 $B$ 。若 $B_i=‘1’$ 则表示 $S_i$ 是好的,否则表示 $S_i$ 是坏的。
第三行有一个整数 $k$ 。
输出格式
一个整数:答案。
数据范围与提示
子任务 $1$($10$ 分):$n\leq 10$。
子任务 $2$($10$ 分):$n\leq 100$。
子任务 $3$($10$ 分):$n\leq 1000$。
子任务 $4$($10$ 分):$n\leq 100000,k=n$。
子任务 $5$($10$ 分):$n\leq 100000,k=0$。
子任务 $6$($20$ 分):$n\leq 100000$,若$S_i=S_j$,则$B_i=B_j$。
子任务 $7$($30$ 分):$n\leq 100000$。
对于 $100\%$ 的数据:$1\leq n\leq {10}^5,0\leq k\leq {10}^5$, $S$ 只包含小写字母。
题目来源:全是水题的GDOI模拟赛 by yww
题目分析
定位:比较模板的SAM题(然而我一个月前并不会SAM)
题意即求满足一定条件的若干个字符串里本质不同的子串个数。当然这里的“一定条件”比较特殊,是连续的一段子串。(如果这里的“一定条件”字符串是若干个互不相干的串,似乎就需要“广义后缀自动机”来处理了)
既然对于每一个新增的节点,其合法的子串都有一个左边界,那么在SAM里处理的时候,就可以对每一节点加一个权值$mx[u]$表示该节点的最长合法扩展长度。这个权值的作用就在于建完自动机后的$calc()$,原先数本质不同的子串个数是这样的: ans += len[p]-len[fa[p]]; 现在就是 ans += std::max(std::min(mx[p], len[p])-len[fa[p]], ); 。记得注意一下extend里对mx的转移。
#include<bits/stdc++.h>
const int maxn = ; int n,lim,sum[maxn],size[maxn],cnt[maxn],pos[maxn];
long long ans;
struct SAM
{
int ch[maxn][],fa[maxn],len[maxn],mx[maxn],lst,tot;
void init()
{
lst = tot = ;
}
void extend(int c, int v)
{
int p = lst, np = ++tot;
lst = np, len[np] = len[p]+, mx[np] = v; //len[p+1] Here 居然打成标红的这个
for (; p&&!ch[p][c]; p=fa[p]) ch[p][c] = np;
if (!p) fa[np] = ;
else{
int q = ch[p][c];
if (len[p]+==len[q]) fa[np] = q;
else{
int nq = ++tot;
len[nq] = len[p]+, mx[nq] = mx[q];
memcpy(ch[nq], ch[q], sizeof ch[q]);
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for (; p&&ch[p][c]==q; p=fa[p])
ch[p][c] = nq;
}
}
}
void calc()
{
for (int i=; i<=tot; i++) ++cnt[len[i]];
for (int i=; i<=tot; i++) cnt[i] += cnt[i-];
for (int i=; i<=tot; i++) pos[cnt[len[i]]] = i, --cnt[len[i]];
for (int i=tot; i; i--)
{
int p = pos[i];
mx[fa[p]] = std::max(mx[fa[p]], mx[p]);
ans += std::max(std::min(mx[p], len[p])-len[fa[p]], );
}
}
}f;
char s[maxn],t[maxn]; int main()
{
scanf("%s%s%d",s+,t+,&lim);
n = strlen(s+);
f.init();
for (int i=, j=; i<=n; i++)
{
sum[i] = (t[i]=='')+sum[i-];
while (sum[i]-sum[j] > lim) ++j;
f.extend(s[i]-'a'+, i-j);
}
f.calc();
printf("%lld\n",ans);
return ;
}
END