{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Accelerate.Math.FFT.Twine
where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex
#if defined(ACCELERATE_CUDA_BACKEND) || defined(ACCELERATE_LLVM_PTX_BACKEND)
import Data.FileEmbed
import Data.ByteString ( ByteString )
#endif
{-# NOINLINE interleave #-}
interleave :: Elt e => Acc (Vector (Complex e)) -> Acc (Vector e)
interleave arr = generate sh swizzle
where
reals = A.map real arr
imags = A.map imag arr
sh = index1 (2 * A.size arr)
swizzle ix =
let i = indexHead ix
(j,k) = i `quotRem` 2
in
k A.== 0 ? ( reals A.!! j, imags A.!! j )
{-# NOINLINE deinterleave #-}
deinterleave :: forall e. Elt e => Acc (Vector e) -> Acc (Vector (Complex e))
deinterleave arr = generate sh swizzle
where
sh = index1 (A.size arr `quot` 2)
swizzle ix =
let i = indexHead ix `quot` 2
in lift ( arr A.!! i :+ arr A.!! (i+1) ) :: Exp (Complex e)
{-# RULES
"interleave/deinterleave" forall x. deinterleave (interleave x) = x;
"deinterleave/interleave" forall x. interleave (deinterleave x) = x
#-}
#if defined(ACCELERATE_CUDA_BACKEND) || defined(ACCELERATE_LLVM_PTX_BACKEND)
ptx_twine_f32 :: ByteString
ptx_twine_f32 = $(makeRelativeToProject "cubits/twine_f32.ptx" >>= embedFile)
ptx_twine_f64 :: ByteString
ptx_twine_f64 = $(makeRelativeToProject "cubits/twine_f64.ptx" >>= embedFile)
#endif