{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {- | Accelerate interface to the native CUDA implementation of the Fourier Transform provided by the CUFFT library. -} module Data.Array.Accelerate.CUFFT.Private where import qualified Data.Array.Accelerate.CUFFT.RealClass as RC import qualified Data.Array.Accelerate.Fourier.Preprocessed as Prep import qualified Data.Array.Accelerate.Fourier.Planned as Fourier import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp import qualified Data.Array.Accelerate.Utility.Sliced as Sliced import Data.Array.Accelerate.Utility.Lift.Exp (expr) import qualified Data.Array.Accelerate.LLVM.PTX.Foreign as AF import qualified Data.Array.Accelerate.LLVM.PTX as PTX import qualified Data.Array.Accelerate.Array.Sugar as Sugar import qualified Data.Array.Accelerate as A import Data.Array.Accelerate.Data.Complex (Complex((:+)), real, imag, conjugate) import Data.Array.Accelerate.Lifetime (withLifetime) import Data.Array.Accelerate (Acc, Array, Elt, Shape, Slice, (:.)((:.)), Exp, (!), (?)) import qualified Foreign.CUDA.FFT as CUFFT import qualified Foreign.CUDA.Driver as CUDA import Foreign.CUDA.Ptr (DevicePtr) import qualified System.Mem.Weak as Weak import Control.Exception (bracket_) type Transform sh a b = Acc (Array sh a) -> Acc (Array sh b) type Sign a = (Int, Fourier.Sign a) forwardSign, inverseSign :: Num a => Sign a forwardSign = (-1, Fourier.forward) inverseSign = ( 1, Fourier.inverse) data Handle sh e a b = Handle (Transform sh a b) (Mode sh e a b) Int CUFFT.Handle makeHandle :: (Shape sh, Slice sh, RC.Real e) => Mode sh e a b -> Int -> (Fourier.Sign e -> Transform sh a b) -> (CUFFT.Type -> IO CUFFT.Handle) -> IO (Handle sh e a b) makeHandle mode width fallback planner = do plan <- planner $ types mode Weak.addFinalizer plan (CUFFT.destroy plan) return $ Handle (fallback $ fsign mode) mode width plan getBestTarget :: IO AF.PTX getBestTarget = do CUDA.initialise [] -- (dev,prop) <- PTX.selectBestDevice dev <- CUDA.device 0 prop <- CUDA.props dev PTX.createTargetForDevice dev prop [CUDA.SchedAuto] atTarget :: AF.PTX -> IO a -> IO a atTarget target act = withLifetime (AF.deviceContext $ AF.ptxContext target) $ \ctx -> bracket_ (CUDA.push ctx) CUDA.pop act type Batch0 sh = sh type Batch1 sh = Batch0 sh :. Int type Batch2 sh = Batch1 sh :. Int type Batch3 sh = Batch2 sh :. Int transform2D :: (Shape sh, Slice sh, RC.Real a) => Batch2 sh -> Fourier.Sign a -> Fourier.Transform (Batch2 sh) (Complex a) transform2D (_shape:.height:.width) sign = Prep.transform2d $ Prep.SubTransformPair (Fourier.transform sign width) (Fourier.transform sign height) transform3D :: (Shape sh, Slice sh, RC.Real a) => Batch3 sh -> Fourier.Sign a -> Fourier.Transform (Batch3 sh) (Complex a) transform3D (_shape:.depth:.height:.width) sign = Prep.transform3d $ Prep.SubTransformTriple (Fourier.transform sign width) (Fourier.transform sign height) (Fourier.transform sign depth) {- | The implementation works on all arrays of rank less than or equal to 3. The result is un-normalised. -} transform :: (Shape sh, Slice sh, RC.Real e) => Handle (sh:.Int) e a b -> Transform (sh:.Int) a b transform hndl@(Handle fallback mode width _) = {- Unfortunately the fallback version of the function needs to be wrapped in 'interleave' and 'deinterleave' to match the data layout as expected by the foreign version. Fusion might remove redundant transformations. The optimal solution is to make the backend explicit in the type, which allows us to declare back-end specific functions without a fall-back implementation. -} wrap mode (A.constant width) $ A.foreignAcc (AF.ForeignAcc "transformForeign" $ transformForeign hndl) (unwrap mode (A.constant width) fallback) forwardComplex, inverseComplex :: (Shape sh, Slice sh, RC.Real e) => Mode sh e (Complex e) (Complex e) forwardComplex = getModeC2C $ RC.switch (modeC2CFloat forwardSign) (modeC2CDouble forwardSign) inverseComplex = getModeC2C $ RC.switch (modeC2CFloat inverseSign) (modeC2CDouble inverseSign) {- | In contrast to plain CUFFT functions the data is redundant. That is, an array of shape @sh@ is transformed to an array of shape @sh@. This way, all dimensions of an array are handled the same way. Chances are good, that the internal post processing is fused with following array operations and thus the redundant data will not be stored in a manifest array. -} forwardReal :: (Shape sh, Slice sh, RC.Real e) => Mode (sh:.Int) e e (Complex e) forwardReal = getModeR2C $ RC.switch (modeR2C CUFFT.R2C CUFFT.execR2C) (modeR2C CUFFT.D2Z CUFFT.execD2Z) inverseReal :: (Shape sh, Slice sh, RC.Real e) => Mode (sh:.Int) e (Complex e) e inverseReal = getModeC2R $ RC.switch (modeC2R CUFFT.C2R CUFFT.execC2R) (modeC2R CUFFT.Z2D CUFFT.execZ2D) data Types = R2C | C2R | C2C deriving (Eq, Ord, Enum, Show) data Mode sh e a b = Mode { types :: CUFFT.Type, plainTypes :: Types, execute :: CUFFT.Handle -> CUDA.DevicePtr (Sugar.EltRepr e) -> CUDA.DevicePtr (Sugar.EltRepr e) -> IO (), wrap :: Exp Int -> Fourier.Transform (sh:.Int) e -> Transform sh a b, unwrap :: Exp Int -> Transform sh a b -> Fourier.Transform (sh:.Int) e, wrapFallback :: Fourier.Transform sh (Complex e) -> Transform sh a b, fsign :: Fourier.Sign e } newtype ModeC2C sh e = ModeC2C {getModeC2C :: Mode sh e (Complex e) (Complex e)} newtype ModeR2C sh e = ModeR2C {getModeR2C :: Mode sh e e (Complex e)} newtype ModeC2R sh e = ModeC2R {getModeC2R :: Mode sh e (Complex e) e} type Execute e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> IO () type ExecuteSign e = CUFFT.Handle -> DevicePtr e -> DevicePtr e -> Int -> IO () modeC2C :: (Shape sh, Slice sh, RC.Real e) => CUFFT.Type -> ExecuteSign (Sugar.EltRepr e) -> Sign e -> ModeC2C sh e modeC2C typ exec (isign,fsign0) = ModeC2C $ Mode { types = typ, execute = \hndl iptr optr -> exec hndl iptr optr isign, plainTypes = C2C, wrap = \ _width f -> deinterleave . f . interleave, unwrap = \ _width f -> interleave . f . deinterleave, wrapFallback = id, fsign = fsign0 } modeC2CFloat :: (Shape sh, Slice sh) => Sign Float -> ModeC2C sh Float modeC2CFloat = modeC2C CUFFT.C2C CUFFT.execC2C modeC2CDouble :: (Shape sh, Slice sh) => Sign Double -> ModeC2C sh Double modeC2CDouble = modeC2C CUFFT.Z2Z CUFFT.execZ2Z {- The fallback implementation is inefficient because it does not benefit from occurring symmetries. However, it works generally for all dimensions and also for odd data set sizes. -} modeR2C :: (Shape sh, Slice sh, RC.Real e) => CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeR2C (sh:.Int) e modeR2C typ exec = ModeR2C $ Mode { types = typ, execute = exec, plainTypes = R2C, wrap = \width f -> mirror width . deinterleave . f . addDim, unwrap = \width f -> interleave . takeHalf width . f . removeDim, wrapFallback = (. A.map (Exp.modify expr (:+0))), fsign = Fourier.forward } modeC2R :: (Shape sh, Slice sh, RC.Real e) => CUFFT.Type -> Execute (Sugar.EltRepr e) -> ModeC2R (sh:.Int) e modeC2R typ exec = ModeC2R $ Mode { types = typ, execute = exec, plainTypes = C2R, wrap = \width f -> removeDim . f . interleave . takeHalf width, unwrap = \width f -> addDim . f . mirror width . deinterleave, wrapFallback = (A.map real .), fsign = Fourier.inverse } transformForeign :: (Shape sh, RC.Real e) => Handle (sh:.Int) e a b -> AF.Stream -> Array (sh:.Int:.Int) e -> AF.LLVM AF.PTX (Array (sh:.Int:.Int) e) transformForeign (Handle _ mode width hndl) stream input = do let (shape :. _width :. _tupleSize) = A.arrayShape input outputSh = case plainTypes mode of R2C -> shape :. div width 2 + 1 :. (2::Int) C2R -> shape :. width :. 1 C2C -> shape :. width :. 2 output <- AF.allocateRemote outputSh withDevicePtr input $ \iptr -> withDevicePtr output $ \optr -> do AF.liftIO $ execute mode hndl iptr optr ev <- AF.checkpoint stream return (Just ev, (Just ev, output)) newtype WithDevicePtr target r sh e = WithDevicePtr { runWithDevicePtr :: Array sh e -> (CUDA.DevicePtr (Sugar.EltRepr e) -> AF.LLVM target (Maybe AF.Event, r)) -> AF.LLVM target r } withDevicePtr :: (RC.Real e) => Array sh e -> (CUDA.DevicePtr (Sugar.EltRepr e) -> AF.LLVM AF.PTX (Maybe AF.Event, r)) -> AF.LLVM AF.PTX r withDevicePtr = runWithDevicePtr (RC.switch (WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat) (WithDevicePtr $ \(Sugar.Array _ dat) -> AF.withDevicePtr dat)) {- The rule "interleave/deinterleave" may turn a bottom into the identity, if the input array has not extent 2 at the least-significant dimension. The rule is only safe for the usage in this module. -} {-# RULES "interleave/deinterleave" forall x. deinterleave (interleave x) = x; "deinterleave/interleave" forall x. interleave (deinterleave x) = x; "addDim/removeDim" forall x. removeDim (addDim x) = x; "removeDim/addDim" forall x. addDim (removeDim x) = x; #-} {- | Imitate cuComplex types by interleaving real and imaginary components. Adds a least-significant dimension of extent 2. -} {-# NOINLINE[1] interleave #-} interleave :: (Shape sh, Slice sh, Elt a) => Acc (Array sh (Complex a)) -> Acc (Array (sh:.Int) a) interleave arr = A.generate (A.lift $ A.shape arr :. (2::Int)) (\ix -> let x = arr ! A.indexTail ix in A.indexHead ix A.== 0 ? (real x, imag x)) {-# NOINLINE[1] deinterleave #-} deinterleave :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int) a) -> Acc (Array sh (Complex a)) deinterleave arr = A.generate (A.indexTail $ A.shape arr) (\ix -> let get n = arr ! A.lift (ix :. (n::Int)) in A.lift $ get 0 :+ get 1) {-# NOINLINE[1] addDim #-} addDim :: (Shape sh, Slice sh, Elt a) => Acc (Array sh a) -> Acc (Array (sh:.Int) a) addDim arr = A.reshape (A.lift $ A.shape arr :. (1::Int)) arr {-# NOINLINE[1] removeDim #-} removeDim :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int) a) -> Acc (Array sh a) removeDim arr = A.reshape (A.indexTail $ A.shape arr) arr takeHalf :: (Shape sh, Slice sh, Elt a) => Exp Int -> Fourier.Transform (sh:.Int) a takeHalf width = Sliced.take (div width 2 + 1) mirror :: (Shape sh, Slice sh, A.Num a) => Exp Int -> Fourier.Transform (sh:.Int) (Complex a) mirror newWidth arr = let (sh:.width) = Exp.unlift (expr:.expr) $ A.shape arr in A.generate (A.lift $ sh :. newWidth) $ Exp.modify (expr:.expr) $ \(ix:.k) -> k A.< width ? (arr ! Exp.indexCons ix k, conjugate (arr ! Exp.indexCons ix (newWidth - k)))