博客链接

里面有个下降幂应该是上升幂

还有个bk的式子省略了k^3

CODE

蛮短的

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5000005;
const int mod = 998244353;
int fac[MAXN], inv[MAXN];
inline void PreWork(int N) {
fac[0] = fac[1] = inv[0] = inv[1] = 1;
for(int i = 2; i <= N; ++i) {
fac[i] = 1ll * fac[i-1] * i % mod;
inv[i] = 1ll * (mod - mod/i) * inv[mod%i] % mod;
}
for(int i = 2; i <= N; ++i)
inv[i] = 1ll * inv[i-1] * inv[i] % mod;
}
inline int mul(int a, int b, int c) { return 1ll * a * b % mod * c % mod; }
inline int qpow(int a, int b) {
int re = 1;
while(b) {
if(b&1) re = 1ll * re * a % mod;
a = 1ll * a * a % mod; b >>= 1;
}
return re;
}
inline int C(int n, int m) { return n < m ? 0 : mul(fac[n], inv[m], inv[n-m]); }
int T, n, m, l, k, lim;
int a[MAXN], pre[MAXN], f[MAXN]; int main () {
PreWork(MAXN-5);
scanf("%d", &T);
while(T--) {
scanf("%d%d%d%d", &n, &m, &l, &k); lim = min(min(n,m),l);
if(k > lim) { puts("0"); continue; }
pre[0] = 1;
for(int i = 1; i <= lim; ++i) {
a[i] = (mul(n, m, l) - mul(n-i, m-i, l-i)) % mod;
pre[i] = 1ll * pre[i-1] * a[i] % mod;
}
pre[lim] = qpow(pre[lim], mod-2);
for(int i = lim-1; i >= 1; --i)
pre[i] = 1ll * pre[i+1] * a[i+1] % mod;
for(int i = 1; i <= lim; ++i)
f[i] = 1ll * mul(fac[n], fac[m], fac[l])
* mul(inv[n-i], inv[m-i], inv[l-i]) % mod
* pre[i] % mod;
int ans = 0; int sgn = 1;
for(int i = k; i <= lim; sgn=-sgn, ++i)
ans = (ans + 1ll * sgn * C(i, k) * f[i] % mod) % mod;
printf("%d\n", (ans + mod) % mod);
}
}
05-25 14:11