{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Math.FFT (
Mode(..),
FFTElt,
fft1D, fft1D',
fft2D, fft2D',
fft3D, fft3D',
fft
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar ( showShape, shapeToList )
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Math.FFT.Mode
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
import qualified Data.Array.Accelerate.Math.FFT.LLVM.Native as Native
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import qualified Data.Array.Accelerate.Math.FFT.LLVM.PTX as PTX
#endif
import Data.Bits
import Text.Printf
import Prelude as P
type FFTElt e = (P.Num e, A.RealFloat e, A.FromIntegral Int e, A.IsFloating e)
fft1D :: FFTElt e
=> Mode
-> Array DIM1 (Complex e)
-> Acc (Array DIM1 (Complex e))
fft1D mode vec
= fft1D' mode (arrayShape vec) (use vec)
fft1D' :: forall e. FFTElt e
=> Mode
-> DIM1
-> Acc (Array DIM1 (Complex e))
-> Acc (Array DIM1 (Complex e))
fft1D' mode (Z :. len) arr
= let sign = signOfMode mode :: e
scale = A.fromIntegral (A.length arr)
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
foreignAcc (Native.fft1D mode) $
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
foreignAcc (PTX.fft1D mode) $
#endif
fft sign Z len
in
case mode of
Inverse -> A.map (/scale) (go arr)
_ -> go arr
fft2D :: FFTElt e
=> Mode
-> Array DIM2 (Complex e)
-> Acc (Array DIM2 (Complex e))
fft2D mode arr
= fft2D' mode (arrayShape arr) (use arr)
fft2D' :: forall e. FFTElt e
=> Mode
-> DIM2
-> Acc (Array DIM2 (Complex e))
-> Acc (Array DIM2 (Complex e))
fft2D' mode (Z :. height :. width) arr
= let sign = signOfMode mode :: e
scale = A.fromIntegral (A.size arr)
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
foreignAcc (Native.fft2D mode) $
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
foreignAcc (PTX.fft2D mode) $
#endif
fft'
fft' a = A.transpose . fft sign (Z:.height) width
>-> A.transpose . fft sign (Z:.width) height
$ a
in
case mode of
Inverse -> A.map (/scale) (go arr)
_ -> go arr
fft3D :: FFTElt e
=> Mode
-> Array DIM3 (Complex e)
-> Acc (Array DIM3 (Complex e))
fft3D mode arr
= fft3D' mode (arrayShape arr) (use arr)
fft3D' :: forall e. FFTElt e
=> Mode
-> DIM3
-> Acc (Array DIM3 (Complex e))
-> Acc (Array DIM3 (Complex e))
fft3D' mode (Z :. depth :. height :. width) arr
= let sign = signOfMode mode :: e
scale = A.fromIntegral (A.size arr)
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
foreignAcc (Native.fft3D mode) $
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
foreignAcc (PTX.fft3D mode) $
#endif
fft'
fft' a = rotate3D . fft sign (Z:.depth :.height) width
>-> rotate3D . fft sign (Z:.height:.width) depth
>-> rotate3D . fft sign (Z:.width :.depth) height
$ a
in
case mode of
Inverse -> A.map (/scale) (go arr)
_ -> go arr
rotate3D :: Elt e => Acc (Array DIM3 e) -> Acc (Array DIM3 e)
rotate3D arr = backpermute sh rot arr
where
sh :: Exp DIM3
sh =
let Z :. z :. y :. x = unlift (shape arr) :: Z :. Exp Int :. Exp Int :. Exp Int
in index3 y x z
rot :: Exp DIM3 -> Exp DIM3
rot ix =
let Z :. z :. y :. x = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int
in index3 x z y
fft :: forall sh e. (Slice sh, Shape sh, A.RealFloat e, A.FromIntegral Int e)
=> e
-> sh
-> Int
-> Acc (Array (sh:.Int) (Complex e))
-> Acc (Array (sh:.Int) (Complex e))
fft sign sh sz arr
| P.any (P.not . isPow2) (shapeToList (sh:.sz))
= error $ printf "fft: array dimensions must be powers-of-two, but are: %s" (showShape (sh:.sz))
| otherwise
= go sz 0 1
where
go :: Int -> Int -> Int -> Acc (Array (sh:.Int) (Complex e))
go len offset stride
| len P.== 2
= A.generate (constant (sh :. len)) swivel
| otherwise
= combine
(go (len `div` 2) offset (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))
where
len' = the (unit (constant len))
offset' = the (unit (constant offset))
stride' = the (unit (constant stride))
swivel ix =
let sh' :. sz' = unlift ix :: Exp sh :. Exp Int
in
sz' A.== 0 ? ( (arr ! lift (sh' :. offset')) + (arr ! lift (sh' :. offset' + stride'))
, (arr ! lift (sh' :. offset')) - (arr ! lift (sh' :. offset' + stride')) )
combine evens odds =
let odds' = A.generate (A.shape odds) (\ix -> twiddle len' (indexHead ix) * odds!ix)
in
append (A.zipWith (+) evens odds') (A.zipWith (-) evens odds')
twiddle n' i' =
let n = A.fromIntegral n'
i = A.fromIntegral i'
k = 2*pi*i/n
in
lift ( cos k :+ A.constant sign * sin k )
append
:: forall sh e. (Slice sh, Shape sh, Elt e)
=> Acc (Array (sh:.Int) e)
-> Acc (Array (sh:.Int) e)
-> Acc (Array (sh:.Int) e)
append xs ys
= let sh :. n = unlift (A.shape xs) :: Exp sh :. Exp Int
_ :. m = unlift (A.shape ys) :: Exp sh :. Exp Int
in
generate (lift (sh :. n+m))
(\ix -> let sz :. i = unlift ix :: Exp sh :. Exp Int
in i A.< n ? (xs ! lift (sz:.i), ys ! lift (sz:.i-n) ))
isPow2 :: Int -> Bool
isPow2 0 = True
isPow2 1 = False
isPow2 x = x .&. (x-1) P.== 0