我有一个程序,大部分时间用于计算RGB值(无符号8位Word8的3元组)之间的欧几里得距离。我需要一个快速,无分支的无符号int绝对差函数,这样

unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b

特别是,
unsigned_difference a b == unsigned_difference b a
我使用GHC 7.8中的新primops提出了以下内容:
-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
    I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]

哪个ghc -O2 -S编译成
.Lc42U:
    movq 7(%rbx),%rax
    movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
    movq 8(%rbp),%rbx
    movq %rbx,%rcx
    subq %rax,%rcx
    cmpq %rax,%rbx
    setg %dl
    movzbl %dl,%edx
    imulq %rcx,%rdx
    movq %rax,%rcx
    subq %rbx,%rcx
    cmpq %rax,%rbx
    setl %al
    movzbl %al,%eax
    imulq %rcx,%rax
    addq %rdx,%rax
    movq %rax,(%r12)
    leaq -7(%r12),%rbx
    addq $16,%rbp
    jmp *(%rbp)

使用ghc -O2 -fllvm -optlo -O3 -S编译会生成以下asm:
.LBB6_1:
    movq    7(%rbx), %rsi
    movq    $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
    movq    8(%rbp), %rcx
    movq    %rsi, %rdx
    subq    %rcx, %rdx
    xorl    %edi, %edi
    subq    %rsi, %rcx
    cmovleq %rdi, %rcx
    cmovgeq %rdi, %rdx
    addq    %rcx, %rdx
    movq    %rdx, 16(%rax)
    movq    16(%rbp), %rax
    addq    $16, %rbp
    leaq    -7(%r12), %rbx
    jmpq    *%rax  # TAILCALL

因此LLVM设法用(更有效率?)条件移动指令代替比较。不幸的是,使用-fllvm进行编译对我的程序的运行时间影响很小。

但是,此功能有两个问题。
  • 我想比较Word8,但是比较primops必须使用Int。这会导致不必要的分配,因为我不得不存储64位的Int而不是Word8

  • 我已经分析并确认,fromIntegral :: Word8 -> Int的使用占该程序总分配的42.4%。
  • 我的版本使用2个比较,2个乘法和2个减法。我想知道是否有更有效的方法,使用按位运算或SIMD指令并利用我正在比较Word8的事实。

  • 我之前曾标记问题C/C++来吸引那些更倾向于位操作的人们的注意。我的问题使用Haskell,但我接受使用任何语言实现正确方法的答案。

    结论:

    我决定使用
    w8_sad :: Word8 -> Word8 -> Int16
    w8_sad a b = xor (diff + mask) mask
        where diff = fromIntegral a - fromIntegral b
              mask = unsafeShiftR diff 15
    

    因为它比我原来的unsigned_difference函数要快,并且易于实现。 Haskell中的SIMD内在函数尚未成熟。因此,尽管SIMD版本更快,但我还是决定使用标量版本。

    最佳答案

    好吧,我尝试进行基准测试。我将Criterion用于基准测试,因为它会进行适当的重要性测试。我在这里也使用QuickCheck来确保所有方法都返回相同的结果。

    我使用GHC 7.6.3(不幸的是,我无法包含您的primops函数)和-O3进行编译:

    ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff
    

    基本上,我们可以看到天真的实现与一些烦恼之间的区别:
    absdiff1_w8 :: Word8 -> Word8 -> Word8
    absdiff1_w8 a b = max a b - min a b
    
    absdiff2_w8 :: Word8 -> Word8 -> Word8
    absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
      where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
            mask = unsafeShiftR v 63
    

    输出:
    benchmarking absdiff_Word8/1
    mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
    ....
    
    benchmarking absdiff_Word8/2
    mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
    ...
    

    我使用“这里的Bit Twiddling Hacks”中的absolute integer value技巧。不幸的是,我们需要强制转换,我认为仅靠Word8就不可能很好地解决问题,但是无论如何使用本机整数类型似乎是明智的选择(尽管绝对不需要创建堆对象)。

    看起来并没有很大的区别,但是我的测试设置也不是完美的:我将函数映射到大量随机值列表上,以排除分支预测,从而使分支版本看起来比实际效率更高。这会导致重音在内存中累积,这可能会严重影响时序。当我们减去用于维护列表的不变开销时,我们可以看到的加速远远超过20%。

    生成的程序集实际上非常好(这是该函数的内联版本):
    .Lc4BB:
        leaq 7(%rbx),%rax
        movq 8(%rbp),%rbx
        subq (%rax),%rbx
        movq %rbx,%rax
        sarq $63,%rax
        movq $base_GHCziInt_I64zh_con_info,-8(%r12)
        addq %rax,%rbx
        xorq %rax,%rbx
        movq %rbx,0(%r12)
        leaq -7(%r12),%rbx
        movq $s4z0_info,8(%rbp)
    

    1个减法,1个加法,1个右移,1个xor且没有分支,如预期的那样。使用LLVM后端并不能显着改善运行时间。

    希望这对您尝试更多的东西很有用。
    {-# LANGUAGE BangPatterns #-}
    {-# LANGUAGE ScopedTypeVariables #-}
    module Main where
    
    import Data.Word
    import Data.Int
    import Data.Bits
    import Control.Arrow ((***))
    import Control.DeepSeq (force)
    import Control.Exception (evaluate)
    import Control.Monad
    import System.Random
    import Unsafe.Coerce
    
    import Test.QuickCheck hiding ((.&.))
    import Criterion.Main
    
    absdiff1_w8 :: Word8 -> Word8 -> Word8
    absdiff1_w8 !a !b = max a b - min a b
    
    absdiff1_int16 :: Int16 -> Int16 -> Int16
    absdiff1_int16 a b = max a b - min a b
    
    absdiff1_int :: Int -> Int -> Int
    absdiff1_int a b = max a b - min a b
    
    absdiff2_int16 :: Int16 -> Int16 -> Int16
    absdiff2_int16 a b = xor (v + mask) mask
      where v = a - b
            mask = unsafeShiftR v 15
    
    absdiff2_w8 :: Word8 -> Word8 -> Word8
    absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
      where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
            !mask = unsafeShiftR v 63
    
    absdiff3_w8 :: Word8 -> Word8 -> Word8
    absdiff3_w8 a b = if a > b then a - b else b - a
    
    {-absdiff4_int :: Int -> Int -> Int-}
    {-absdiff4_int (I# a) (I# b) =-}
        {-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}
    
    e2e :: (Enum a, Enum b) => a -> b
    e2e = toEnum . fromEnum
    
    prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
    prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
        where x' = e2e x
              y' = e2e y
    
    check = quickCheck prop_same1
         >> quickCheck prop_same2
    
    instance (Random x, Random y) => Random (x, y) where
      random gen1 =
        let (x, gen2) = random gen1
            (y, gen3) = random gen2
        in ((x,y),gen3)
    
    main =
        do check
           !pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
           let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
           defaultMain
             [ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
                                      , bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
                                      , bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
                                      ]
             , bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
                                      , bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
                                      ]
             {-, bgroup "absdiff_Int"   [ bench "1" $ whnf (absdiff1_int 13) 14-}
                                      {-, bench "2" $ whnf (absdiff3_int 13) 14-}
                                      {-]-}
             ]
    

    10-06 05:12