{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Array.Repa.Repr.Accelerate (
A, Shapes,
fromRepa, toRepa,
computeAccS, computeAccP
) where
import Control.Monad
import qualified Data.Array.Accelerate.Array.Data as A
import qualified Data.Array.Accelerate.Sugar.Array as A
import qualified Data.Array.Accelerate.Sugar.Elt as A
import qualified Data.Array.Accelerate.Sugar.Shape as A
import qualified Data.Array.Accelerate.Representation.Array as AR
import qualified Data.Array.Repa as R
import qualified Data.Array.Repa.Eval as R
class (R.Shape r, A.Shape a) => Shapes r a | a -> r, r -> a where
toR :: a -> r
toA :: r -> a
instance Shapes R.Z A.Z where
{-# INLINE toR #-}
toR A.Z = R.Z
{-# INLINE toA #-}
toA R.Z = A.Z
instance Shapes sr sa => Shapes (sr R.:. Int) (sa A.:. Int) where
{-# INLINE toR #-}
toR (sa A.:. sz) = toR sa R.:. sz
{-# INLINE toA #-}
toA (sr R.:. sz) = toA sr A.:. sz
data A
instance A.Elt e => R.Source A e where
data Array A sh e
= AAccelerate !sh !(A.ArrayData (A.EltR e))
{-# INLINE extent #-}
extent (AAccelerate sh _)
= sh
{-# INLINE linearIndex #-}
linearIndex (AAccelerate sh adata) ix
| ix >= 0 && ix < R.size sh
= A.toElt (A.indexArrayData (A.eltR @e) adata ix)
| otherwise
= error "Repa: accelerate array out of bounds"
{-# INLINE unsafeLinearIndex #-}
unsafeLinearIndex (AAccelerate _ adata) ix
= A.toElt (A.indexArrayData (A.eltR @e) adata ix)
{-# INLINE deepSeqArray #-}
deepSeqArray (AAccelerate sh adata) x
= sh `R.deepSeq` adata `seq` x
instance A.Elt e => R.Target A e where
data MVec A e
= MAVec (A.MutableArrayData (A.EltR e))
{-# INLINE newMVec #-}
newMVec n
= MAVec `liftM` A.newArrayData (A.eltR @e) n
{-# INLINE unsafeWriteMVec #-}
unsafeWriteMVec (MAVec mad) n e
= A.writeArrayData (A.eltR @e) mad n (A.fromElt e)
{-# INLINE unsafeFreezeMVec #-}
unsafeFreezeMVec sh (MAVec mad)
= return $! AAccelerate sh mad
{-# INLINE deepSeqMVec #-}
deepSeqMVec (MAVec arr) x
= arr `seq` x
{-# INLINE touchMVec #-}
touchMVec _
= return ()
toRepa
:: Shapes sh sh'
=> A.Array sh' e -> R.Array A sh e
{-# INLINE toRepa #-}
toRepa arr@(A.Array (AR.Array _ adata))
= AAccelerate (toR (A.shape arr)) adata
fromRepa
:: (Shapes sh sh', A.Elt e)
=> R.Array A sh e -> A.Array sh' e
{-# INLINE fromRepa #-}
fromRepa (AAccelerate sh adata)
= A.Array (AR.Array (A.fromElt (toA sh)) adata)
computeAccS
:: (R.Load r sh e, A.Elt e)
=> R.Array r sh e -> R.Array A sh e
{-# INLINE computeAccS #-}
computeAccS = R.computeS
computeAccP
:: (R.Load r sh e, A.Elt e, Monad m)
=> R.Array r sh e
-> m (R.Array A sh e)
{-# INLINE computeAccP #-}
computeAccP = R.computeP