{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Math.FFT.LLVM.PTX -- Copyright : [2017] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Math.FFT.LLVM.PTX ( fft1D, fft2D, fft3D, ) where import Data.Array.Accelerate.Math.FFT.Mode import Data.Array.Accelerate.Math.FFT.Twine import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Lifetime import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Data.Complex import Data.Array.Accelerate.Type import Data.Array.Accelerate.LLVM.PTX.Foreign import Foreign.CUDA.Ptr ( DevicePtr ) import Foreign.Ptr import Foreign.Storable import Foreign.CUDA.Analysis import qualified Foreign.CUDA.FFT as FFT import qualified Foreign.CUDA.Driver as CUDA hiding ( device ) import qualified Foreign.CUDA.Driver.Context as CUDA ( device ) import Control.Concurrent.MVar import Control.Exception import Control.Monad import Data.Maybe import Data.Typeable import System.IO.Unsafe fft1D :: IsFloating e => Mode -> ForeignAcc (Vector (Complex e) -> (Vector (Complex e))) fft1D mode = ForeignAcc "fft1D" $ liftAtoC (cuFFT mode) fft2D :: IsFloating e => Mode -> ForeignAcc (Array DIM2 (Complex e) -> (Array DIM2 (Complex e))) fft2D mode = ForeignAcc "fft2D" $ liftAtoC (cuFFT mode) fft3D :: IsFloating e => Mode -> ForeignAcc (Array DIM3 (Complex e) -> (Array DIM3 (Complex e))) fft3D mode = ForeignAcc "fft3D" $ liftAtoC (cuFFT mode) liftAtoC :: forall sh e. (Shape sh, IsFloating e) => (Stream -> Array (sh:.Int) e -> LLVM PTX (Array (sh:.Int) e)) -> Stream -> Array (sh:.Int) (Complex e) -> LLVM PTX (Array (sh:.Int) (Complex e)) liftAtoC f s = case floatingType :: FloatingType e of TypeFloat{} -> c2a s <=< f s <=< a2c s TypeDouble{} -> c2a s <=< f s <=< a2c s TypeCFloat{} -> c2a s <=< f s <=< a2c s TypeCDouble{} -> c2a s <=< f s <=< a2c s -- | Call the cuFFT library to execute the FFT (inplace) -- cuFFT :: forall sh e. (Shape sh, IsFloating e) => Mode -> Stream -> Array (sh:.Int) e -> LLVM PTX (Array (sh:.Int) e) cuFFT mode stream arr = withScalarArrayPtr arr stream $ \d_arr -> liftIO $ withLifetime stream $ \st -> do let sh :. sz = shape arr p <- plan (sh :. sz `quot` 2) (undefined::e) -- recall this is an array of packed (Vec2 e) FFT.setStream p st case floatingType :: FloatingType e of TypeFloat{} -> FFT.execC2C p d_arr d_arr (signOfMode mode) >> return arr TypeDouble{} -> FFT.execZ2Z p d_arr d_arr (signOfMode mode) >> return arr TypeCFloat{} -> FFT.execC2C p d_arr d_arr (signOfMode mode) >> return arr TypeCDouble{} -> FFT.execZ2Z p d_arr d_arr (signOfMode mode) >> return arr -- | Convert an unzipped Accelerate array of complex numbers into a (new) packed -- array suitable for use with CUFFT. -- a2c :: forall sh e. (Shape sh, Elt e, IsFloating e, Storable (DevicePtrs e)) => Stream -> Array (sh:.Int) (Complex e) -> LLVM PTX (Array (sh:.Int) e) -- this is really a packed array of (Vec2 e) type a2c stream arr | FloatingDict <- floatingDict (floatingType :: FloatingType e) = do let sh :. sz = shape arr n = size sh * sz -- cs <- allocateRemote (sh :. 2*sz) withComplexArrayPtrs arr stream $ \d_re d_im -> do withScalarArrayPtr cs stream $ \d_cs -> liftIO $ do withLifetime stream $ \st -> do mdl <- twine (sizeOf (undefined::e)) pack <- CUDA.getFun mdl "interleave" dev <- CUDA.device prp <- CUDA.props dev regs <- CUDA.requires pack CUDA.NumRegs let blockSize = 256 sharedMem = 0 maxBlocks = maxResidentBlocks prp blockSize regs sharedMem numBlocks = maxBlocks `min` ((n + blockSize - 1) `div` blockSize) -- CUDA.launchKernel pack (numBlocks,1,1) (blockSize,1,1) sharedMem (Just st) [ CUDA.VArg d_cs, CUDA.VArg d_re, CUDA.VArg d_im, CUDA.IArg (fromIntegral n) ] return cs -- | Convert a packed array of complex numbers into a (new) unzipped Accelerate -- array. -- c2a :: forall sh e. (Shape sh, Elt e, IsFloating e, Storable (DevicePtrs e)) => Stream -> Array (sh:.Int) e -> LLVM PTX (Array (sh:.Int) (Complex e)) c2a stream cs | FloatingDict <- floatingDict (floatingType :: FloatingType e) = do let sh :. sz2 = shape cs sz = sz2 `quot` 2 n = size sh * sz -- arr <- allocateRemote (sh :. sz) withComplexArrayPtrs arr stream $ \d_re d_im -> do withScalarArrayPtr cs stream $ \d_cs -> liftIO $ do withLifetime stream $ \st -> do mdl <- twine (sizeOf (undefined::e)) unpack <- CUDA.getFun mdl "deinterleave" dev <- CUDA.device prp <- CUDA.props dev regs <- CUDA.requires unpack CUDA.NumRegs let blockSize = 256 sharedMem = 0 maxBlocks = maxResidentBlocks prp blockSize regs sharedMem numBlocks = maxBlocks `min` ((n + blockSize - 1) `div` blockSize) -- CUDA.launchKernel unpack (numBlocks,1,1) (blockSize,1,1) sharedMem (Just st) [ CUDA.VArg d_re, CUDA.VArg d_im, CUDA.VArg d_cs, CUDA.IArg (fromIntegral n) ] return arr -- | Generate an execute plan for a given type and size of FFT. These plans are -- cached so that subsequent invocations are quicker. -- plan :: forall sh e. (Shape sh, IsFloating e) => sh -> e -> IO FFT.Handle plan (shapeToList -> sh) _ = modifyMVar fft_plans $ \ps -> case lookup (ty, sh) ps of Just p -> return (ps, p) Nothing -> do p <- case sh of [w] -> FFT.plan1D w ty 1 [w,h] -> FFT.plan2D h w ty [w,h,d] -> FFT.plan3D d h w ty _ -> error "cuFFT only supports 1D, 2D, and 3D transforms" return (((ty,sh),p) : ps, p) where ty = case floatingType :: FloatingType e of TypeFloat{} -> FFT.C2C TypeDouble{} -> FFT.Z2Z TypeCFloat{} -> FFT.C2C TypeCDouble{} -> FFT.Z2Z -- | Load the module to convert between SoA and AoS representation for the given -- type. This is cached for subsequent reuse. -- twine :: Int -> IO CUDA.Module twine bitsize = do ctx <- fromMaybe (error "could not determine current CUDA context") `fmap` CUDA.get modifyMVar ptx_twine_modules $ \ms -> do case lookup (bitsize,ctx) ms of Just m -> return (ms, m) Nothing -> do m <- CUDA.loadData $ case bitsize of 4 -> ptx_twine_f32 8 -> ptx_twine_f64 _ -> error "cuFFT only supports Float and Double" return (((bitsize,ctx), m) : ms, m) -- | Dig out the two device pointers for an unzipped array of complex numbers. -- withComplexArrayPtrs :: forall sh e a. IsFloating e => Array sh (Complex e) -> Stream -> (DevicePtrs e -> DevicePtrs e -> LLVM PTX a) -> LLVM PTX a withComplexArrayPtrs (Array _ adata) st k | AD_Pair (AD_Pair AD_Unit ad1) ad2 <- adata = case floatingType :: FloatingType e of TypeFloat{} -> withArrayData arrayElt ad1 st $ \p1 -> withArrayData arrayElt ad2 st $ \p2 -> k p1 p2 TypeDouble{} -> withArrayData arrayElt ad1 st $ \p1 -> withArrayData arrayElt ad2 st $ \p2 -> k p1 p2 TypeCDouble{} -> withArrayData arrayElt ad1 st $ \p1 -> withArrayData arrayElt ad2 st $ \p2 -> k p1 p2 TypeCFloat{} -> withArrayData arrayElt ad1 st $ \p1 -> withArrayData arrayElt ad2 st $ \p2 -> k p1 p2 -- | Dig out the device pointer for a scalar array -- withScalarArrayPtr :: forall sh e a. IsFloating e => Array sh e -> Stream -> (DevicePtrs e -> LLVM PTX a) -> LLVM PTX a withScalarArrayPtr (Array _ ad) st k = case floatingType :: FloatingType e of TypeFloat{} -> withArrayData arrayElt ad st $ \p -> k p TypeDouble{} -> withArrayData arrayElt ad st $ \p -> k p TypeCDouble{} -> withArrayData arrayElt ad st $ \p -> k p TypeCFloat{} -> withArrayData arrayElt ad st $ \p -> k p withArrayData :: (Typeable e, Typeable a, ArrayElt e, Storable a, ArrayPtrs e ~ Ptr a) => ArrayEltR e -> ArrayData e -> Stream -> (DevicePtr a -> LLVM PTX b) -> LLVM PTX b withArrayData _ ad s k = withDevicePtr ad $ \p -> do r <- k p e <- checkpoint s return (Just e,r) type family DevicePtrs e :: * type instance DevicePtrs Float = DevicePtr Float type instance DevicePtrs Double = DevicePtr Double type instance DevicePtrs CFloat = DevicePtr Float type instance DevicePtrs CDouble = DevicePtr Double -- Cache the FFT planning step for faster repeat evaluations. {-# NOINLINE fft_plans #-} fft_plans :: MVar [((FFT.Type, [Int]), FFT.Handle)] fft_plans = unsafePerformIO $ do mv <- newMVar [] _ <- mkWeakMVar mv $ withMVar mv $ mapM_ (\(_,p) -> FFT.destroy p) return mv -- Cache the functions which convert between SoA and AoS format. {-# NOINLINE ptx_twine_modules #-} ptx_twine_modules :: MVar [((Int, CUDA.Context), CUDA.Module)] ptx_twine_modules = unsafePerformIO $ do mv <- newMVar [] _ <- mkWeakMVar mv $ withMVar mv $ mapM_ (\((_,ctx),mdl) -> bracket_ (CUDA.push ctx) CUDA.pop (CUDA.unload mdl)) return mv