{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
Transformations of collections of datasets.
-}
module Data.Array.Accelerate.CUFFT.Batched (
   Priv.Transform,
   Priv.transform,

   Handle,
   plan1D,
   plan2D,
   plan3D,

   RC.Real,
   Mode,
   Priv.forwardComplex, Priv.inverseComplex,
   Priv.forwardReal, Priv.inverseReal,
   Batch0, Batch1, Batch2, Batch3,
   ) where

import qualified Data.Array.Accelerate.CUFFT.Private as Priv
import Data.Array.Accelerate.CUFFT.Private
          (Batch0, Batch1, Batch2, Batch3,
           Mode, wrapFallback, Handle, makeHandle, )

import qualified Data.Array.Accelerate.CUFFT.RealClass as RC

import qualified Data.Array.Accelerate.Fourier.Planned as Fourier

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Elt, Shape, Slice, (:.)((:.)), )

import qualified Foreign.CUDA.FFT as CUFFT


{- |
The plan must be created in the 'Data.Array.Accelerate.CUDA.Context'
where 'Priv.transform' is executed.
E.g. if you run 'Priv.transform' in 'Data.Array.Accelerate.CUDA.run1',
then you must call 'plan1D'
within 'Data.Array.Accelerate.CUDA.Foreign.inDefaultContext'.
-}
plan1D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   Mode (Batch1 sh) e a b -> Batch1 sh -> IO (Handle (Batch1 sh) e a b)
plan1D mode (batch:.width) =
   makeHandle mode width
      (\sign -> wrapFallback mode $ Fourier.transform sign width)
      (\typ -> CUFFT.planMany [width] Nothing Nothing typ (A.arraySize batch))

plan2D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   Mode (Batch2 sh) e a b -> Batch2 sh -> IO (Handle (Batch2 sh) e a b)
plan2D mode sh@(batch:.height:.width) =
   makeHandle mode width
      (wrapFallback mode . Priv.transform2D sh)
      (\typ ->
         CUFFT.planMany [height,width] Nothing Nothing typ (A.arraySize batch))

plan3D ::
   (Shape sh, Slice sh, Elt e, RC.Real e) =>
   Mode (Batch3 sh) e a b ->
   Batch3 sh -> IO (Handle (Batch3 sh) e a b)
plan3D mode sh@(batch:.depth:.height:.width) =
   makeHandle mode width
      (wrapFallback mode . Priv.transform3D sh)
      (\typ ->
         CUFFT.planMany [depth,height,width]
            Nothing Nothing typ (A.arraySize batch))