{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Math.FFT.LLVM.Native.Base
where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Math.FFT.Mode
import Data.Array.Accelerate.Math.FFT.Type
import Data.Array.Accelerate.Math.FFT.LLVM.Native.Ix
import Data.Array.CArray.Base ( CArray(..) )
import Math.FFT.Base ( Sign(..), Flag, measure, preserveInput )
import Data.Bits
import Data.Typeable
import Foreign.ForeignPtr
import Text.Printf
import Prelude as P
signOf :: Mode -> Sign
signOf Forward = DFTForward
signOf _ = DFTBackward
flags :: Flag
flags = measure .|. preserveInput
nameOf :: forall sh. Shape sh => Mode -> sh -> String
nameOf Forward _ = printf "FFTW.dft%dD" (rank (undefined::sh))
nameOf _ _ = printf "FFTW.idft%dD" (rank (undefined::sh))
{-# INLINE fromCArray #-}
fromCArray
:: forall ix sh e. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Shape sh, Elt ix, Numeric e)
=> CArray ix (Complex e)
-> IO (Array sh (Complex e))
fromCArray (CArray lo hi _ fp) = do
sh <- return $ rangeToShape (toIxShapeRepr lo, toIxShapeRepr hi) :: IO sh
ua <- newUniqueArray (castForeignPtr fp :: ForeignPtr e)
case numericR::NumericR e of
NumericRfloat32 -> return $ Array (fromElt sh) (AD_V2 (AD_Float ua))
NumericRfloat64 -> return $ Array (fromElt sh) (AD_V2 (AD_Double ua))
{-# INLINE withCArray #-}
withCArray
:: forall ix sh e a. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Shape sh, Elt ix, Numeric e)
=> Array sh (Complex e)
-> (CArray ix (Complex e) -> IO a)
-> IO a
withCArray arr f =
let
sh = shape arr
(lo, hi) = shapeToRange sh
wrap fp = CArray (fromIxShapeRepr lo) (fromIxShapeRepr hi) (size sh) (castForeignPtr fp)
in
withArray arr (f . wrap)
{-# INLINE withArray #-}
withArray
:: forall sh e a. Numeric e
=> Array sh (Complex e)
-> (ForeignPtr e -> IO a)
-> IO a
withArray (Array _ adata) = withArrayData (numericR::NumericR e) adata
{-# INLINE withArrayData #-}
withArrayData
:: NumericR e
-> ArrayData (EltRepr (Complex e))
-> (ForeignPtr e -> IO a)
-> IO a
withArrayData NumericRfloat32 (AD_V2 (AD_Float ua)) = withLifetime (uniqueArrayData ua)
withArrayData NumericRfloat64 (AD_V2 (AD_Double ua)) = withLifetime (uniqueArrayData ua)
{-# INLINE matchShapeType #-}
matchShapeType
:: forall sh sh'. (Shape sh, Shape sh')
=> sh
-> sh'
-> Maybe (sh :~: sh')
matchShapeType _ _
| Just Refl <- matchTupleType (eltType (undefined::sh)) (eltType (undefined::sh'))
= gcast Refl
matchShapeType _ _
= Nothing