{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
Accelerate interface to the native CUDA implementation
of the Fourier Transform provided by the CUFFT library.
module Data.Array.Accelerate.CUFFT.Private where

import qualified Data.Array.Accelerate.CUFFT.RealClass as RC

import qualified Data.Array.Accelerate.Fourier.Preprocessed as Prep
import qualified Data.Array.Accelerate.Fourier.Planned as Fourier

import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import Data.Array.Accelerate.Utility.Lift.Exp (expr)

import qualified Data.Array.Accelerate.CUDA.Foreign as AF
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex (Complex((:+)), real, imag, conjugate)
import Data.Array.Accelerate
          (Acc, Array, Elt, Shape, Slice, (:.)((:.)),
           Exp, (!), (?), (==*), (<*),)

import qualified Foreign.CUDA.FFT as CUFFT
import qualified Foreign.CUDA.Driver as CUDA
import Foreign.CUDA.Ptr (DevicePtr, )

import qualified System.Mem.Weak as Weak

type Transform sh a b = Acc (Array sh a) -> Acc (Array sh b)

type Sign a = (Int, Fourier.Sign a)

forwardSign, inverseSign :: Num a => Sign a
forwardSign = (-1, Fourier.forward)
inverseSign = ( 1, Fourier.inverse)

   Handle sh e a b =
      Handle (Transform sh a b) (Mode sh e a b) Int CUFFT.Handle

makeHandle ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode sh e a b -> Int ->
   (Fourier.Sign e -> Transform sh a b) ->
   (CUFFT.Type -> IO CUFFT.Handle) ->
   IO (Handle sh e a b)
makeHandle mode width fallback planner = do
   plan <- planner $ types mode
   Weak.addFinalizer plan (CUFFT.destroy plan)
   return $ Handle (fallback $ fsign mode) mode width plan

type Batch0 sh = sh
type Batch1 sh = Batch0 sh :. Int
type Batch2 sh = Batch1 sh :. Int
type Batch3 sh = Batch2 sh :. Int

transform2D ::
   (Shape sh, Slice sh, Elt a, RC.Real a) =>
   Batch2 sh -> Fourier.Sign a ->
   Fourier.Transform (Batch2 sh) (Complex a)
transform2D (_shape:.height:.width) sign =
   Prep.transform2d $
         (Fourier.transform sign width)
         (Fourier.transform sign height)

transform3D ::
   (Shape sh, Slice sh, Elt a, RC.Real a) =>
   Batch3 sh -> Fourier.Sign a ->
   Fourier.Transform (Batch3 sh) (Complex a)
transform3D (_shape:.depth:.height:.width) sign =
   Prep.transform3d $
      (Fourier.transform sign width)
      (Fourier.transform sign height)
      (Fourier.transform sign depth)

{- |
The implementation works on all arrays of rank less than or equal to 3.
The result is un-normalised.
transform ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   Handle (sh:.Int) e a b ->
   Transform (sh:.Int) a b
transform hndl@(Handle fallback mode width _) =
   Unfortunately the fallback version of the function
   needs to be wrapped in 'interleave' and 'deinterleave'
   to match the data layout as expected by the foreign version.
   Fusion might remove redundant transformations.
   The optimal solution is to make the backend explicit in the type,
   which allows us to declare back-end specific functions
   without a fall-back implementation.
   wrap mode (A.constant width) $
      (AF.CUDAForeignAcc "transformForeign" $ const $ transformForeign hndl)
      (unwrap mode (A.constant width) fallback)

forwardComplex, inverseComplex ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode sh e (Complex e) (Complex e)
forwardComplex =
   getModeC2C $ RC.switch (modeC2CFloat forwardSign) (modeC2CDouble forwardSign)
inverseComplex =
   getModeC2C $ RC.switch (modeC2CFloat inverseSign) (modeC2CDouble inverseSign)

{- |
In contrast to plain CUFFT functions the data is redundant.
That is, an array of shape @sh@ is transformed to an array of shape @sh@.
This way, all dimensions of an array are handled the same way.
Chances are good,
that the internal post processing is fused with following array operations
and thus the redundant data will not be stored in a manifest array.
forwardReal ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode (sh:.Int) e e (Complex e)
forwardReal =
   getModeR2C $
      (modeR2C CUFFT.R2C CUFFT.execR2C)
      (modeR2C CUFFT.D2Z CUFFT.execD2Z)

inverseReal ::
   (Shape sh, Slice sh, RC.Real e) =>
   Mode (sh:.Int) e (Complex e) e
inverseReal =
   getModeC2R $
      (modeC2R CUFFT.C2R CUFFT.execC2R)
      (modeC2R CUFFT.Z2D CUFFT.execZ2D)

data Types = R2C | C2R | C2C
   deriving (Eq, Ord, Enum, Show)

data Mode sh e a b =
   Mode {
      types :: CUFFT.Type,
      plainTypes :: Types,
      execute :: CUFFT.Handle -> CUDA.DevicePtr e -> CUDA.DevicePtr e -> IO (),
      wrap :: Exp Int -> Fourier.Transform (sh:.Int) e -> Transform sh a b,
      unwrap :: Exp Int -> Transform sh a b -> Fourier.Transform (sh:.Int) e,
      wrapFallback :: Fourier.Transform sh (Complex e) -> Transform sh a b,
      fsign :: Fourier.Sign e

   ModeC2C sh e =
      ModeC2C {getModeC2C :: Mode sh e (Complex e) (Complex e)}

   ModeR2C sh e =
      ModeR2C {getModeR2C :: Mode sh e e (Complex e)}

   ModeC2R sh e =
      ModeC2R {getModeC2R :: Mode sh e (Complex e) e}

type Execute e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> IO ()
type ExecuteSign e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> Int -> IO ()

modeC2C ::
   (Shape sh, Slice sh, RC.Real e, Elt e) =>
   CUFFT.Type -> ExecuteSign e -> Sign e -> ModeC2C sh e
modeC2C typ exec (isign,fsign0) =
   ModeC2C $
   Mode {
      types = typ,
      execute = \hndl iptr optr -> exec hndl iptr optr isign,
      plainTypes = C2C,
      wrap = \ _width f -> deinterleave . f . interleave,
      unwrap = \ _width f -> interleave . f . deinterleave,
      wrapFallback = id,
      fsign = fsign0

modeC2CFloat :: (Shape sh, Slice sh) => Sign Float -> ModeC2C sh Float
modeC2CFloat = modeC2C CUFFT.C2C CUFFT.execC2C

modeC2CDouble :: (Shape sh, Slice sh) => Sign Double -> ModeC2C sh Double
modeC2CDouble = modeC2C CUFFT.Z2Z CUFFT.execZ2Z

The fallback implementation is inefficient
because it does not benefit from occurring symmetries.
However, it works generally for all dimensions
and also for odd data set sizes.
modeR2C ::
   (Shape sh, Slice sh, RC.Real e, Elt e) =>
   CUFFT.Type -> Execute e -> ModeR2C (sh:.Int) e
modeR2C typ exec =
   ModeR2C $
   Mode {
      types = typ,
      execute = exec,
      plainTypes = R2C,
      wrap = \width f -> mirror width . deinterleave . f . addDim,
      unwrap = \width f -> interleave . takeHalf width . f . removeDim,
      wrapFallback = (. A.map (Exp.modify expr (:+0))),
      fsign = Fourier.forward

modeC2R ::
   (Shape sh, Slice sh, RC.Real e, Elt e) =>
   CUFFT.Type -> Execute e -> ModeC2R (sh:.Int) e
modeC2R typ exec =
   ModeC2R $
   Mode {
      types = typ,
      execute = exec,
      plainTypes = C2R,
      wrap = \width f -> removeDim . f . interleave . takeHalf width,
      unwrap = \width f -> addDim . f . mirror width . deinterleave,
      wrapFallback = (A.map real .),
      fsign = Fourier.inverse

transformForeign ::
   (Shape sh, Elt e, RC.Real e) =>
   Handle (sh:.Int) e a b ->
   Array (sh:.Int:.Int) e -> AF.CIO (Array (sh:.Int:.Int) e)
transformForeign (Handle _ mode width hndl) input = do
   let (shape :. _width :. _tupleSize) = A.arrayShape input
       outputSh =
          case plainTypes mode of
             R2C -> shape :. div width 2 + 1 :. 2
             C2R -> shape :. width :. 1
             C2C -> shape :. width :. 2
   output <- AF.allocateArray outputSh
   iptr   <- getDevicePtr input
   optr   <- getDevicePtr output
   AF.liftIO $ execute mode hndl iptr optr
   return output

   GetDevicePtr sh e =
      GetDevicePtr {
         runGetDevicePtr :: Array sh e -> AF.CIO ((), CUDA.DevicePtr e)

getDevicePtr ::
   (RC.Real e) =>
   Array sh e -> AF.CIO (CUDA.DevicePtr e)
getDevicePtr =
   fmap snd .
         (GetDevicePtr AF.devicePtrsOfArray)
         (GetDevicePtr AF.devicePtrsOfArray))

The rule "interleave/deinterleave" may turn a bottom into the identity,
if the input array has not extent 2 at the least-significant dimension.
The rule is only safe for the usage in this module.
  "interleave/deinterleave" forall x. deinterleave (interleave x) = x;
  "deinterleave/interleave" forall x. interleave (deinterleave x) = x;
  "addDim/removeDim" forall x. removeDim (addDim x) = x;
  "removeDim/addDim" forall x. addDim (removeDim x) = x;

{- |
Imitate cuComplex types by interleaving real and imaginary components.
Adds a least-significant dimension of extent 2.
{-# NOINLINE[1] interleave #-}
interleave ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh (Complex a)) -> Acc (Array (sh:.Int) a)
interleave arr =
      (A.lift $ A.shape arr :. (2::Int))
      (\ix ->
         let x = arr ! A.indexTail ix
         in  A.indexHead ix ==* 0 ? (real x, imag x))

{-# NOINLINE[1] deinterleave #-}
deinterleave ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int) a) -> Acc (Array sh (Complex a))
deinterleave arr =
   A.generate (A.indexTail $ A.shape arr)
      (\ix ->
         let get n = arr ! A.lift (ix :. (n::Int))
         in  A.lift $ get 0 :+ get 1)

{-# NOINLINE[1] addDim #-}
addDim ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array sh a) -> Acc (Array (sh:.Int) a)
addDim arr = A.reshape (A.lift $ A.shape arr :. (1::Int)) arr

{-# NOINLINE[1] removeDim #-}
removeDim ::
   (Shape sh, Slice sh, Elt a) =>
   Acc (Array (sh:.Int) a) -> Acc (Array sh a)
removeDim arr = A.reshape (A.indexTail $ A.shape arr) arr

takeHalf ::
   (Shape sh, Slice sh, Elt a) =>
   Exp Int -> Fourier.Transform (sh:.Int) a
takeHalf width = Sliced.take (div width 2 + 1)

mirror ::
   (Shape sh, Slice sh, Elt a, A.IsNum a) =>
   Exp Int -> Fourier.Transform (sh:.Int) (Complex a)
mirror newWidth arr =
   let (sh:.width) = Exp.unlift (expr:.expr) $ A.shape arr
   in  A.generate (A.lift $ sh :. newWidth) $
       Exp.modify (expr:.expr) $ \(ix:.k) ->
          k <* width ?
             (arr ! Exp.indexCons ix k,
              conjugate (arr ! Exp.indexCons ix (newWidth - k)))