因此,我试图为Prims算法编写并行算法,但我不太清楚如何使用Spark Graphx来实现。我一直在努力寻找资源,但是没有很多在graphx中实现最短路径算法的示例。我认为我需要使用分治法将图拆分为子图,然后合并其MST。
Graphx资源:
http://ampcamp.berkeley.edu/big-data-mini-course/graph-analytics-with-graphx.html#the-property-graph
并行Prims资源:
https://www8.cs.umu.se/kurser/5DV050/VT10/handouts/F10.pdf
码:
import org.apache.spark._
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.util._
object ParallelPrims {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("Parallel Prims").setMaster("local")
val sc = new SparkContext(conf)
val logFile = "NodeData.txt"
val logData = sc.textFile(logFile, 2).cache()
// Splitting off header node
val headerAndRows = logData.map(line => line.split(",").map(_.trim))
val header = headerAndRows.first
val data = headerAndRows.filter(_(0) != header(0))
// Parse number of Nodes and Edges from header
val numNodes = header(0).toInt
val numEdges = header(1).toInt
val vertexArray = new Array[(Long, String)](numNodes)
val edgeArray = new Array[Edge[Int]](numEdges)
// Create vertex array
var count = 0
for (count <- 0 to numNodes - 1) {
vertexArray(count) = (count.toLong + 1, ("v" + (count + 1)).toString())
}
count = 0
val rrdarr = data.take(data.count.toInt)
// Create edge array
for (count <- 0 to (numEdges - 1)) {
val line = rrdarr(count)
val cols = line.toList
val edge = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
edgeArray(count) = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
}
// Creating graphx graph
val vertexRDD: RDD[(Long, (String))] = sc.parallelize(vertexArray)
val edgeRDD: RDD[Edge[Int]] = sc.parallelize(edgeArray)
val graph: Graph[String, Int] = Graph(vertexRDD, edgeRDD)
graph.triplets.take(6).foreach(println)
}
}
NodeData.txt
4,6
1,2,5
1,3,8
1,4,4
2,3,8
2,4,7
3,4,1
输出量
((1,v1),(2,v2),5)
((1,v1),(3,v3),8)
((1,v1),(4,v4),4)
((2,v2),(3,v3),8)
((2,v2),(4,v4),7)
((3,v3),(4,v4),1)
最佳答案
这是我的Prims算法版本。
var graph : Graph [String, Int] = ...
// just empty RDD for MST
var MST = sc.parallelize(Array[EdgeTriplet[Int, Int]]())
// pick random vertex from graph
var Vt: RDD[VertexId] = sc.parallelize(Array(graph.pickRandomVertex))
// do until all vertices is in Vt set
val vcount = graph.vertices.count
while (Vt.count < vcount) {
// rdd to make inner joins
val hVt = Vt.map(x => (x, x))
// add key to make inner join
val bySrc = graph.triplets.map(triplet => (triplet.srcId, triplet))
// add key to make inner join
val byDst = graph.triplets.map(triplet => (triplet.dstId, triplet))
// all triplet where source vertex is in Vt
val bySrcJoined = bySrc.join(hVt).map(_._2._1)
// all triplet where destinaiton vertex is in Vt
val byDstJoined = byDst.join(hVt).map(_._2._1)
// sum previous two rdds and substract all triplets where both source and destination vertex in Vt
val candidates = bySrcJoined.union(byDstJoined).subtract(byDstJoined.intersection(bySrcJoined))
// find triplet with least weight
val triplet = candidates.sortBy(triplet => triplet.attr).first
// add triplet to MST
MST = MST.union(sc.parallelize(Array(triplet)))
// find out whether we should add source or destinaiton vertex to Vt
if (!Vt.filter(x => x == triplet.srcId).isEmpty) {
Vt = Vt.union(sc.parallelize(Array(triplet.dstId)))
} else {
Vt = Vt.union(sc.parallelize(Array(triplet.srcId)))
}
}
// final minimum spanning tree
MST.collect.foreach(p => println(p.srcId + " " + p.attr + " " + p.dstId))