{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Math.FFT.LLVM.PTX (
fft,
fft1D,
fft2D,
fft3D,
) where
import Data.Array.Accelerate.Math.FFT.Mode
import Data.Array.Accelerate.Math.FFT.Type
import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Foreign.CUDA.Ptr ( DevicePtr, castDevPtr )
import qualified Foreign.CUDA.FFT as FFT
import Data.Hashable
import Data.Proxy
import Data.Typeable
import System.IO.Unsafe
fft :: forall sh e. (Shape sh, Numeric e)
=> Mode
-> ForeignAcc (Array (sh:.Int) (Complex e) -> Array (sh:.Int) (Complex e))
fft mode
| Just Refl <- matchShapeType (undefined::sh) (undefined::DIM0) = fft1D mode
| Just Refl <- matchShapeType (undefined::sh) (undefined::DIM1) = ForeignAcc "cuda.fft2.many" $ fft' fft2DMany_plans mode
| Just Refl <- matchShapeType (undefined::sh) (undefined::DIM2) = ForeignAcc "cuda.fft3.many" $ fft' fft3DMany_plans mode
| otherwise = $internalError "fft" "only for 1D..3D inner-dimension transforms"
fft1D :: Numeric e
=> Mode
-> ForeignAcc (Vector (Complex e) -> Vector (Complex e))
fft1D mode = ForeignAcc "cuda.fft1d" $ fft' fft1D_plans mode
fft2D :: Numeric e
=> Mode
-> ForeignAcc (Array DIM2 (Complex e) -> Array DIM2 (Complex e))
fft2D mode = ForeignAcc "cuda.fft2d" $ fft' fft2D_plans mode
fft3D :: Numeric e
=> Mode
-> ForeignAcc (Array DIM3 (Complex e) -> Array DIM3 (Complex e))
fft3D mode = ForeignAcc "cuda.fft3d" $ fft' fft3D_plans mode
{-# INLINEABLE fft' #-}
fft' :: forall sh e. (Shape sh, Numeric e)
=> Plans (sh, FFT.Type)
-> Mode
-> Stream
-> Array sh (Complex e)
-> LLVM PTX (Array sh (Complex e))
fft' plans mode stream =
let
go :: Numeric e => Array sh (Complex e) -> LLVM PTX (Array sh (Complex e))
go ain = do
let
sh = shape ain
t = fftType (Proxy::Proxy e)
aout <- allocateRemote sh
withArray ain stream $ \d_in -> do
withArray aout stream $ \d_out -> do
withPlan plans (sh,t) $ \h -> do
liftIO $ cuFFT (Proxy::Proxy e) h mode stream (castDevPtr d_in) (castDevPtr d_out)
return aout
in
case numericR::NumericR e of
NumericRfloat32 -> go
NumericRfloat64 -> go
{-# INLINE cuFFT #-}
cuFFT :: forall e. Numeric e
=> Proxy e
-> FFT.Handle
-> Mode
-> Stream
-> DevicePtr (Complex e)
-> DevicePtr (Complex e)
-> IO ()
cuFFT _ p mode stream d_in d_out =
withLifetime stream $ \s -> do
FFT.setStream p s
case numericR::NumericR e of
NumericRfloat32 -> FFT.execC2C p (fftMode mode) d_in d_out
NumericRfloat64 -> FFT.execZ2Z p (fftMode mode) d_in d_out
fftType :: forall e. Numeric e => Proxy e -> FFT.Type
fftType _ =
case numericR::NumericR e of
NumericRfloat32 -> FFT.C2C
NumericRfloat64 -> FFT.Z2Z
fftMode :: Mode -> FFT.Mode
fftMode Forward = FFT.Forward
fftMode _ = FFT.Inverse
{-# NOINLINE fft1D_plans #-}
fft1D_plans :: Plans (DIM1, FFT.Type)
fft1D_plans
= unsafePerformIO
$ createPlan (\(Z:.n, t) -> FFT.plan1D n t 1)
(\(Z:.n, t) -> fromEnum t `hashWithSalt` n)
{-# NOINLINE fft2D_plans #-}
fft2D_plans :: Plans (DIM2, FFT.Type)
fft2D_plans
= unsafePerformIO
$ createPlan (\(Z:.h:.w, t) -> FFT.plan2D h w t)
(\(Z:.h:.w, t) -> fromEnum t `hashWithSalt` h `hashWithSalt` w)
{-# NOINLINE fft3D_plans #-}
fft3D_plans :: Plans (DIM3, FFT.Type)
fft3D_plans
= unsafePerformIO
$ createPlan (\(Z:.d:.h:.w, t) -> FFT.plan3D d h w t)
(\(Z:.d:.h:.w, t) -> fromEnum t `hashWithSalt` d `hashWithSalt` h `hashWithSalt` w)
{-# NOINLINE fft2DMany_plans #-}
fft2DMany_plans :: Plans (DIM2, FFT.Type)
fft2DMany_plans
= unsafePerformIO
$ createPlan (\(Z:.h:.w, t) -> FFT.planMany [h,w] Nothing Nothing t 1)
(\(Z:.h:.w, t) -> fromEnum t `hashWithSalt` h `hashWithSalt` w)
{-# NOINLINE fft3DMany_plans #-}
fft3DMany_plans :: Plans (DIM3, FFT.Type)
fft3DMany_plans
= unsafePerformIO
$ createPlan (\(Z:.d:.h:.w, t) -> FFT.planMany [d,h,w] Nothing Nothing t 1)
(\(Z:.d:.h:.w, t) -> fromEnum t `hashWithSalt` d `hashWithSalt` h `hashWithSalt` w)