prim算法是用来求最小生成树的算法,要两个数组来维护需要的信息分别是vis数组和dis数组
vis[i]用来表示结点i是否已经加入最小生成树,0为还没加入最小生成树,1为已经加入最小生成树
dis[i]用来表示连接结点i的边的最小权值是多少,初始化为一个很大的数字0x3f3f3f3f
朴素版 时间复杂度 n²
prim算法思路如下:
1.随机选取一个起始点i,并让dis[i] = 0
2.选取一个dis最小,且没有加入最小生成树的点u(第一次选取的是起始点)
3.将选取到的点u加入最小生成树里,sum += dis[u],并让vis[u] = 1,表示该点已经加入最小生成树
4.对于和u点连接的每个点v,判断是否已经加入最小生成树里,如果是就跳过该点,如果不是就判断dis[v]是否大于u和v之间的权值,如果是说明找到一个把v点加入最小生成树的更短的边,更新dis[v],让dis[v] = u和v之间的权值
5.重复2到4的步骤,直到所有的点都加入最小生成树,如果在所有的点还没被选完,且发现u = 0
说明该图不是连通图,不存在最小生成树
优先队列优化版 时间复杂度nlogn
优化的是第2步,第2步朴素的做法是把所有的点都遍历一遍来选择一个dis最小的点,可以使用一个优先队列来储存dis,每次就取优先队列的队头就可以了,时间复杂度由n变为logn
来一个例子
1.选择1号点为起始点,dis[1] = 0
2.把1号点加入最小生成树,sum += dis[1]
3.对于和1号点相连的 2号 3号 6号点
由于vis[2] = 0,dis[2] > 3 所以更新dis[2] dis[2] = 3
由于vis[3] = 0,dis[3] > 8 所以更新dis[3] dis[3] = 8
由于vis[6] = 0,dis[6] > 10 所以更新dis[6] dis[6] = 10
4.选择一个当前dis最小的一个结点,即2号点,把2号点加入最小生成树里,vis[2] = 1, sum += dis[2]
5.对于和2号点相连的 1号 3号
由于vis[1] = 1,该点已经加入最小生成树,所以跳过该点
由于vis[3] = 0,dis[3] < 14 所以不用更新dis[3]
6.继续选择dis最小的点3号加入最小生成树里
重复选择直到所有点都选完,或者选点的时候已经没有点可以选择了但所有的点还没被选完说明该图是不连通的
代码
洛谷P3366 【模板】最小生成树
使用vector来存图
朴素版
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 50010;
const int MAX = 0x3f3f3f3f;
int n,m,vis[MAXN],dis[MAXN];
struct edge{
int v,w;
bool operator < (const edge & a) const{
return w > a.w;
}
};
vector<edge> v[MAXN];
void prim(){
ll sum = 0; // 最小生成树的和
//int cnt = 0; // 加入生成树的点的个数
memset(dis, MAX, sizeof(dis)); //初始化为无穷大
dis[1] = 0; // 随机选取一个点
for(int k = 0; k < n; k++){
int min = MAX;
int u = 0;
for(int i = 1; i <= n; i++){ //选取一个dis最小点
if(vis[i]) continue;
if(dis[i] < min){
min = dis[i];
u = i;
}
}
if(u == 0){ //该图不是一个连通图
cout << "orz";
return;
}
sum += dis[u]; //加入到最小生成树里
vis[u] = 1; //标记
for(edge x : v[u]){
if(!vis[x.v] && dis[x.v] > x.w){
dis[x.v] = x.w;
}
}
}
cout << sum;
return;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin >> n >> m;
int t1,t2,w;
for(int i = 0; i < m; i++){
cin >> t1 >> t2 >> w;
v[t1].push_back({t2,w});
v[t2].push_back({t1,w});
}
prim();
}
优先队列版本
链式前向星存图
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 50010;const int MAX = 0x3f3f3f3f;
int n,m,cnt,pre[MAXN],vis[MAXN],dis[MAXN];
struct edge{
int v,w;
bool operator < (const edge & a) const{
return w > a.w;
}
};
struct info{ //链式前向星
int w,to,nt;
}node[500000];
void add(int x,int y,int z)
{
cnt++;
node[cnt].to = y;
node[cnt].w = z;
node[cnt].nt = pre[x];
pre[x] = cnt;
}
void prim()
{
ll sum = 0;
int flag = 0;
memset(dis,MAX,sizeof(dis));
dis[1] = 0;
priority_queue<edge> pq;
pq.push({1,0});
while(!pq.empty()){
int u = pq.top().v;
pq.pop();
if(vis[u]) continue;
flag++;sum += dis[u];
vis[u] = 1;
for(int i = pre[u]; i ; i = node[i].nt){
if(!vis[node[i].to] && dis[node[i].to] > node[i].w){
dis[node[i].to] = node[i].w;
pq.push({node[i].to,node[i].w});
}
}
}
if(flag < n) cout << "orz";
else cout << sum;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
int t1,t2,w;
cin >> n >> m;
for(int i = 1;i <= m; i++){
cin >> t1 >> t2 >> w;
if(t1 == t2) continue;
add(t1,t2,w);
add(t2,t1,w);
}
prim();
return 0;
}