Directed Roads

扫码查看

Directed Roads

题目链接:http://codeforces.com/contest/711/problem/D

dfs

刚开始的时候想歪了,以为同一个连通区域会有多个环,实际上每个点的出度为1,也就是说每个连通区域最多就只有一个环。

那么每一个连通区域的方法数就 = (2^环内边数-2)*(2^环外边数) [因为环内有两种情况形成圈,不可取],

总方法数 = 不同连通区域的方法数的乘积;

于是我把整个有向图先存储成无向图,用dfs判断该连通区域有没有环,再cls掉环外的边,之后再继续dfs...

代码如下:

 #include<cstdio>
#include<cstring>
#include<vector>
#include<iostream>
#define N 200005
#define M (int)(1e9+7)
#define special 9
using namespace std;
typedef long long LL;
struct nod{
LL edge;
LL to;
nod(LL a,LL b){
edge=a;
to=b;
}
};
vector<nod>node[N];
LL n;
LL vis[N];
LL dfs(LL index,LL num){
for(LL i=;i<node[index].size();++i){
LL e=node[index][i].edge,to=node[index][i].to;
if(vis[e]==-){
vis[index]=to;
LL temp=dfs(e,num+);
if(temp)return temp;
vis[index]=-;
}else if(vis[e]==to){
vis[index]=to;
vis[e]=special;
return num;
}
}
return ;
}
LL cls(LL index,LL num){
for(LL i=;i<node[index].size();++i){
vis[index]=-;
LL e=node[index][i].edge;
if(vis[e]==special)return num;
if(vis[e]!=-)
return cls(e,num+);
}
return ;
}
LL pow(LL a,LL b){
LL base=a,temp=;
while(b){
if(b&)temp=(temp*base)%M;
base=(base*base)%M;
b>>=;
}
return temp;
}
LL mod(LL a,LL b){
LL base=a,temp=;
while(b){
if(b&)temp=(temp+base)%M;
base=(base+base)%M;
b>>=;
}
return temp;
}
int main(void){
memset(vis,-,sizeof(vis));
LL res=;
scanf("%I64d",&n);
for(LL i=;i<=n;++i){
LL vertice;
//cin>>vertice;
scanf("%I64d",&vertice);
node[i].push_back(nod(vertice,));
node[vertice].push_back(nod(i,));
}
for(LL i=;i<=n;++i){
if(vis[i]==-){
LL cyc_temp=dfs(i,);
if(vis[i]!=special&&vis[i]!=-){
LL un_temp=cls(i,);
cyc_temp-=un_temp;
}
if(res==&&cyc_temp)res=pow(,cyc_temp)-;
else if(cyc_temp)res=mod(res,(pow(,cyc_temp)-));
}
}
LL un_sum=;
for(LL i=;i<=n;++i)
if(vis[i]==-)un_sum++;
if(res)res=mod(res,pow(,un_sum));
else res=pow(,un_sum);
//cout<<res<<endl;
printf("%I64d\n",res);
}

然而这样会T(想象一种坏的情况:只有一个连通区域,且环在末尾,这样差不多是O(n^2)的复杂度)

仔细想过后,其实不需要将有向图转化为无向图,因为每个点的出度为1,如果有环,那么有向图也必然成环,改进后复杂度就成了O(n)

代码如下:

 #include<cstdio>
#include<cstring>
#include<iostream>
#define N 200005
#define M (int)(1e9+7)
using namespace std;
typedef long long LL;
LL n,sum=;
LL a[N];
LL vis[N];
LL pow(LL a,LL b){
LL base=a,temp=;
while(b){
if(b&)temp=(temp*base)%M;
base=(base*base)%M;
b>>=;
}
return temp;
}
int main(void){
cin>>n;
LL res=n;
for(LL i=;i<=n;++i)cin>>a[i];
for(LL i=;i<=n;++i){
if(!vis[i]){
LL index=i;
while(){
vis[index]=i;
index=a[index];
if(vis[index])break;
}
if(vis[index]!=i)continue;
LL node=,temp=index;
while(){
node++;
temp=a[temp];
if(temp==index)break;
}
res-=node;
sum=(sum*(pow(,node)-))%M;
}
}
sum=(sum*pow(,res))%M;
cout<<sum<<endl;
}
05-11 15:24
查看更多