{-# LANGUAGE TypeFamilies #-} {- | Simple manual implementation of embedding cufft functionality in the @accelerate@ framework. In this example, a plan is created once globally and must be run 'inDefaultContext'. -} module Main where import qualified Data.Array.Accelerate.LLVM.PTX.Foreign as AF import qualified Data.Array.Accelerate.LLVM.PTX as PTX import qualified Data.Array.Accelerate.Array.Sugar as Sugar import qualified Foreign.CUDA.FFT as CUFFT import qualified Foreign.CUDA.Driver as CUDA import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import qualified Data.Array.Accelerate as A import Data.Array.Accelerate.Lifetime (withLifetime) import Data.Array.Accelerate (Acc, Vector, Z(Z), (:.)((:.)), ) import Control.Exception (bracket_) transformForeign :: CUFFT.Handle -> AF.Stream -> Vector Float -> AF.LLVM AF.PTX (Vector Float) transformForeign h stream input = let (Z:.inlen) = A.arrayShape input outlen = (div inlen 2 + 1) * 2 in do output <- AF.allocateRemote (Z:.outlen) withArray input stream $ \iptr -> withArray output stream $ \optr -> AF.liftIO $ CUFFT.execR2C h iptr optr return output withArray :: (Sugar.EltRepr e ~ er, AF.ArrayPtrs er ~ Ptr er, Storable er) => A.Array sh e -> AF.Stream -> (CUDA.DevicePtr er -> AF.LLVM AF.PTX r) -> AF.LLVM AF.PTX r withArray (Sugar.Array _ adata) s k = AF.withDevicePtr adata $ \p -> do r <- k p e <- AF.checkpoint s return (Just e, r) transform :: CUFFT.Handle -> Acc (Vector Float) -> Acc (Vector Float) transform h = A.foreignAcc (AF.ForeignAcc "transformForeign" $ transformForeign h) (error "no fft fallback implemented") getBestTarget :: IO PTX.PTX getBestTarget = do CUDA.initialise [] dev <- CUDA.device 0 prop <- CUDA.props dev PTX.createTargetForDevice dev prop [CUDA.SchedAuto] atTarget :: PTX.PTX -> IO a -> IO a atTarget target act = withLifetime (AF.deviceContext $ AF.ptxContext target) $ \ctx -> bracket_ (CUDA.push ctx) CUDA.pop act main :: IO () main = do let inlen = 5 target <- getBestTarget h <- atTarget target $ CUFFT.plan1D inlen CUFFT.R2C 1 print $ PTX.run1With target (transform h) $ A.fromList (Z:.inlen) [1,0,0,0,0]