module Data.Array.Accelerate.AST (
Idx(..),
Val(..), prj, idxToInt,
Arrays(..), ArraysR(..),
PreOpenAfun(..), OpenAfun, PreAfun, Afun, PreOpenAcc(..), OpenAcc(..), Acc,
Stencil(..), StencilR(..),
PreOpenFun(..), OpenFun, PreFun, Fun, PreOpenExp(..), OpenExp, PreExp, Exp, PrimConst(..),
PrimFun(..)
) where
import Data.Typeable
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Array.Representation (SliceIndex)
import Data.Array.Accelerate.Array.Delayed (Delayable)
import Data.Array.Accelerate.Array.Sugar as Sugar
#include "accelerate.h"
data Idx env t where
ZeroIdx :: Idx (env, t) t
SuccIdx :: Idx env t -> Idx (env, s) t
data Val env where
Empty :: Val ()
Push :: Val env -> t -> Val (env, t)
deriving instance Typeable1 Val
prj :: Idx env t -> Val env -> t
prj ZeroIdx (Push _ v) = v
prj (SuccIdx idx) (Push val _) = prj idx val
prj _ _ = INTERNAL_ERROR(error) "prj" "inconsistent valuation"
idxToInt :: Idx env t -> Int
idxToInt = go 0
where go :: Int -> Idx env t -> Int
go !n ZeroIdx = n
go !n (SuccIdx idx) = go (n+1) idx
class (Delayable arrs, Typeable arrs) => Arrays arrs where
arrays :: ArraysR arrs
data ArraysR arrs where
ArraysRunit :: ArraysR ()
ArraysRarray :: (Shape sh, Elt e) => ArraysR (Array sh e)
ArraysRpair :: ArraysR arrs1 -> ArraysR arrs2 -> ArraysR (arrs1, arrs2)
instance Arrays () where
arrays = ArraysRunit
instance (Shape sh, Elt e) => Arrays (Array sh e) where
arrays = ArraysRarray
instance (Arrays arrs1, Arrays arrs2) => Arrays (arrs1, arrs2) where
arrays = ArraysRpair arrays arrays
data PreOpenAfun acc aenv t where
Abody :: acc aenv t -> PreOpenAfun acc aenv t
Alam :: (Arrays as, Arrays t)
=> PreOpenAfun acc (aenv, as) t -> PreOpenAfun acc aenv (as -> t)
type OpenAfun = PreOpenAfun OpenAcc
type PreAfun acc = PreOpenAfun acc ()
type Afun = OpenAfun ()
data PreOpenAcc acc aenv a where
Let :: (Arrays bndArrs, Arrays bodyArrs)
=> acc aenv bndArrs
-> acc (aenv, bndArrs) bodyArrs
-> PreOpenAcc acc aenv bodyArrs
Let2 :: (Arrays bndArrs1, Arrays bndArrs2, Arrays bodyArrs)
=> acc aenv (bndArrs1, bndArrs2)
-> acc ((aenv, bndArrs1), bndArrs2)
bodyArrs
-> PreOpenAcc acc aenv bodyArrs
PairArrays :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
=> acc aenv (Array sh1 e1)
-> acc aenv (Array sh2 e2)
-> PreOpenAcc acc aenv (Array sh1 e1, Array sh2 e2)
Avar :: Arrays arrs
=> Idx aenv arrs
-> PreOpenAcc acc aenv arrs
Apply :: (Arrays arrs1, Arrays arrs2)
=> PreAfun acc (arrs1 -> arrs2)
-> acc aenv arrs1
-> PreOpenAcc acc aenv arrs2
Acond :: (Arrays arrs)
=> PreExp acc aenv Bool
-> acc aenv arrs
-> acc aenv arrs
-> PreOpenAcc acc aenv arrs
Use :: Array dim e
-> PreOpenAcc acc aenv (Array dim e)
Unit :: Elt e
=> PreExp acc aenv e
-> PreOpenAcc acc aenv (Scalar e)
Reshape :: (Shape sh, Shape sh', Elt e)
=> PreExp acc aenv sh
-> acc aenv (Array sh' e)
-> PreOpenAcc acc aenv (Array sh e)
Generate :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> PreOpenAcc acc aenv (Array sh e)
Replicate :: (Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix)
(EltRepr sl)
co'
(EltRepr sh)
-> PreExp acc aenv slix
-> acc aenv (Array sl e)
-> PreOpenAcc acc aenv (Array sh e)
Index :: (Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix)
(EltRepr sl)
co'
(EltRepr sh)
-> acc aenv (Array sh e)
-> PreExp acc aenv slix
-> PreOpenAcc acc aenv (Array sl e)
Map :: (Shape sh, Elt e, Elt e')
=> PreFun acc aenv (e -> e')
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh e')
ZipWith :: (Shape sh, Elt e1, Elt e2, Elt e3)
=> PreFun acc aenv (e1 -> e2 -> e3)
-> acc aenv (Array sh e1)
-> acc aenv (Array sh e2)
-> PreOpenAcc acc aenv (Array sh e3)
Fold :: (Shape sh, Elt e)
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Array (sh:.Int) e)
-> PreOpenAcc acc aenv (Array sh e)
Fold1 :: (Shape sh, Elt e)
=> PreFun acc aenv (e -> e -> e)
-> acc aenv (Array (sh:.Int) e)
-> PreOpenAcc acc aenv (Array sh e)
FoldSeg :: (Shape sh, Elt e)
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Array (sh:.Int) e)
-> acc aenv Segments
-> PreOpenAcc acc aenv (Array (sh:.Int) e)
Fold1Seg :: (Shape sh, Elt e)
=> PreFun acc aenv (e -> e -> e)
-> acc aenv (Array (sh:.Int) e)
-> acc aenv Segments
-> PreOpenAcc acc aenv (Array (sh:.Int) e)
Scanl :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e)
Scanl' :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e, Scalar e)
Scanl1 :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e)
Scanr :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e)
Scanr' :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> PreExp acc aenv e
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e, Scalar e)
Scanr1 :: Elt e
=> PreFun acc aenv (e -> e -> e)
-> acc aenv (Vector e)
-> PreOpenAcc acc aenv (Vector e)
Permute :: (Shape sh, Elt e)
=> PreFun acc aenv (e -> e -> e)
-> acc aenv (Array sh' e)
-> PreFun acc aenv (sh -> sh')
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh' e)
Backpermute :: (Shape sh, Shape sh', Elt e)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh' e)
Stencil :: (Elt e, Elt e', Stencil sh e stencil)
=> PreFun acc aenv (stencil -> e')
-> Boundary (EltRepr e)
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh e')
Stencil2 :: (Elt e1, Elt e2, Elt e',
Stencil sh e1 stencil1,
Stencil sh e2 stencil2)
=> PreFun acc aenv (stencil1 ->
stencil2 -> e')
-> Boundary (EltRepr e1)
-> acc aenv (Array sh e1)
-> Boundary (EltRepr e2)
-> acc aenv (Array sh e2)
-> PreOpenAcc acc aenv (Array sh e')
newtype OpenAcc aenv t = OpenAcc (PreOpenAcc OpenAcc aenv t)
deriving instance Typeable2 OpenAcc
type Acc = OpenAcc ()
class (Shape sh, Elt e, IsTuple stencil) => Stencil sh e stencil where
stencil :: StencilR sh e stencil
stencilAccess :: (sh -> e) -> sh -> stencil
data StencilR sh e pat where
StencilRunit3 :: (Elt e)
=> StencilR DIM1 e (e,e,e)
StencilRunit5 :: (Elt e)
=> StencilR DIM1 e (e,e,e,e,e)
StencilRunit7 :: (Elt e)
=> StencilR DIM1 e (e,e,e,e,e,e,e)
StencilRunit9 :: (Elt e)
=> StencilR DIM1 e (e,e,e,e,e,e,e,e,e)
StencilRtup3 :: (Shape sh, Elt e)
=> StencilR sh e pat1
-> StencilR sh e pat2
-> StencilR sh e pat3
-> StencilR (sh:.Int) e (pat1,pat2,pat3)
StencilRtup5 :: (Shape sh, Elt e)
=> StencilR sh e pat1
-> StencilR sh e pat2
-> StencilR sh e pat3
-> StencilR sh e pat4
-> StencilR sh e pat5
-> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5)
StencilRtup7 :: (Shape sh, Elt e)
=> StencilR sh e pat1
-> StencilR sh e pat2
-> StencilR sh e pat3
-> StencilR sh e pat4
-> StencilR sh e pat5
-> StencilR sh e pat6
-> StencilR sh e pat7
-> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7)
StencilRtup9 :: (Shape sh, Elt e)
=> StencilR sh e pat1
-> StencilR sh e pat2
-> StencilR sh e pat3
-> StencilR sh e pat4
-> StencilR sh e pat5
-> StencilR sh e pat6
-> StencilR sh e pat7
-> StencilR sh e pat8
-> StencilR sh e pat9
-> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7,pat8,pat9)
instance Elt e => Stencil DIM1 e (e, e, e) where
stencil = StencilRunit3
stencilAccess rf (Z:.y) = (rf' (y 1),
rf' y ,
rf' (y + 1))
where
rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e) where
stencil = StencilRunit5
stencilAccess rf (Z:.y) = (rf' (y 2),
rf' (y 1),
rf' y ,
rf' (y + 1),
rf' (y + 2))
where
rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e) where
stencil = StencilRunit7
stencilAccess rf (Z:.y) = (rf' (y 3),
rf' (y 2),
rf' (y 1),
rf' y ,
rf' (y + 1),
rf' (y + 2),
rf' (y + 3))
where
rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e, e, e) where
stencil = StencilRunit9
stencilAccess rf (Z:.y) = (rf' (y 4),
rf' (y 3),
rf' (y 2),
rf' (y 1),
rf' y ,
rf' (y + 1),
rf' (y + 2),
rf' (y + 3),
rf' (y + 4))
where
rf' d = rf (Z:.d)
instance (Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row3) => Stencil (sh:.Int:.Int) a (row1, row2, row3) where
stencil = StencilRtup3 stencil stencil stencil
stencilAccess rf (ix:.i) = (stencilAccess (rf' (i 1)) ix,
stencilAccess (rf' i ) ix,
stencilAccess (rf' (i + 1)) ix)
where
rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row5) => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5) where
stencil = StencilRtup5 stencil stencil stencil stencil stencil
stencilAccess rf (ix:.i) = (stencilAccess (rf' (i 2)) ix,
stencilAccess (rf' (i 1)) ix,
stencilAccess (rf' i ) ix,
stencilAccess (rf' (i + 1)) ix,
stencilAccess (rf' (i + 2)) ix)
where
rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row7)
=> Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7) where
stencil = StencilRtup7 stencil stencil stencil stencil stencil stencil stencil
stencilAccess rf (ix:.i) = (stencilAccess (rf' (i 3)) ix,
stencilAccess (rf' (i 2)) ix,
stencilAccess (rf' (i 1)) ix,
stencilAccess (rf' i ) ix,
stencilAccess (rf' (i + 1)) ix,
stencilAccess (rf' (i + 2)) ix,
stencilAccess (rf' (i + 3)) ix)
where
rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row7,
Stencil (sh:.Int) a row8,
Stencil (sh:.Int) a row9)
=> Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where
stencil = StencilRtup9 stencil stencil stencil stencil stencil stencil stencil stencil stencil
stencilAccess rf (ix:.i) = (stencilAccess (rf' (i 4)) ix,
stencilAccess (rf' (i 3)) ix,
stencilAccess (rf' (i 2)) ix,
stencilAccess (rf' (i 1)) ix,
stencilAccess (rf' i ) ix,
stencilAccess (rf' (i + 1)) ix,
stencilAccess (rf' (i + 2)) ix,
stencilAccess (rf' (i + 3)) ix,
stencilAccess (rf' (i + 4)) ix)
where
rf' d ds = rf (ds :. d)
data PreOpenFun (acc :: * -> * -> *) env aenv t where
Body :: PreOpenExp acc env aenv t -> PreOpenFun acc env aenv t
Lam :: Elt a
=> PreOpenFun acc (env, EltRepr a) aenv t -> PreOpenFun acc env aenv (a -> t)
type OpenFun = PreOpenFun OpenAcc
type PreFun acc = PreOpenFun acc ()
type Fun = OpenFun ()
data PreOpenExp (acc :: * -> * -> *) env aenv t where
Var :: Elt t
=> Idx env (EltRepr t)
-> PreOpenExp acc env aenv t
Const :: Elt t
=> EltRepr t
-> PreOpenExp acc env aenv t
Tuple :: (Elt t, IsTuple t)
=> Tuple (PreOpenExp acc env aenv) (TupleRepr t)
-> PreOpenExp acc env aenv t
Prj :: (Elt t, IsTuple t)
=> TupleIdx (TupleRepr t) e
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv e
IndexNil :: PreOpenExp acc env aenv Z
IndexCons :: (Slice sl, Elt a)
=> PreOpenExp acc env aenv sl
-> PreOpenExp acc env aenv a
-> PreOpenExp acc env aenv (sl:.a)
IndexHead :: (Slice sl, Elt a)
=> PreOpenExp acc env aenv (sl:.a)
-> PreOpenExp acc env aenv a
IndexTail :: (Slice sl, Elt a)
=> PreOpenExp acc env aenv (sl:.a)
-> PreOpenExp acc env aenv sl
IndexAny :: Shape sh => PreOpenExp acc env aenv (Any sh)
Cond :: PreOpenExp acc env aenv Bool
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv t
PrimConst :: Elt t
=> PrimConst t
-> PreOpenExp acc env aenv t
PrimApp :: (Elt a, Elt r)
=> PrimFun (a -> r)
-> PreOpenExp acc env aenv a
-> PreOpenExp acc env aenv r
IndexScalar :: (Shape dim, Elt t)
=> acc aenv (Array dim t)
-> PreOpenExp acc env aenv dim
-> PreOpenExp acc env aenv t
Shape :: (Shape dim, Elt e)
=> acc aenv (Array dim e)
-> PreOpenExp acc env aenv dim
Size :: (Shape dim, Elt e)
=> acc aenv (Array dim e)
-> PreOpenExp acc env aenv Int
type OpenExp = PreOpenExp OpenAcc
type PreExp acc = PreOpenExp acc ()
type Exp = OpenExp ()
data PrimConst ty where
PrimMinBound :: BoundedType a -> PrimConst a
PrimMaxBound :: BoundedType a -> PrimConst a
PrimPi :: FloatingType a -> PrimConst a
data PrimFun sig where
PrimAdd :: NumType a -> PrimFun ((a, a) -> a)
PrimSub :: NumType a -> PrimFun ((a, a) -> a)
PrimMul :: NumType a -> PrimFun ((a, a) -> a)
PrimNeg :: NumType a -> PrimFun (a -> a)
PrimAbs :: NumType a -> PrimFun (a -> a)
PrimSig :: NumType a -> PrimFun (a -> a)
PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a)
PrimRem :: IntegralType a -> PrimFun ((a, a) -> a)
PrimIDiv :: IntegralType a -> PrimFun ((a, a) -> a)
PrimMod :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBNot :: IntegralType a -> PrimFun (a -> a)
PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a)
PrimRecip :: FloatingType a -> PrimFun (a -> a)
PrimSin :: FloatingType a -> PrimFun (a -> a)
PrimCos :: FloatingType a -> PrimFun (a -> a)
PrimTan :: FloatingType a -> PrimFun (a -> a)
PrimAsin :: FloatingType a -> PrimFun (a -> a)
PrimAcos :: FloatingType a -> PrimFun (a -> a)
PrimAtan :: FloatingType a -> PrimFun (a -> a)
PrimAsinh :: FloatingType a -> PrimFun (a -> a)
PrimAcosh :: FloatingType a -> PrimFun (a -> a)
PrimAtanh :: FloatingType a -> PrimFun (a -> a)
PrimExpFloating :: FloatingType a -> PrimFun (a -> a)
PrimSqrt :: FloatingType a -> PrimFun (a -> a)
PrimLog :: FloatingType a -> PrimFun (a -> a)
PrimFPow :: FloatingType a -> PrimFun ((a, a) -> a)
PrimLogBase :: FloatingType a -> PrimFun ((a, a) -> a)
PrimAtan2 :: FloatingType a -> PrimFun ((a, a) -> a)
PrimTruncate :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimRound :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimFloor :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimCeiling :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimLt :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimGt :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimNEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimMax :: ScalarType a -> PrimFun ((a, a) -> a )
PrimMin :: ScalarType a -> PrimFun ((a, a) -> a )
PrimLAnd :: PrimFun ((Bool, Bool) -> Bool)
PrimLOr :: PrimFun ((Bool, Bool) -> Bool)
PrimLNot :: PrimFun (Bool -> Bool)
PrimOrd :: PrimFun (Char -> Int)
PrimChr :: PrimFun (Int -> Char)
PrimBoolToInt :: PrimFun (Bool -> Int)
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)