我想使用Vector
在Haskell中编写Floyd-Warshall所有对最短路径算法的有效实现,以期获得良好的性能。
该实现非常简单,但是不使用3维| V |×| V |×| V |由于我们只读取了先前的k
值,因此使用了二维向量。
因此,该算法实际上只是传递2D向量并生成新2D向量的一系列步骤。最终的2D向量包含所有节点(i,j)之间的最短路径。
我的直觉告诉我,确保在每个步骤之前都已对先前的2D向量进行了评估非常重要,因此我在BangPatterns
函数的prev
参数和严格的fw
上使用了foldl'
:
{-# Language BangPatterns #-}
import Control.DeepSeq
import Control.Monad (forM_)
import Data.List (foldl')
import qualified Data.Map.Strict as M
import Data.Vector (Vector, (!), (//))
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as V hiding (length, replicate, take)
type Graph = Vector (M.Map Int Double)
type TwoDVector = Vector (Vector Double)
infinity :: Double
infinity = 1/0
-- calculate shortest path between all pairs in the given graph, if there are
-- negative cycles, return Nothing
allPairsShortestPaths :: Graph -> Int -> Maybe TwoDVector
allPairsShortestPaths g v =
let initial = fw g v V.empty 0
results = foldl' (fw g v) initial [1..v]
in if negCycle results
then Nothing
else Just results
where -- check for negative elements along the diagonal
negCycle a = any not $ map (\i -> a ! i ! i >= 0) [0..(V.length a-1)]
-- one step of the Floyd-Warshall algorithm
fw :: Graph -> Int -> TwoDVector -> Int -> TwoDVector
fw g v !prev k = V.create $ do -- ← bang
curr <- V.new v
forM_ [0..(v-1)] $ \i ->
V.write curr i $ V.create $ do
ivec <- V.new v
forM_ [0..(v-1)] $ \j -> do
let d = distance g prev i j k
V.write ivec j d
return ivec
return curr
distance :: Graph -> TwoDVector -> Int -> Int -> Int -> Double
distance g _ i j 0 -- base case; 0 if same vertex, edge weight if neighbours
| i == j = 0.0
| otherwise = M.findWithDefault infinity j (g ! i)
distance _ a i j k = let c1 = a ! i ! j
c2 = (a ! i ! (k-1))+(a ! (k-1) ! j)
in min c1 c2
但是,当使用带有47978条边的1000节点图运行该程序时,情况看起来一点也不好。内存使用率很高,程序花费的时间太长,无法运行。该程序是使用
ghc -O2
编译的。我重建了用于分析的程序,并将迭代次数限制为50:
results = foldl' (fw g v) initial [1..50]
然后,我使用
+RTS -p -hc
和+RTS -p -hd
运行该程序:这很有趣……但是我想这表明它正在堆积大量的暴徒。不好。
好的,所以在黑暗中拍摄了几张照片之后,我在
deepseq
中添加了fw
以确保prev
确实得到了评估:let d = prev `deepseq` distance g prev i j k
现在情况看起来更好了,实际上我可以在不断使用内存的情况下运行该程序以使其完成。显然,
prev
参数上的爆炸是不够的。为了与之前的图表进行比较,以下是添加
deepseq
后进行50次迭代的内存使用情况:好的,所以情况会好一些,但是我仍然有一些问题:
deepseq
有点难看吗? Vector
的用法在这里是惯用的/正确的吗?我正在为每次迭代构建一个全新的向量,并希望垃圾收集器将删除旧的Vector
。 供引用,这是
graph.txt
:http://sebsauvage.net/paste/?45147f7caf8c5f29#7tiCiPovPHWRm1XNvrSb/zNl3ujF3xB3yehrxhEdVWw=这是
main
:main = do
ls <- fmap lines $ readFile "graph.txt"
let numVerts = head . map read . words . head $ ls
let edges = map (map read . words) (tail ls)
let g = V.create $ do
g' <- V.new numVerts
forM_ [0..(numVerts-1)] (\idx -> V.write g' idx M.empty)
forM_ edges $ \[f,t,w] -> do
-- subtract one from vertex IDs so we can index directly
curr <- V.read g' (f-1)
V.write g' (f-1) $ M.insert (t-1) (fromIntegral w) curr
return g'
let a = allPairsShortestPaths g numVerts
case a of
Nothing -> putStrLn "Negative cycle detected."
Just a' -> do
putStrLn $ "The shortest, shortest path has length "
++ show ((V.minimum . V.map V.minimum) a')
最佳答案
首先,一些常规代码清除:
在fw
函数中,您显式分配和填充可变向量。但是,为此目的有一个预制函数generate
。 fw
因此可以重写为
V.generate v (\i -> V.generate v (\j -> distance g prev i j k))
类似地,图生成代码可以替换为
replicate
和accum
:let parsedEdges = map (\[f,t,w] -> (f - 1, (t - 1, fromIntegral w))) edges
let g = V.accum (flip (uncurry M.insert)) (V.replicate numVerts M.empty) parsedEdges
请注意,这完全消除了所有对突变的需求,而不会损失任何性能。
现在,到实际的问题:
deepseq
非常有用,但只能像这样解决空间泄漏。根本的问题不是在生成结果后就需要强制执行结果。相反,使用deepseq
意味着您应该首先更严格地构建该结构。实际上,如果在矢量创建代码中添加爆炸样式,如下所示:let !d = distance g prev i j k
然后,无需使用
deepseq
即可解决此问题。请注意,这不适用于generate
代码,因为出于某些原因(我可能为此创建功能请求),vector
并未为盒装矢量提供严格的功能。但是,当我回答严格的问题3的未装箱矢量时,两种方法都可以在没有严格注释的情况下使用。 Map Int
替换为IntMap
。因为这并不是功能的真正慢点,所以没有太大关系,但是对于繁重的工作量,IntMap
可以更快。 repa
。这具有自动并行化代码的巨大优势。请注意,由于repa
展平了其数组,并且显然不能正确地消除填充得很好的划分(可以与嵌套循环一起使用,但我认为它使用了一个循环和一个划分),因此它具有相同的性能正如我上面提到的那样,运行时间从1.3秒增加到1.8秒。但是,如果启用并行性并使用多核计算机,则会开始看到一些好处。不幸的是,您当前的测试用例太小了,看不到太多好处,因此,在我的6核计算机上,我看到它回落到1.2秒。如果我将大小恢复为[1..v]
而不是[1..50]
,则并行度将其从32秒增加到13秒。大概,如果您为该程序提供更大的输入,则可能会看到更多的好处。如果您有兴趣,我已经发布了
repa
-ified版本here。 -fllvm
。在我的计算机上使用repa
进行测试,在没有并行性的情况下我得到了14.7秒,这几乎与没有-fllvm
并具有并行性时一样好。通常,LLVM可以很好地处理基于数组的代码。