题面

题解

SDOI2015 序列统计 比较像

这个无非就是把乘改成了加, NTT 改成了 MTT

再加上了一个小小的容斥 : 拿所有方案减去不合法方案即可

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
const int N = 1005;
const int mod = 20170408;
const double pi = acos(-1);
typedef long long ll;
using namespace std;

int n, m, lim, cnt, a[N], b[N], res1[N], res2[N], vis[20000005], stk[5000005], top, P, r[N];
struct Complex
{
    double a, b;
    Complex(double x = 0, double y = 0) { a = x, b = y; }
    Complex operator + (const Complex &p) const { return Complex(a + p.a, b + p.b); }
    Complex operator - (const Complex &p) const { return Complex(a - p.a, b - p.b); }
    Complex operator * (const Complex &p) const { return Complex(a * p.a - b * p.b, a * p.b + b * p.a); }
} c[N], d[N], e[N], f[N], A[N], B[N], C[N], D[N];

template < typename T >
inline T read()
{
    T x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * w;
}

void del(int *p)
{
    for(int i = 2; i <= m; i++)
    {
        if(!vis[i]) stk[++top] = i, b[i % P]--;
        for(int j = 1; i * stk[j] <= m; j++)
        {
            vis[i * stk[j]] = 1;
            if(!(i % stk[j])) break;
        }
    }
}

void fft(Complex *p, int opt)
{
    for(int i = 0; i < lim; i++) if(i < r[i]) swap(p[i], p[r[i]]);
    for(int i = 1; i < lim; i <<= 1)
    {
        Complex rt = Complex(cos(pi / i), opt * sin(pi / i));
        for(int j = 0; j < lim; j += (i << 1))
        {
            Complex w = Complex(1, 0);
            for(int k = j; k < j + i; k++, w = w * rt)
            {
                Complex x = p[k], y = w * p[k + i];
                p[k] = x + y, p[k + i] = x - y;
            }
        }
    }
    if(opt == -1)
    {
        for(int i = 0; i < lim; i++)
            p[i].a = (ll) (p[i].a / lim + 0.5) % mod;
    }
}

int sum(int x, int y) { return (1ll * x << y) % mod; }

void mul(int *a, int *b, int *ans)
{
    for(int i = 0; i < lim; i++)
    {
        c[i].a = a[i] >> 15, c[i].b = 0, d[i].a = a[i] & 32767, d[i].b = 0;
        e[i].a = b[i] >> 15, e[i].b = 0, f[i].a = b[i] & 32767, f[i].b = 0;
    }
    fft(c, 1), fft(d, 1), fft(e, 1), fft(f, 1);
    for(int i = 0; i < lim; i++)
    {
        A[i] = c[i] * e[i], B[i] = c[i] * f[i];
        C[i] = e[i] * d[i], D[i] = d[i] * f[i];
    }
    fft(A, -1), fft(B, -1), fft(C, -1), fft(D, -1);
    for(int i = 0; i < lim; i++)
        ans[i] = (((1ll * sum((ll) A[i].a % mod, 30) + sum((ll) B[i].a % mod, 15)) % mod
                   + sum((ll) C[i].a % mod, 15)) % mod + sum((ll) D[i].a % mod, 0)) % mod;
    for(int i = 0; i < lim; i++)
        ans[i] = (1ll * ans[i] + ans[i + P]) % mod, ans[i + P] = 0;
}

int main()
{
    n = read <int> (), m = read <int> (), P = read <int> ();
    for(lim = 1; lim <= 2 * P; lim <<= 1, cnt++); cnt--;
    for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << cnt);
    for(int i = 0; i < P; i++)
        b[i] = a[i] = m / P + (i && i <= m % P);
    del(b);
    res1[0] = 1, res2[0] = 1;
    while(n)
    {
        if(n & 1)
            mul(res1, a, res1), mul(res2, b, res2);
        mul(a, a, a), mul(b, b, b);
        n >>= 1;
    }
    printf("%d\n", (res1[0] - res2[0] + mod) % mod);
    return 0;
}
12-18 07:14