ac自动机+dp

把串的每位翻转一下插入$ac$自动机

$dp[i][j][k]$表示走了$j$步到自动机上$i$节点是否包含了串 转移即可

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2005;
namespace ac
{
    struct data {
        int danger, fail;
        int c[2];
    } t[maxn];
    int root, size;
    int q[maxn * 10];
    void ins(char *s)
    {
        int now = root, l = strlen(s + 1);
     //   printf("s = %s\n", s + 1);
        for(int i = 1; i <= l; ++i)
        {
            if(!t[now].c[s[i] - '0']) t[now].c[s[i] - '0'] = ++size;
            now = t[now].c[s[i] - '0'];
        }
//        printf("now=%d\n", now);
        t[now].danger = 1;
    }
    void get_fail()
    {
        int l = 1, r = 0;
        for(int i = 0; i < 2; ++i) if(t[0].c[i]) q[++r] = t[0].c[i];
        while(l <= r) {
            int u = q[l++];
            t[u].danger |= t[t[u].fail].danger;
            for(int i = 0; i < 2; ++i) {
                int &v = t[u].c[i];
                if(!v) v = t[t[u].fail].c[i];
                else t[v].fail = t[t[u].fail].c[i], q[++r]=v;
            }
        }
    }
} using namespace ac;
int n, m;
long long dp[maxn][45][2];
char s[maxn];
int main() {
    int T; scanf("%d", &T);
    while(T--) {
        scanf("%d%d%s", &n, &m, s + 1);
        size = 0;
        ins(s);
        for(int i = 1; i <= n; ++i) {
            if(s[i] == '0') {
                s[i] = '1';
                ins(s);
                s[i] = '0';
            } else {
                s[i] = '0';
                ins(s);
                s[i] = '1';
            }
        }
        get_fail();
        memset(dp, 0, sizeof(dp));
        dp[root][0][0] = 1;
        for(int i = 0; i < m; ++i) {
            for(int j = 0; j <= size; ++j) {

                int x = t[j].c[0], y = t[j].c[1];

        //        printf("j = %d x = %d y = %d\n", j, x, y);
                if(!t[x].danger) dp[x][i + 1][0] += dp[j][i][0];
                else dp[x][i + 1][1] += dp[j][i][0];
                if(!t[y].danger) dp[y][i + 1][0] += dp[j][i][0];
                else dp[y][i + 1][1] += dp[j][i][0];
                dp[x][i + 1][1] += dp[j][i][1];
                dp[y][i + 1][1] += dp[j][i][1];
            }
        }
    //    printf("size = %d\n", size);
    //    for(int i = 0; i <= size; ++i)
    //        for(int j = 1; j <= m; ++j)
    //            printf("dp[%d][%d][0] = %lld dp[%d][%d][1] = %d\n", i, j, dp[i][j][0], i, j, dp[i][j][1]);
        long long ans = 0;
        for(int i = 0; i <= size; ++i)
            ans += dp[i][m][1];
        printf("%lld\n", ans);
        for(int i = 0; i <= size; ++i) {
            t[i].danger = t[i].fail = 0;
            t[i].c[0] = t[i].c[1] = 0;
        }
    }
    return 0;
}
View Code
01-15 04:05