因此,我试图为Prims算法编写并行算法,但我不太清楚如何使用Spark Graphx来实现。我一直在努力寻找资源,但是没有很多在graphx中实现最短路径算法的示例。我认为我需要使用分治法将图拆分为子图,然后合并其MST。

Graphx资源:
http://ampcamp.berkeley.edu/big-data-mini-course/graph-analytics-with-graphx.html#the-property-graph

并行Prims资源:
https://www8.cs.umu.se/kurser/5DV050/VT10/handouts/F10.pdf

码:

import org.apache.spark._
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.util._

object ParallelPrims {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)
  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("Parallel Prims").setMaster("local")
    val sc = new SparkContext(conf)
    val logFile = "NodeData.txt"

    val logData = sc.textFile(logFile, 2).cache()
    // Splitting off header node
    val headerAndRows = logData.map(line => line.split(",").map(_.trim))
    val header = headerAndRows.first
    val data = headerAndRows.filter(_(0) != header(0))
    // Parse number of Nodes and Edges from header
    val numNodes = header(0).toInt
    val numEdges = header(1).toInt

    val vertexArray = new Array[(Long, String)](numNodes)

    val edgeArray = new Array[Edge[Int]](numEdges)
    // Create vertex array
    var count = 0
    for (count <- 0 to numNodes - 1) {
      vertexArray(count) = (count.toLong + 1, ("v" + (count + 1)).toString())
    }
    count = 0
    val rrdarr = data.take(data.count.toInt)
    // Create edge array
    for (count <- 0 to (numEdges - 1)) {
      val line = rrdarr(count)
      val cols = line.toList
      val edge = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
      edgeArray(count) = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt)
    }
    // Creating graphx graph
    val vertexRDD: RDD[(Long, (String))] = sc.parallelize(vertexArray)
    val edgeRDD: RDD[Edge[Int]] = sc.parallelize(edgeArray)

    val graph: Graph[String, Int] = Graph(vertexRDD, edgeRDD)

    graph.triplets.take(6).foreach(println)

  }

}

NodeData.txt
4,6
1,2,5
1,3,8
1,4,4
2,3,8
2,4,7
3,4,1

输出量
((1,v1),(2,v2),5)
((1,v1),(3,v3),8)
((1,v1),(4,v4),4)
((2,v2),(3,v3),8)
((2,v2),(4,v4),7)
((3,v3),(4,v4),1)

最佳答案

这是我的Prims算法版本。

var graph : Graph [String, Int] = ...

// just empty RDD for MST
var MST = sc.parallelize(Array[EdgeTriplet[Int, Int]]())

// pick random vertex from graph
var Vt: RDD[VertexId] = sc.parallelize(Array(graph.pickRandomVertex))

// do until all vertices is in Vt set
val vcount = graph.vertices.count
while (Vt.count < vcount) {

  // rdd to make inner joins
  val hVt = Vt.map(x => (x, x))

  // add key to make inner join
  val bySrc = graph.triplets.map(triplet => (triplet.srcId, triplet))

  // add key to make inner join
  val byDst = graph.triplets.map(triplet => (triplet.dstId, triplet))

  // all triplet where source vertex is in Vt
  val bySrcJoined = bySrc.join(hVt).map(_._2._1)

  // all triplet where destinaiton vertex is in Vt
  val byDstJoined = byDst.join(hVt).map(_._2._1)

  // sum previous two rdds and substract all triplets where both source and destination vertex in Vt
  val candidates = bySrcJoined.union(byDstJoined).subtract(byDstJoined.intersection(bySrcJoined))

  // find triplet with least weight
  val triplet = candidates.sortBy(triplet => triplet.attr).first

  // add triplet to MST
  MST = MST.union(sc.parallelize(Array(triplet)))

  // find out whether we should add source or destinaiton vertex to Vt
  if (!Vt.filter(x => x == triplet.srcId).isEmpty) {
    Vt = Vt.union(sc.parallelize(Array(triplet.dstId)))
  } else {
    Vt = Vt.union(sc.parallelize(Array(triplet.srcId)))
  }
}

// final minimum spanning tree
MST.collect.foreach(p => println(p.srcId + " " + p.attr + " " + p.dstId))

07-23 14:59