就向书上说得那样,如果模式串P的第i行出现在文本串T的第r行第c列,则cnt[r-i][c]++;

还有个很棘手的问题就是模式串中可能会有相同的串,所以用repr[i]来记录第i个模式串P[i]第一次出现的位置。如果repr[i] == i,说明这个模式串之前没有重复过,可以加进自动机里去。有重复的话,把这些重复的模式串组织成一个链表,用next把它们连接起来。

所以在统计cnt的时候,匹配到的模式串可能会作为匹配的第i行,也可能是next[i]行,next[next[i]]行等等。

 #include <cstdio>
#include <cstring>
#include <queue>
using namespace std; int n, m, x, y, tr;
const int maxx = + ;
const int maxn = + ;
const int maxnode = + ;
const int sigma_size = ;
char T[maxn][maxn], P[maxx][maxx];
int cnt[maxn][maxn];
int repr[maxx];
int next[maxx]; struct AhoCorasickAutomata
{
int ch[maxnode][sigma_size];
int f[maxnode];
int last[maxnode];
int val[maxnode];
int sz; void init() { sz = ; memset(ch[], , sizeof(ch[])); } inline int idx(char c) { return c - 'a'; } void insert(char* s, int v)
{
int u = , n = strlen(s);
for(int i = ; i < n; i++)
{
int c = idx(s[i]);
if(!ch[u][c])
{
memset(ch[sz], , sizeof(ch[sz]));
val[sz] = ;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] = v;
} void match(int i, int j)
{
int c = i - y + ;
int pr = repr[val[j] - ];
while(pr >= )
{
if(tr - pr >= ) cnt[tr-pr][c]++;
pr = next[pr];
}
} void print(int i, int j)
{//在文本串的第i列匹配到单词节点j
if(j)
{
match(i, j);
print(i, last[j]);
}
} void find(char* T)
{
int j = , n = strlen(T);
for(int i = ; i < n; i++)
{
int c = idx(T[i]);
while(j && !ch[j][c]) j = f[j];
j = ch[j][c];
if(val[j]) print(i, j);
else if(val[last[j]]) print(i, last[j]);
}
} void getFail()
{
queue<int> q;
f[] = ;
for(int c = ; c < sigma_size; c++)
{
int u = ch[][c];
if(u) { f[u] = ; last[u] = ; q.push(u); }
}
while(!q.empty())
{
int r = q.front(); q.pop();
for(int c = ; c < sigma_size; c++)
{
int u = ch[r][c];
if(!u) continue;
q.push(u);
int v = f[r];
while(v && !ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}
}ac; int main()
{
//freopen("in.txt", "r", stdin); int test;
scanf("%d", &test);
while(test--)
{
scanf("%d%d", &n, &m);
for(int i = ; i < n; i++) scanf("%s", T[i]);
scanf("%d%d", &x, &y);
ac.init();
for(int i = ; i < x; i++)
{
repr[i] = i;
next[i] = -;
scanf("%s", P[i]);
for(int j = ; j < i; j++) if(strcmp(P[i], P[j]) == )
{
repr[i] = j;
next[i] = next[j];
next[j] = i;
break;
}
if(repr[i] == i) ac.insert(P[i], i+);
}
ac.getFail();
memset(cnt, , sizeof(cnt));
for(tr = ; tr < n; tr++) ac.find(T[tr]); int ans = ;
for(int i = ; i < n; i++)
for(int j = ; j < m; j++)
if(cnt[i][j] == x) ans++;
printf("%d\n", ans);
} return ;
}

代码君

04-21 00:11