题意:给出多个加密的模式串,和多个待匹配的串,问每个串里出现了多少种模式串。加密方法是把每3bytes加密成按6bits一个对应成4个字符,对应方法题里给了。
分析:除了解密之外,基本是个赤裸裸的AC自动机。这题要注意有多个模式串要进自动机,所以自动机的vis数组要每次清零。
#include <cstdio>
#include <queue>
#include <cstring>
#include <cctype>
using namespace std; #define D(x) const int MAX_CHILD_NUM = ;
const int MAX_NODE_NUM = * + ;
const int MAX_LEN = + ; char st[MAX_LEN];
int st2[MAX_LEN];
bool vis[MAX_NODE_NUM]; struct Trie
{
int next[MAX_NODE_NUM][MAX_CHILD_NUM];
int fail[MAX_NODE_NUM];
int count[MAX_NODE_NUM];
int node_cnt;
int root; void init()
{
node_cnt = ;
root = newnode();
} int newnode()
{
for (int i = ; i < MAX_CHILD_NUM; i++)
next[node_cnt][i] = -;
count[node_cnt++] = ;
return node_cnt - ;
} int get_id(int a)
{
return a;
} void insert(int buf[], int id)
{
int now = root;
for (int i = ; buf[i] != -; i++)
{
int id = get_id(buf[i]);
if (next[now][id] == -)
next[now][id] = newnode();
now = next[now][id];
}
count[now]++;
} void build()
{
queue<int>Q;
fail[root] = root;
for (int i = ; i < MAX_CHILD_NUM; i++)
if (next[root][i] == -)
next[root][i] = root;
else
{
fail[next[root][i]] = root;
Q.push(next[root][i]);
}
while (!Q.empty())
{
int now = Q.front();
Q.pop();
for (int i = ; i < MAX_CHILD_NUM; i++)
if (next[now][i] == -)
next[now][i] = next[fail[now]][i];
else
{
fail[next[now][i]]=next[fail[now]][i];
Q.push(next[now][i]);
}
}
} int query(int buf[])
{
int now = root;
int res = ;
for (int i = ; buf[i] != -; i++)
{
now = next[now][get_id(buf[i])];
int temp = now;
while (temp != root && !vis[temp])
{
res += count[temp];
// optimization: prevent from searching this fail chain again.
//also prevent matching again.
vis[temp] = true;
temp = fail[temp];
}
}
return res;
} void debug()
{
for(int i = ;i < node_cnt;i++)
{
printf("id = %3d,fail = %3d,end = %3d,chi = [",i,fail[i],count[i]);
for(int j = ;j < MAX_CHILD_NUM;j++)
printf("%2d",next[i][j]);
printf("]\n");
}
}
}ac; int n, m; int get_value(char ch)
{
if (isupper(ch))
return ch - 'A';
if (islower(ch))
return ch - 'a' + ;
if (isdigit(ch))
return ch - '' + ;
if (ch == '+')
return ;
return ;
} void transform(char *st, int *st2)
{
int len = strlen(st);
int len2 = len * / ;
for (int i = ; i < len; i += )
{
int a = ;
for (int j = ; j < ; j++)
{
a = (a << ) + get_value(st[i + j]);
D(printf("**%d\n", a));
} for (int j = ; j >= ; j--)
{
st2[i * / + j] = a % ( << );
a >>= ;
D(printf("**%d\n", st2[i * / + j]));
}
}
while (st[len - ] == '=')
{
len--;
len2--;
}
st2[len2] = -;
D(puts("#"));
for (int i = ; i < len2; i++)
{
D(printf("%d ", st2[i]));
}
D(puts(""));
} void input()
{
for (int i = ; i <= n; i++)
{
scanf("%s", st);
transform(st, st2);
ac.insert(st2, i);
}
ac.build();
scanf("%d", &m);
for (int i = ; i < m; i++)
{
scanf("%s", st);
transform(st, st2);
memset(vis, , sizeof(vis));
printf("%d\n", ac.query(st2));
}
puts("");
} int main()
{ while (scanf("%d", &n) != EOF)
{
ac.init();
input();
}
return ;
}