一看好像会做的样子,就去做了一下,结果
猝不及防地T掉了
赶紧查了一下,没有死循环,复杂度也是对的,无果,于是翻了题解
题解没看懂,但是找到了标程,然后发现我被卡常了。。。
而且好像当时还过了前10个点啊。。这要真的是比赛稳稳的FST啊
小技巧:
逆元只需要求inv[i]和inv[i!],可以预处理出来
令md=1e9+7
则inv[1]=1
除此外inv[i]=(md-md/i)*inv[md%i]%md
令inv2[i]=inv[i!]
则inv2[n]=pow(n!,md-2)
除此外inv2[i]=inv2[i+1]*(i+1)%mod;
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<map>
#define md 1000000007
using namespace std;
typedef long long LL;
LL poww(LL a,LL b)
{
LL base=a,ans=;
while(b)
{
if(b&) ans=(ans*base)%md;
b>>=;
base=(base*base)%md;
}
return ans;
}
LL inv[],inv2[],jc[],sum,sumx,num[],ans,n;
char s1[],s2[];
void addx(LL x)
//num[x]++,同时维护sum,sumx
{
sum=sum*jc[num[x]]%md;
num[x]++;sumx++;
sum=sum*sumx%md;
sum=sum*inv2[num[x]]%md;
}
void delx(LL x)
{
sum=sum*jc[num[x]]%md;
sum=sum*inv[sumx]%md;#include
#include
#include
#include
#define md 1000000007
using namespace std;
typedef long long LL;
LL poww(LL a,LL b)
{
LL base=a,ans=1;
while(b)
{
if(b&1) ans=(ans*base)%md;
b>>=1;
base=(base*base)%md;
}
return ans;
}
LL inv[1000100],inv2[1000100],jc[1000100],sum,sumx,num[30],ans,n;
char s1[1000100],s2[1000100];
void addx(LL x)
//num[x]++,同时维护sum,sumx
{
sum=sum*jc[num[x]]%md;
num[x]++;sumx++;
sum=sum*sumx%md;
sum=sum*inv2[num[x]]%md;
}
void delx(LL x)
{
sum=sum*jc[num[x]]%md;
sum=sum*inv[sumx]%md;
sumx--;num[x]--;
sum=sum*inv2[num[x]]%md;
}
//s1的所有排列中小于s2的个数-s1的所有排列中小于s1的个数+1
int main()
{
LL i,j;
scanf("%s",s1+1);
scanf("%s",s2+1);n=strlen(s2+1);
jc[0]=1;
for(i=1;i<=1000000;i++) jc[i]=jc[i-1]*i%md;
inv[1]=1;
for(i=2;i<=1000000;i++) inv[i]=(md-md/i)*inv[md%i]%md;
for(i=0;i<=1000000;i++) inv2[i]=poww(jc[i],md-2);
for(i=1;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md;
for(i=1;i<=n;i++)
{
for(j=0;j<s2[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans+sum)%md;
addx(j);
}
if(num[s2[i]-'a']) delx(s2[i]-'a');
else break;
}
for(i=0;i<26;i++) num[i]=0;
for(i=1;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md;
for(i=1;i<=n;i++)
{
for(j=0;j<s1[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans-sum+md)%md;
addx(j);
}
if(num[s1[i]-'a']) delx(s1[i]-'a');
else break;
}
ans=(ans-1+md)%md;
printf("%lld",ans);
return 0;
} sumx--;num[x]--;
sum=sum*inv2[num[x]]%md;
}
//s1的所有排列中小于s2的个数-s1的所有排列中小于s1的个数+1
int main()
{
LL i,j;
scanf("%s",s1+);
scanf("%s",s2+);n=strlen(s2+);
jc[]=;
for(i=;i<=;i++) jc[i]=jc[i-]*i%md;
inv[]=;
for(i=;i<=;i++) inv[i]=(md-md/i)*inv[md%i]%md;
for(i=;i<=;i++) inv2[i]=poww(jc[i],md-);
for(i=;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=;i<;i++) sum=sum*inv2[num[i]]%md;
for(i=;i<=n;i++)
{
for(j=;j<s2[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans+sum)%md;
addx(j);
}
if(num[s2[i]-'a']) delx(s2[i]-'a');
else break;
}
for(i=;i<;i++) num[i]=;
for(i=;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=;i<;i++) sum=sum*inv2[num[i]]%md;
for(i=;i<=n;i++)
{
for(j=;j<s1[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans-sum+md)%md;
addx(j);
}
if(num[s1[i]-'a']) delx(s1[i]-'a');
else break;
}
ans=(ans-+md)%md;
printf("%lld",ans);
return ;
}
原来的代码(假的)
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<map>
#define md 1000000007
using namespace std;
typedef long long LL;
LL poww(LL a,LL b)
{
LL base=a,ans=;
while(b)
{
if(b&) ans=(ans*base)%md;
b>>=;
base=(base*base)%md;
}
return ans;
}
LL inv[],inv2[],jc[],sum,sumx,num[],ans,n;
char s1[],s2[];
void addx(LL x)
//num[x]++,同时维护sum,sumx
{
sum=sum*jc[num[x]]%md;
num[x]++;sumx++;
sum=sum*sumx%md;
sum=sum*inv2[num[x]]%md;
}
void delx(LL x)
{
sum=sum*jc[num[x]]%md;
sum=sum*inv[sumx]%md;
sumx--;num[x]--;
sum=sum*inv2[num[x]]%md;
}
//s1的所有排列中小于s2的个数-s1的所有排列中小于s1的个数+1
int main()
{
LL i,j;
scanf("%s",s1+);
scanf("%s",s2+);n=strlen(s2+);
jc[]=;
for(i=;i<=;i++) jc[i]=jc[i-]*i%md;
inv[]=;
for(i=;i<=;i++) inv[i]=(md-md/i)*inv[md%i]%md;
for(i=;i<=;i++) inv2[i]=poww(jc[i],md-);
for(i=;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=;i<;i++) sum=sum*inv2[num[i]]%md;
for(i=;i<=n;i++)
{
for(j=;j<s2[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans+sum)%md;
addx(j);
}
if(num[s2[i]-'a']) delx(s2[i]-'a');
else break;
}
for(i=;i<;i++) num[i]=;
for(i=;i<=n;i++) num[s1[i]-'a']++;
sum=jc[n];sumx=n;
for(i=;i<;i++) sum=sum*inv2[num[i]]%md;
for(i=;i<=n;i++)
{
for(j=;j<s1[i]-'a';j++)
if(num[j])
{
delx(j);
ans=(ans-sum+md)%md;
addx(j);
}
if(num[s1[i]-'a']) delx(s1[i]-'a');
else break;
}
ans=(ans-+md)%md;
printf("%lld",ans);
return ;
}