题面
题解
和 SDOI2015 序列统计 比较像
这个无非就是把乘改成了加, NTT 改成了 MTT
再加上了一个小小的容斥 : 拿所有方案减去不合法方案即可
Code
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
const int N = 1005;
const int mod = 20170408;
const double pi = acos(-1);
typedef long long ll;
using namespace std;
int n, m, lim, cnt, a[N], b[N], res1[N], res2[N], vis[20000005], stk[5000005], top, P, r[N];
struct Complex
{
double a, b;
Complex(double x = 0, double y = 0) { a = x, b = y; }
Complex operator + (const Complex &p) const { return Complex(a + p.a, b + p.b); }
Complex operator - (const Complex &p) const { return Complex(a - p.a, b - p.b); }
Complex operator * (const Complex &p) const { return Complex(a * p.a - b * p.b, a * p.b + b * p.a); }
} c[N], d[N], e[N], f[N], A[N], B[N], C[N], D[N];
template < typename T >
inline T read()
{
T x = 0, w = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * w;
}
void del(int *p)
{
for(int i = 2; i <= m; i++)
{
if(!vis[i]) stk[++top] = i, b[i % P]--;
for(int j = 1; i * stk[j] <= m; j++)
{
vis[i * stk[j]] = 1;
if(!(i % stk[j])) break;
}
}
}
void fft(Complex *p, int opt)
{
for(int i = 0; i < lim; i++) if(i < r[i]) swap(p[i], p[r[i]]);
for(int i = 1; i < lim; i <<= 1)
{
Complex rt = Complex(cos(pi / i), opt * sin(pi / i));
for(int j = 0; j < lim; j += (i << 1))
{
Complex w = Complex(1, 0);
for(int k = j; k < j + i; k++, w = w * rt)
{
Complex x = p[k], y = w * p[k + i];
p[k] = x + y, p[k + i] = x - y;
}
}
}
if(opt == -1)
{
for(int i = 0; i < lim; i++)
p[i].a = (ll) (p[i].a / lim + 0.5) % mod;
}
}
int sum(int x, int y) { return (1ll * x << y) % mod; }
void mul(int *a, int *b, int *ans)
{
for(int i = 0; i < lim; i++)
{
c[i].a = a[i] >> 15, c[i].b = 0, d[i].a = a[i] & 32767, d[i].b = 0;
e[i].a = b[i] >> 15, e[i].b = 0, f[i].a = b[i] & 32767, f[i].b = 0;
}
fft(c, 1), fft(d, 1), fft(e, 1), fft(f, 1);
for(int i = 0; i < lim; i++)
{
A[i] = c[i] * e[i], B[i] = c[i] * f[i];
C[i] = e[i] * d[i], D[i] = d[i] * f[i];
}
fft(A, -1), fft(B, -1), fft(C, -1), fft(D, -1);
for(int i = 0; i < lim; i++)
ans[i] = (((1ll * sum((ll) A[i].a % mod, 30) + sum((ll) B[i].a % mod, 15)) % mod
+ sum((ll) C[i].a % mod, 15)) % mod + sum((ll) D[i].a % mod, 0)) % mod;
for(int i = 0; i < lim; i++)
ans[i] = (1ll * ans[i] + ans[i + P]) % mod, ans[i + P] = 0;
}
int main()
{
n = read <int> (), m = read <int> (), P = read <int> ();
for(lim = 1; lim <= 2 * P; lim <<= 1, cnt++); cnt--;
for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << cnt);
for(int i = 0; i < P; i++)
b[i] = a[i] = m / P + (i && i <= m % P);
del(b);
res1[0] = 1, res2[0] = 1;
while(n)
{
if(n & 1)
mul(res1, a, res1), mul(res2, b, res2);
mul(a, a, a), mul(b, b, b);
n >>= 1;
}
printf("%d\n", (res1[0] - res2[0] + mod) % mod);
return 0;
}