description

小w这学期选了门图论课,他在学习点着色的知识。他现在得到了一张无向图,并希望在这张图上使用最多n种颜色给每个节点染色,使得任意一条边关联的两个节点颜色不同。

小w获得一张n个节点m条边的基图,并得到了一份神秘代码。他会根据这份代码的内容构建完整的无向图。

while(1){

int modify_tag=0;

for(int x=1;x<=n;x++)

for(int y=x+1;y<=n;y++)

for(int z=y+1;z<=n;z++)

if(edge(x,y)∈G && edge(x,z)∈G){

add edge(y,z) to G

modify_tag=1;

}

if(modify_tag==0) break;

}

即对于图上的任意三元组x<y<z,若(x,y),(x,z)在图中则在图上加上一条(y,z)的边,直至无法加边为止。

小w想要知道使用n种颜色给这张基图生成的完整无向图的染色方案数。小w太菜了,他无力解决这个难题,于是只好把它交给了你。


analysis

  • 首先有一个结论,\(ans=\prod n-g[i]\),\(g[i]\)表示与\(i\)相连、编号比\(i\)大的节点数量

  • 如果从大到小染色染到第\(i\)位,\(g[i]\)已经染过色了且\(i\)点和这\(g[i]\)个点构成完全图

  • 那么这\(g[i]\)个点颜色都不相同,\(i\)位染色的方案数就是\(n-g[i]\),以此类推

  • 暴力做法是\(O(n^2)\)把这个图建出来,搞出每一个\(g[i]\),全部乘起来

  • 其实可以考虑维护\(n\)棵线段树,存储第\(i\)位向后的连边情况

  • 编号从小到大合并,假设合并到第\(k\)位,区间查询比\(k\)大有多少已经连边的点,乘进答案

  • 然后再找出比\(k\)大且最小的存在的编号,把\(k\)的线段树合并到该编号的线段树就可以了

  • 也可以用\(set\)维护连点的集合,合并两集合就启发式合并,但这个我还不是很懂


code

线段树合并

#pragma GCC optimize("O3")
#pragma G++ optimize("O3")
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define MAXN 100005
#define MAX MAXN*50
#define ha 998244353
#define ll long long
#define reg register ll
#define fo(i,a,b) for (reg i=a;i<=b;++i)
#define fd(i,a,b) for (reg i=a;i>=b;--i) using namespace std; ll root[MAXN],tr[MAX],mn[MAX],lson[MAX],rson[MAX];
ll n,m,tot,ans=1; inline ll read()
{
ll x=0,f=1;char ch=getchar();
while (ch<'0' || '9'<ch){if (ch=='-')f=-1;ch=getchar();}
while ('0'<=ch && ch<='9')x=x*10+ch-'0',ch=getchar();
return x*f;
}
inline ll max(ll x,ll y){return x>y?x:y;}
inline ll min(ll x,ll y){return x<y?x:y;}
inline ll newnode()
{
++tot,mn[tot]=ha;
return tot;
}
inline void modify(ll &t,ll l,ll r,ll x)
{
if (!t)t=newnode();
++tr[t],mn[t]=min(mn[t],x);
if (l==r)return;
ll mid=(l+r)>>1;
if (x<=mid)modify(lson[t],l,mid,x);
else modify(rson[t],mid+1,r,x);
}
inline ll query_sum(ll t,ll l,ll r,ll x,ll y)
{
if (!t)return 0;
if (l==x && y==r)return tr[t];
ll mid=(l+r)>>1;
if (y<=mid)return query_sum(lson[t],l,mid,x,y);
else if (x>mid)return query_sum(rson[t],mid+1,r,x,y);
else return query_sum(lson[t],l,mid,x,mid)+query_sum(rson[t],mid+1,r,mid+1,y);
}
inline ll query_min(ll t,ll l,ll r,ll x,ll y)
{
if (!t)return ha;
if (l==x && y==r)return mn[t];
ll mid=(l+r)>>1;
if (y<=mid)return query_min(lson[t],l,mid,x,y);
else if (x>mid)return query_min(rson[t],mid+1,r,x,y);
else return min(query_min(lson[t],l,mid,x,mid),query_min(rson[t],mid+1,r,mid+1,y));
}
inline void merge(ll x,ll y,ll l,ll r)
{
if (l==r)return;ll mid=(l+r)>>1;
if (lson[x] && lson[y])merge(lson[x],lson[y],l,mid);
else if (lson[y])lson[x]=lson[y];
if (rson[x] && rson[y])merge(rson[x],rson[y],mid+1,r);
else if (rson[y])rson[x]=rson[y];
tr[x]=tr[lson[x]]+tr[rson[x]];
mn[x]=min(mn[lson[x]],mn[rson[x]]);
}
int main()
{
//freopen("T3.in","r",stdin);
freopen("graph.in","r",stdin);
freopen("graph.out","w",stdout);
n=read(),m=read();
fo(i,1,n)root[i]=newnode();
fo(i,1,m)
{
ll x=read(),y=read();
if (x>y)swap(x,y);
modify(root[x],1,n,y);
}
mn[0]=ha;
fo(i,1,n-1)
{
(ans*=n-query_sum(root[i],1,n,i+1,n))%=ha;
ll tmp=query_min(root[i],1,n,i+1,n);
if (tmp<=n)merge(root[tmp],root[i],1,n);
}
printf("%lld\n",ans*n%ha);
return 0;
}

set+启发式合并

#pragma GCC optimize("O3")
#pragma G++ optimize("O3")
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<set>
#define MAXN 100005
#define ha 998244353
#define ll long long
#define reg register ll
#define fo(i,a,b) for (reg i=a;i<=b;++i)
#define fd(i,a,b) for (reg i=a;i>=b;--i) using namespace std; set<ll>a[MAXN];
ll fa[MAXN];
ll n,m,ans=1,x,y; inline ll read()
{
ll x=0,f=1;char ch=getchar();
while (ch<'0' || '9'<ch){if (ch=='-')f=-1;ch=getchar();}
while ('0'<=ch && ch<='9')x=x*10+ch-'0',ch=getchar();
return x*f;
}
int main()
{
//freopen("T3.in","r",stdin);
freopen("graph.in","r",stdin);
freopen("graph.out","w",stdout);
n=read(),m=read();
fo(i,1,n)fa[i]=i;
fo(i,1,m)x=read(),y=read(),a[min(x,y)].insert(max(x,y));
/*
fo(i,1,n)if (a[i].size())
{
printf("%lld\n",i);
for (set<ll>::iterator j=a[i].begin();j!=a[i].end();++j)printf("%lld ",*j);
printf("\n\n");
}
*/
fo(i,1,n)
{
//printf("!!!%lld:%lld\n",i,fa[i]);
//printf("%lld\n",*a[fa[i]].begin());
if (*a[fa[i]].begin()==i)a[fa[i]].erase(a[fa[i]].begin())/*,printf("#@!!#@!!$\n")*/;
(ans*=n-a[fa[i]].size())%=ha;
//printf("%lld %lld\n",i,n-a[fa[i]].size());
if (a[fa[i]].size()!=1)
{
ll &x=fa[i],&y=fa[*a[fa[i]].begin()];
if (a[x].size()>a[y].size())swap(x,y);
for (set<ll>::iterator tmp=a[x].begin();tmp!=a[x].end();++tmp)a[y].insert(*tmp);
a[x].clear();
}
//printf("\n");
//printf("!!!%lld:%lld\n",i,fa[i]);
}
//fo(i,1,n)printf("%lld:%lld ",i,fa[i]);
printf("%lld\n",ans);
return 0;
}
05-18 18:56