{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} module Main where import Criterion.Main (Benchmark, defaultMain, bgroup, bench, whnf, ) import qualified Data.Array.Accelerate.CUFFT.Batched as CUFFT import qualified Data.Array.Accelerate.FFTW.Manifest as FFTW import qualified Data.Array.Accelerate.Fourier.Planned as Planned import qualified Data.Array.Accelerate.Fourier.Preprocessed as Prep import qualified Data.Array.Accelerate.CUDA as CUDA import qualified Data.Array.Accelerate as A import Data.Array.Accelerate (Array, DIM2, Z(Z), (:.)((:.)), ) import Data.Complex (Complex, ) import System.IO.Unsafe (unsafePerformIO) data Elem a = Elem width :: DIM2 -> Int width (Z:._:.n) = n powersOfTwo :: (A.Elt a, RealFloat a) => Elem a -> (DIM2 -> Array DIM2 (Complex a) -> Array DIM2 (Complex a)) -> [Benchmark] powersOfTwo Elem f = flip map (take 6 $ iterate (2*) 1024) $ \len -> let sh = Z:.16:.len in bench (show len) $ whnf (f sh) $ A.fromList sh $ repeat 0 powersOfTwos :: (CUFFT.Real a, A.IsFloating a, FFTW.Element a, RealFloat a) => Elem a -> [Benchmark] powersOfTwos e = bgroup "CUDA split-radix" (powersOfTwo e (CUDA.run1 . Prep.ditSplitRadix Prep.forward . width)) : bgroup "CUFFT" (powersOfTwo e $ \sh -> CUDA.run1 $ CUFFT.transform $ unsafePerformIO $ CUFFT.plan1D CUFFT.forwardComplex sh) : bgroup "FFTW" (powersOfTwo e (const FFTW.dft)) : [] arbitrary :: (A.Elt a, RealFloat a) => Elem a -> (DIM2 -> Array DIM2 (Complex a) -> Array DIM2 (Complex a)) -> [Benchmark] arbitrary Elem f = flip map (takeWhile (<=128) $ iterate (1+) 64) $ \len -> let sh = Z:.2048:.len in bench (show len) $ whnf (f sh) $ A.fromList sh $ repeat 0 arbitraryLengths :: (CUFFT.Real a, A.IsFloating a, FFTW.Element a, RealFloat a) => Elem a -> [Benchmark] arbitraryLengths e = bgroup "CUDA generic" (arbitrary e (CUDA.run1 . Planned.transform Planned.forward . width)) : bgroup "CUFFT" (arbitrary e $ \sh -> CUDA.run1 $ CUFFT.transform $ unsafePerformIO $ CUFFT.plan1D CUFFT.forwardComplex sh) : bgroup "FFTW" (arbitrary e (const FFTW.dft)) : [] allBenchs :: (CUFFT.Real a, A.IsFloating a, FFTW.Element a, RealFloat a) => Elem a -> [Benchmark] allBenchs e = bgroup "2^n" (powersOfTwos e) : bgroup "any" (arbitraryLengths e) : [] main :: IO () main = defaultMain $ bgroup "float" (allBenchs (Elem :: Elem Float)) : bgroup "double" (allBenchs (Elem :: Elem Double)) : []