2017-02-11 90 views
19

这是我的previous question的后续行动,我问道为什么流合并不在某个程序中踢。事实证明,问题在于某些函数未被内联,并且一个标志将性能提高了约17x(它展示了内联的重要性!)。有什么方法可以内联递归函数吗?

现在,请注意,在原始问题上,我一次硬编码64调用incAll。现在,假设,相反,我创建一个nTimes功能,反复调用的函数:

module Main where 

import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a 
nTimes 0 f x = x 
nTimes n f x = f (nTimes (n-1) f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes 64 incAll array) 

在这种情况下,只需添加一个INLINE编译到nTimes不会帮助,因为据我所知GHC不内联递归功能。在编译时强制GHC扩展nTimes是否有任何窍门,从而恢复预期的性能?

+2

您可以使用Template Haskell来引入语法来扩展重复的应用程序。 –

+1

@JoachimBreitner刚刚完成了这个。必须学习模板Haskell。仍在测试我的答案,但似乎要快得多(类似于其他问题)。 – Zeta

回答

26

不,但您可以使用更好的功能。我不是在谈论V.map (+64),这会让事情变得更快,但约nTimes。我们有三个候选人已经做nTimes做:

{-# INLINE nTimesFoldr #-} 
nTimesFoldr :: Int -> (a -> a) -> a -> a  
nTimesFoldr n f x = foldr (.) id (replicate n f) $ x 

{-# INLINE nTimesIterate #-} 
nTimesIterate :: Int -> (a -> a) -> a -> a  
nTimesIterate n f x = iterate f x !! n 

{-# INLINE nTimesTail #-} 
nTimesTail :: Int -> (a -> a) -> a -> a  
nTimesTail n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

所有版本大约需要8秒,相比40秒钟内你的版本需要。顺便说一下,Joachim的版本也需要8秒。请注意,iterate版本在我的系统上占用更多内存。尽管GHC有unroll plugin,但在过去的五年里它没有更新(它使用自定义的说明)。

根本没有展开?

但是,在我们绝望之前,GHC实际上试图将所有内容都嵌入其中?让我们用nTimesTailnTimes 1

module Main where 
import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a  
nTimes n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes 1 incAll array) 
$ stack ghc --package vector -- -O2 -ddump-simpl -dsuppress-all SO.hs 
main2 = 
    case (runSTRep main3) `cast` ... 
    of _ { Vector ww1_s9vw ww2_s9vx ww3_s9vy -> 
    case $wgo 1 ww1_s9vw ww2_s9vx ww3_s9vy 
    of _ { (# ww5_s9w3, ww6_s9w4, ww7_s9w5 #) -> 

我们可以停在那儿。 $wgo是上面定义的go。即使使用1 GHC也不会展开循环。这是令人不安的,因为1是一个常数。

拯救模板

但是,唉,它并不是全部。如果C++程序员能够对编译时常量进行以下操作,那么我们应该如此,对吧?

template <int N> 
struct Call{ 
    template <class F, class T> 
    static T call(F f, T && t){ 
     return f(Call<N-1>::call(f,std::forward<T>(t))); 
    } 
}; 
template <> 
struct Call<0>{ 
    template <class F, class T> 
    static T call(F f, T && t){ 
     return t; 
    } 
}; 

果然,我们可以与TemplateHaskell*

-- Times.sh 
{-# LANGUAGE TemplateHaskell #-} 
module Times where 

import Control.Monad (when) 
import Language.Haskell.TH 

nTimesTH :: Int -> Q Exp 
nTimesTH n = do 
    f <- newName "f" 
    x <- newName "x" 

    when (n <= 0) (reportWarning "nTimesTH: argument non-positive") 

    let go k | k <= 0 = VarE x 
     go k   = AppE (VarE f) (go (k - 1)) 
    return $ LamE [VarP f,VarP x] (go n) 

是什么nTimesTH办?它会创建一个新函数,其中第一个名称f将应用于第二个名称x,总计n次。 n现在需要一个编译时间常数,它适合我们,因为循环展开,才可能与编译时间常数:

$(nTimesTH 0) = \f x -> x 
$(nTimesTH 1) = \f x -> f x 
$(nTimesTH 2) = \f x -> f (f x) 
$(nTimesTH 3) = \f x -> f (f (f x)) 
... 

是否行得通?它快吗?与nTimes相比有多快?让我们尝试另一个main为:

-- SO.hs 
{-# LANGUAGE TemplateHaskell #-} 
module Main where 
import Times 
import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a  
nTimes n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    let vTH = V.sum ($(nTimesTH 64) incAll array) 
    let vNorm = V.sum (nTimes 64 incAll array) 
    print $ vTH == vNorm 
stack ghc --package vector -- -O2 SO.hs && SO.exe +RTS -t 
True 
<<ghc: 52000056768 bytes, 66 GCs, 400034700/800026736 avg/max bytes residency (2 samples), 1527M in use, 0.000 INIT (0.000 elapsed), 8.875 MUT (9.119 elapsed), 0.000 GC (0.094 elapsed) :ghc>> 

它得到正确的结果。它有多快?让我们再次用另一个main

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum ($(nTimesTH 64) incAll array) 
 800,048,112 bytes allocated in the heap           
      4,352 bytes copied during GC            
      42,664 bytes maximum residency (1 sample(s))        
      18,776 bytes maximum slop             
      764 MB total memory in use (0 MB lost due to fragmentation)    

            Tot time (elapsed) Avg pause Max pause   
    Gen 0   1 colls,  0 par 0.000s 0.000s  0.0000s 0.0000s   
    Gen 1   1 colls,  0 par 0.000s 0.049s  0.0488s 0.0488s   

    INIT time 0.000s ( 0.000s elapsed)           
    MUT  time 0.172s ( 0.221s elapsed)           
    GC  time 0.000s ( 0.049s elapsed)           
    EXIT time 0.000s ( 0.049s elapsed)           
    Total time 0.188s ( 0.319s elapsed)           

    %GC  time  0.0% (15.3% elapsed)           

    Alloc rate 4,654,825,378 bytes per MUT second         

    Productivity 100.0% of total user, 58.7% of total elapsed   

好,比较,为8秒。因此,对于TL; DR:如果您有编译时常量,并且您想基于该常量创建和/或修改您的代码,请考虑模板Haskell。

*请注意,这是我写的第一个模板Haskell代码。小心使用。不要使用太大的n,否则你最终可能会遇到混乱的功能。

+2

注意:解决方案是[代码审查](https://codereview.stackexchange.com/questions/155144/execute-a-function-n-times-where-n-is-known-at-compile-time )。 – Zeta

+0

嘿刚回来让你知道这是在大多数方面的辉煌答案,谢谢。 – MaiaVictor

4

你可以写

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a 
nTimes n f x = go n 
    where go 0 = x 
     go n = f (go (n-1)) 

和GHC会内联nTimes,并有可能专门递归go您的特定参数incAllarray,但它不会展开循环。

+0

啊,很烂,谢谢。 – MaiaVictor

14

Andres已经告诉过我一个小知道的技巧,在那里你可以通过使用类型类实际获得GHC内联递归函数。

这个想法是,而不是写一个函数,通常你在一个值上执行结构递归。您可以使用类型类定义函数,并对类型参数执行结构递归。在这个例子中,类型级自然数。

由于每次递归调用的类型不同,GHC会高兴地嵌入每个递归调用并生成高效的代码。

我没有对此进行基准测试或看看核心,但它明显更快。

{-# LANGUAGE DataKinds #-} 
{-# LANGUAGE KindSignatures #-} 
{-# LANGUAGE PolyKinds #-} 
{-# LANGUAGE ScopedTypeVariables #-} 
module Main where 

import qualified Data.Vector.Unboxed as V 

data Proxy a = Proxy 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

oldNTimes :: Int -> (a -> a) -> a -> a 
oldNTimes 0 f x = x 
oldNTimes n f x = f (oldNTimes (n-1) f x) 

-- New definition 

data N = Z | S N 

class Unroll (n :: N) where 
    nTimes :: Proxy n -> (a -> a) -> a -> a 

instance Unroll Z where 
    nTimes _ f x = x 

instance Unroll n => Unroll (S n) where 
    nTimes p f x = 
     let Proxy :: Proxy (S n) = p 
     in f (nTimes (Proxy :: Proxy n) f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes (Proxy :: Proxy (S (S (S (S (S (S (S (S (S (S (S Z)))))))))))) incAll array) 
    print $ V.sum (oldNTimes 11 incAll array) 
+0

不错,虽然如果你想使用'nTimes 64','Proxy :: Proxy(S(S(S(S ...(SZ)...)'这个词会比较有趣,我会用它来类型级别的算术,但。有些像'代理(十:*:六:+:四)'。 – Zeta

+0

我仍然无法得到这些类型类的编程恶作剧,任何明确表示我的人都是这样的人。 – MaiaVictor