{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Math.FFT.LLVM.Native (
fft1D,
fft2D,
fft3D,
) where
import Data.Array.Accelerate.Math.FFT.Mode
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Type as A
import Data.Array.Accelerate.Array.Sugar as S
import Data.Array.Accelerate.Error as A
import Data.Array.Accelerate.Array.Data as A
import Data.Array.Accelerate.Array.Unique as A
import Data.Array.Accelerate.Data.Complex as A
import Data.Array.Accelerate.LLVM.Native.Foreign
import Data.Ix ( Ix )
import Data.Array.CArray ( CArray )
import qualified Data.Array.CArray as C
import Math.FFT.Base ( FFTWReal, Sign(..), Flag, measure, destroyInput )
import qualified Math.FFT as FFT
import Foreign.Ptr
import Foreign.Storable
import Foreign.Storable.Complex ()
import Data.Bits
import Text.Printf
import Prelude as P
fft1D :: forall e. (Elt e, IsFloating e)
=> Mode
-> ForeignAcc (Vector (Complex e) -> Vector (Complex e))
fft1D mode
= ForeignAcc (nameOf mode (undefined::DIM1))
$ case floatingType :: FloatingType e of
TypeFloat{} -> liftIO . liftAtoC go
TypeDouble{} -> liftIO . liftAtoC go
TypeCFloat{} -> liftIO . liftAtoC go
TypeCDouble{} -> liftIO . liftAtoC go
where
go :: FFTWReal r => CArray Int (Complex r) -> CArray Int (Complex r)
go = FFT.dftGU (signOf mode) flags [0]
fft2D :: forall e. (Elt e, IsFloating e)
=> Mode
-> ForeignAcc (Array DIM2 (Complex e) -> Array DIM2 (Complex e))
fft2D mode
= ForeignAcc (nameOf mode (undefined::DIM2))
$ case floatingType :: FloatingType e of
TypeFloat{} -> liftIO . liftAtoC go
TypeDouble{} -> liftIO . liftAtoC go
TypeCFloat{} -> liftIO . liftAtoC go
TypeCDouble{} -> liftIO . liftAtoC go
where
go :: FFTWReal r => CArray (Int,Int) (Complex r) -> CArray (Int,Int) (Complex r)
go = FFT.dftGU (signOf mode) flags [0,1]
fft3D :: forall e. (Elt e, IsFloating e)
=> Mode
-> ForeignAcc (Array DIM3 (Complex e) -> Array DIM3 (Complex e))
fft3D mode
= ForeignAcc (nameOf mode (undefined::DIM3))
$ case floatingType :: FloatingType e of
TypeFloat{} -> liftIO . liftAtoC go
TypeDouble{} -> liftIO . liftAtoC go
TypeCFloat{} -> liftIO . liftAtoC go
TypeCDouble{} -> liftIO . liftAtoC go
where
go :: FFTWReal r => CArray (Int,Int,Int) (Complex r) -> CArray (Int,Int,Int) (Complex r)
go = FFT.dftGU (signOf mode) flags [0,1,2]
signOf :: Mode -> Sign
signOf Forward = DFTForward
signOf _ = DFTBackward
flags :: Flag
flags = measure .|. destroyInput
nameOf :: forall sh. Shape sh => Mode -> sh -> String
nameOf Forward _ = printf "FFTW.dft%dD" (rank (undefined::sh))
nameOf _ _ = printf "FFTW.idft%dD" (rank (undefined::sh))
liftAtoC
:: (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Shape sh, Ix ix, Elt ix, Elt e, IsFloating e, Storable e', ArrayPtrs e ~ Ptr e')
=> (CArray ix (Complex e') -> CArray ix (Complex e'))
-> Array sh (Complex e)
-> IO (Array sh (Complex e))
liftAtoC f a = c2a . f =<< a2c a
a2c :: forall ix sh e e'. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Ix ix, Elt ix, Shape sh, IsFloating e, Storable e', ArrayPtrs e ~ Ptr e')
=> Array sh (Complex e)
-> IO (CArray ix (Complex e'))
a2c arr
| FloatingDict <- floatingDict (floatingType :: FloatingType e)
= let
(lo,hi) = shapeToRange (arrayShape arr)
bnds = (fromIxShapeRepr lo, fromIxShapeRepr hi)
n = S.size (arrayShape arr)
in
C.createCArray bnds $ \p_cs ->
withComplexArrayPtrs arr $ \p_re p_im ->
let
go !i | i P.>= n = return ()
go !i = do
re <- peekElemOff p_re i
im <- peekElemOff p_im i
pokeElemOff p_cs i (re :+ im)
go (i+1)
in
go 0
c2a :: forall ix sh e e'. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Ix ix, Elt ix, Shape sh, Elt e, IsFloating e, Storable e', ArrayPtrs e ~ Ptr e')
=> CArray ix (Complex e')
-> IO (Array sh (Complex e))
c2a carr
| FloatingDict <- floatingDict (floatingType :: FloatingType e)
= let
(lo,hi) = C.bounds carr
n = C.rangeSize (lo,hi)
sh = rangeToShape (toIxShapeRepr lo, toIxShapeRepr hi)
in do
arr <- allocateArray sh
C.withCArray carr $ \p_cs -> do
withComplexArrayPtrs arr $ \p_re p_im -> do
let
go !i | i P.>= n = return ()
go !i = do
re :+ im <- peekElemOff p_cs i
pokeElemOff p_re i re
pokeElemOff p_im i im
go (i+1)
go 0
return arr
type family IxShapeRepr e where
IxShapeRepr () = ()
IxShapeRepr Int = ((),Int)
IxShapeRepr (t,h) = (IxShapeRepr t, h)
fromIxShapeRepr
:: forall ix sh. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Shape sh, Elt ix)
=> sh
-> ix
fromIxShapeRepr = liftToElt (go (eltType (undefined::ix)))
where
go :: forall ix'. TupleType ix' -> IxShapeRepr ix' -> ix'
go UnitTuple () = ()
go (PairTuple tt _) (t, h) = (go tt t, h)
go (SingleTuple (NumScalarType (IntegralNumType TypeInt{}))) ((),h) = h
go _ _
= $internalError "fromIxShapeRepr" "expected Int dimensions"
toIxShapeRepr
:: forall ix sh. (IxShapeRepr (EltRepr ix) ~ EltRepr sh, Shape sh, Elt ix)
=> ix
-> sh
toIxShapeRepr = liftToElt (go (eltType (undefined::ix)))
where
go :: forall ix'. TupleType ix' -> ix' -> IxShapeRepr ix'
go UnitTuple () = ()
go (SingleTuple (NumScalarType (IntegralNumType TypeInt{}))) h = ((), h)
go (PairTuple tt _) (t, h) = (go tt t, h)
go _ _
= error "toIxShapeRepr: not a valid Data.Ix index"
withComplexArrayPtrs
:: forall sh e a. IsFloating e
=> Array sh (Complex e)
-> (ArrayPtrs e -> ArrayPtrs e -> IO a)
-> IO a
withComplexArrayPtrs (Array _ adata) k
| AD_Pair (AD_Pair AD_Unit ad1) ad2 <- adata
= case floatingType :: FloatingType e of
TypeFloat{} -> withArrayData arrayElt ad1 $ \p1 -> withArrayData arrayElt ad2 $ \p2 -> k p1 p2
TypeDouble{} -> withArrayData arrayElt ad1 $ \p1 -> withArrayData arrayElt ad2 $ \p2 -> k p1 p2
TypeCFloat{} -> withArrayData arrayElt ad1 $ \p1 -> withArrayData arrayElt ad2 $ \p2 -> k p1 p2
TypeCDouble{} -> withArrayData arrayElt ad1 $ \p1 -> withArrayData arrayElt ad2 $ \p2 -> k p1 p2
withArrayData
:: (ArrayPtrs e ~ Ptr a)
=> ArrayEltR e
-> ArrayData e
-> (Ptr a -> IO b)
-> IO b
withArrayData ArrayEltRfloat (AD_Float ua) = withUniqueArrayPtr ua
withArrayData ArrayEltRdouble (AD_Double ua) = withUniqueArrayPtr ua
withArrayData ArrayEltRcfloat (AD_CFloat ua) = withUniqueArrayPtr ua
withArrayData ArrayEltRcdouble (AD_CDouble ua) = withUniqueArrayPtr ua
withArrayData _ _ =
$internalError "withArrayData" "expected array of [C]Float or [C]Double"