{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.Math.FFT.LLVM.PTX
-- Copyright   : [2017] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Math.FFT.LLVM.PTX (

  fft,
  fft1D,
  fft2D,
  fft3D,

) where

import Data.Array.Accelerate.Math.FFT.Mode
import Data.Array.Accelerate.Math.FFT.Type
import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base
import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans

import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime

import Data.Array.Accelerate.LLVM.PTX.Foreign

import Foreign.CUDA.Ptr                                             ( DevicePtr, castDevPtr )
import qualified Foreign.CUDA.FFT                                   as FFT

import Data.Hashable
import Data.Proxy
import Data.Typeable
import System.IO.Unsafe


fft :: forall sh e. (Shape sh, Numeric e)
    => Mode
    -> ForeignAcc (Array (sh:.Int) (Complex e) -> Array (sh:.Int) (Complex e))
fft mode
  | Just Refl <- matchShapeType (undefined::sh) (undefined::DIM0) = fft1D mode
  | Just Refl <- matchShapeType (undefined::sh) (undefined::DIM1) = ForeignAcc "cuda.fft2.many" $ fft' fft2DMany_plans mode
  | Just Refl <- matchShapeType (undefined::sh) (undefined::DIM2) = ForeignAcc "cuda.fft3.many" $ fft' fft3DMany_plans mode
  | otherwise = $internalError "fft" "only for 1D..3D inner-dimension transforms"

fft1D :: Numeric e
      => Mode
      -> ForeignAcc (Vector (Complex e) -> Vector (Complex e))
fft1D mode = ForeignAcc "cuda.fft1d" $ fft' fft1D_plans mode

fft2D :: Numeric e
      => Mode
      -> ForeignAcc (Array DIM2 (Complex e) -> Array DIM2 (Complex e))
fft2D mode = ForeignAcc "cuda.fft2d" $ fft' fft2D_plans mode

fft3D :: Numeric e
      => Mode
      -> ForeignAcc (Array DIM3 (Complex e) -> Array DIM3 (Complex e))
fft3D mode = ForeignAcc "cuda.fft3d" $ fft' fft3D_plans mode


-- Internals
-- ---------

{-# INLINEABLE fft' #-}
fft' :: forall sh e. (Shape sh, Numeric e)
     => Plans (sh, FFT.Type)
     -> Mode
     -> Stream
     -> Array sh (Complex e)
     -> LLVM PTX (Array sh (Complex e))
fft' plans mode stream =
  let
      go :: Numeric e => Array sh (Complex e) -> LLVM PTX (Array sh (Complex e))
      go ain = do
        let
            sh = shape ain
            t  = fftType (Proxy::Proxy e)
        --
        aout <- allocateRemote sh
        withArray ain stream    $ \d_in  -> do
         withArray aout stream  $ \d_out -> do
          withPlan plans (sh,t) $ \h     -> do
            liftIO $ cuFFT (Proxy::Proxy e) h mode stream (castDevPtr d_in) (castDevPtr d_out)
            return aout
  in
  case numericR::NumericR e of
    NumericRfloat32 -> go
    NumericRfloat64 -> go


-- Execute the FFT
--
{-# INLINE cuFFT #-}
cuFFT :: forall e. Numeric e
      => Proxy e
      -> FFT.Handle
      -> Mode
      -> Stream
      -> DevicePtr (Complex e)
      -> DevicePtr (Complex e)
      -> IO ()
cuFFT _ p mode stream d_in d_out =
  withLifetime stream $ \s -> do
    FFT.setStream p s
    case numericR::NumericR e of
      NumericRfloat32 -> FFT.execC2C p (fftMode mode) d_in d_out
      NumericRfloat64 -> FFT.execZ2Z p (fftMode mode) d_in d_out

fftType :: forall e. Numeric e => Proxy e -> FFT.Type
fftType _ =
  case numericR::NumericR e of
    NumericRfloat32 -> FFT.C2C
    NumericRfloat64 -> FFT.Z2Z

fftMode :: Mode -> FFT.Mode
fftMode Forward = FFT.Forward
fftMode _       = FFT.Inverse


-- Plan caches
-- -----------

{-# NOINLINE fft1D_plans #-}
fft1D_plans :: Plans (DIM1, FFT.Type)
fft1D_plans
  = unsafePerformIO
  $ createPlan (\(Z:.n, t) -> FFT.plan1D n t 1)
               (\(Z:.n, t) -> fromEnum t `hashWithSalt` n)

{-# NOINLINE fft2D_plans #-}
fft2D_plans :: Plans (DIM2, FFT.Type)
fft2D_plans
  = unsafePerformIO
  $ createPlan (\(Z:.h:.w, t) -> FFT.plan2D h w t)
               (\(Z:.h:.w, t) -> fromEnum t `hashWithSalt` h `hashWithSalt` w)

{-# NOINLINE fft3D_plans #-}
fft3D_plans :: Plans (DIM3, FFT.Type)
fft3D_plans
  = unsafePerformIO
  $ createPlan (\(Z:.d:.h:.w, t) -> FFT.plan3D d h w t)
               (\(Z:.d:.h:.w, t) -> fromEnum t `hashWithSalt` d `hashWithSalt` h `hashWithSalt` w)

{-# NOINLINE fft2DMany_plans #-}
fft2DMany_plans :: Plans (DIM2, FFT.Type)
fft2DMany_plans
  = unsafePerformIO
  $ createPlan (\(Z:.h:.w, t) -> FFT.planMany [h,w] Nothing Nothing t 1)
               (\(Z:.h:.w, t) -> fromEnum t `hashWithSalt` h `hashWithSalt` w)

{-# NOINLINE fft3DMany_plans #-}
fft3DMany_plans :: Plans (DIM3, FFT.Type)
fft3DMany_plans
  = unsafePerformIO
  $ createPlan (\(Z:.d:.h:.w, t) -> FFT.planMany [d,h,w] Nothing Nothing t 1)
               (\(Z:.d:.h:.w, t) -> fromEnum t `hashWithSalt` d `hashWithSalt` h `hashWithSalt` w)