问题:
求多项式\(g(x)\),满足\(f(x)*g(x)\equiv 1\pmod{x^n}\)

Q1:这是什么意思?

A1:即\(f(x)*g(x)\)最后只有常数项系数为\(1\),其余系数都为\(0\)

Q2:为什么要求模\(x^n\)

A2:这是为了将\(n\)次方以上的项全都除去,否则\(g(x)\)会有无穷多项。

用倍增的方法可以\(O(n\log n)\)求出。

假设我们求出了\(h(x)\),满足\(f(x)*h(x)\equiv 1\pmod{x^{\lceil\frac{n}{2}\rceil}}\)

我们要求\(g(x)\),满足\(f(x)*g(x)\equiv 1\pmod{x^n}\)

进行推导:
\(f(x)*h(x)\equiv 1\pmod{x^{\lceil\frac{n}{2}\rceil}}\)
\(f(x)*g(x)\equiv 1\pmod{x^n}\)
\(f(x)*g(x)\equiv f(x)*h(x)\pmod{x^{\lceil\frac{n}{2}\rceil}}\)(为什么正确需要考虑模的意义:忽略更高次项)
\(f(x)*(g(x)-h(x))\equiv 0\pmod{x^{\lceil\frac{n}{2}\rceil}}\)
\(g(x)-h(x)\equiv 0\pmod{x^{\lceil\frac{n}{2}\rceil}}\)
\((g(x)-h(x))^2\equiv 0\pmod{x^{\lceil\frac{n}{2}\rceil}}\)
\(g(x)^2-2*g(x)*h(x)+h(x)^2\equiv 0\pmod{x^{\lceil\frac{n}{2}\rceil}}\)
此时有个结论:
\(g(x)^2-2*g(x)*h(x)+h(x)^2\equiv 0\pmod{x^n}\)

因为平方前\((g(x)-h(x))\)\([0,\lceil\frac{n}{2}\rceil-1]\)的系数都为\(0\),考虑\(i\in[\lceil\frac{n}{2}\rceil,n-1]\)\(C_i=\sum\limits_{j<=i}A_j*B_{i-j}\),其中\(i,i-j\)中必定有一个属于\([0,\lceil\frac{n}{2}\rceil-1]\),因此\(C_i=0\)

两边同乘\(f(x)\)
\(g(x)-2*h(x)+f(x)h(x)^2\equiv 0\pmod{x^n}\)
\(g(x)\equiv2*h(x)-f(x)*h(x)^2\pmod{x^n}\)
\(g(x)\equiv h(x)*(2-f(x)*h(x))\pmod{x^n}\)

于是可以递归分治求了。

模板题

code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4*1e5+10;
const ll mod=998244353;
const ll g=3;
const ll invg=332748118;
int n,lim,len;
int pos[maxn];
ll a[maxn],b[maxn],c[maxn];
inline ll read()
{
    char c=getchar();ll res=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9')res=res*10+c-'0',c=getchar();
    return res*f;
}
inline ll power(ll x,ll k)
{
    ll res=1;
    while(k)
    {
        if(k&1)res=res*x%mod;
        x=x*x%mod;k>>=1;
    }
    return res;
}
void NTT(ll* a,int op)
{
    for(int i=0;i<lim;i++)if(i<pos[i])swap(a[i],a[pos[i]]);
    for(int l=1;l<lim;l<<=1)
    {
        ll wn=power(op==1?g:invg,(mod-1)/(l<<1));
        for(int i=0;i<lim;i+=l<<1)
        {
            ll w=1;
            for(int j=0;j<l;j++,w=w*wn%mod)
            {
                ll x=a[i+j],y=w*a[i+l+j];
                a[i+j]=(x+y)%mod;a[i+l+j]=((x-y)%mod+mod)%mod;
            }
        }
    }
    if(op==1)return;
    ll inv=power(lim,mod-2);
    for(int i=0;i<lim;i++)a[i]=a[i]*inv%mod;
}
void solve(ll* a,ll* b,int n)
{
    if(n==1){b[0]=power(a[0],mod-2);return;}
    solve(a,b,(n+1)>>1);
    lim=1,len=0;
    while(lim<(n<<1))lim<<=1,len++;
    for(int i=0;i<lim;i++)pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
    for(int i=0;i<n;i++)c[i]=a[i];
    for(int i=n;i<lim;i++)c[i]=0;
    NTT(c,1);NTT(b,1);
    for(int i=0;i<lim;i++)b[i]=((2ll-b[i]*c[i]%mod)%mod+mod)%mod*b[i]%mod;
    NTT(b,-1);
    for(int i=n;i<lim;i++)b[i]=0;
}
int main()
{
    n=read();
    for(int i=0;i<n;i++)a[i]=(read()%mod+mod)%mod;
    solve(a,b,n);
    for(int i=0;i<n;i++)printf("%lld ",(b[i]%mod+mod)%mod);
    return 0;
}
12-26 12:47