题目大意:在字符集大小为$m$的情况下,有多少种构造长度为$n$的字符串$s$的方案,使得$C(s)=k$。其中$C(s)$表示字符串$s$中出现次数最多的字符的出现次数。
对$998244353$取模,$n,m≤5\times 10^4$
如果你考虑去DP,你就lose了。
令$F(x)$表示满足$C(s)≤x$的方案数。
那么最终的答案显然为$F(k)-F(k-1)$。
这一题有一个非常优美的性质:对于每一种字符,允许的最多出现次数都是$k$。
那么,令$G_k(x)=\sum\limits_{i=0}^{k} \frac{1}{i!}x^i$
则有$F(k)=n![x^n]G_k^m(x)$
证明是显然的
写一个多项式快速幂的板子就过了。
#include<bits/stdc++.h>
#define M (1<<17)
#define L long long
#define MOD 998244353
#define G 3
using namespace std; L pow_mod(L x,L k){
L ans=;
while(k){
if(k&) ans=ans*x%MOD;
x=x*x%MOD; k>>=;
}
return ans;
} void change(L a[],int n){
for(int i=,j=;i<n-;i++){
if(i<j) swap(a[i],a[j]);
int k=n>>;
while(j>=k) j-=k,k>>=;
j+=k;
}
}
void NTT(L a[],int n,int on){
change(a,n);
for(int h=;h<=n;h<<=){
L wn=pow_mod(G,(MOD-)/h);
for(int j=;j<n;j+=h){
L w=;
for(int k=j;k<j+(h>>);k++){
L u=a[k],t=w*a[k+(h>>)]%MOD;
a[k]=(u+t)%MOD;
a[k+(h>>)]=(u-t+MOD)%MOD;
w=w*wn%MOD;
}
}
}
if(on==-){
L inv=pow_mod(n,MOD-);
for(int i=;i<n;i++) a[i]=a[i]*inv%MOD;
reverse(a+,a+n);
}
} void getinv(L a[],L b[],int n){
if(n==){b[]=pow_mod(a[],MOD-); return;}
static L c[M],d[M];
memset(c,,n<<); memset(d,,n<<);
getinv(a,c,n>>);
for(int i=;i<n;i++) d[i]=a[i];
NTT(d,n<<,); NTT(c,n<<,);
for(int i=;i<(n<<);i++) b[i]=(*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD;
NTT(b,n<<,-);
for(int i=;i<n;i++) b[n+i]=;
} void qiudao(L a[],L b[],int n){
memset(b,,sizeof(b));
for(int i=;i<n;i++) b[i-]=i*a[i]%MOD;
}
void jifen(L a[],L b[],int n){
memset(b,,sizeof(b));
for(int i=;i<n;i++) b[i+]=a[i]*pow_mod(i+,MOD-)%MOD;
} void getln(L a[],L b[],int n){
static L c[M],d[M];
memset(c,,n<<); memset(d,,n<<);
qiudao(a,c,n); getinv(a,d,n);
NTT(c,n<<,); NTT(d,n<<,);
for(int i=;i<(n<<);i++) c[i]=c[i]*d[i]%MOD;
NTT(c,n<<,-);
jifen(c,b,n);
} void getexp(L a[],L b[],int n){
if(n==){b[]=; return;}
static L lnb[M]; memset(lnb,,n<<);
getexp(a,b,n>>); getln(b,lnb,n);
for(int i=;i<n;i++) lnb[i]=(a[i]-lnb[i]+MOD)%MOD,b[i+n]=;
lnb[n]=;
lnb[]=(lnb[]+)%MOD;
NTT(lnb,n<<,); NTT(b,n<<,);
for(int i=;i<(n<<);i++) b[i]=b[i]*lnb[i]%MOD;
NTT(b,n<<,-);
for(int i=;i<n;i++) b[i+n]=;
} L a[M]={},b[M]={};
L fac[M]={},invfac[M]={};
int n,k,m; L solve(){
memset(a,,sizeof(a));
memset(b,,sizeof(b));
int nn=; while(nn<=n) nn<<=;
for(int i=;i<=m;i++) a[i]=invfac[i];
L hh=a[],invhh=pow_mod(hh,MOD-);
for(int i=;i<nn;i++) a[i]=a[i]*invhh%MOD;
getln(a,b,nn);
for(int i=;i<nn;i++) b[i]=b[i]*k%MOD;
getexp(b,a,nn);
hh=pow_mod(hh,k);
for(int i=;i<nn;i++) a[i]=a[i]*hh%MOD;
return a[n];
} int main(){
scanf("%d%d%d",&n,&k,&m);
fac[]=; for(int i=;i<M;i++) fac[i]=fac[i-]*i%MOD;
invfac[M-]=pow_mod(fac[M-],MOD-);
for(int i=M-;~i;i--) invfac[i]=invfac[i+]*(i+)%MOD;
L res1=solve();
m--;
L res2=solve();
cout<<(res1-res2+MOD)*fac[n]%MOD<<endl;
}