module Data.Array.Accelerate.AST (
Idx(..),
Val(..), prj,
OpenAcc(..), Acc, Stencil(..),
OpenFun(..), Fun, OpenExp(..), Exp, PrimConst(..), PrimFun(..)
) where
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Representation (SliceIndex)
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Tuple
#include "accelerate.h"
data Idx env t where
ZeroIdx :: Idx (env, t) t
SuccIdx :: Idx env t -> Idx (env, s) t
data Val env where
Empty :: Val ()
Push :: Val env -> t -> Val (env, t)
prj :: Idx env t -> Val env -> t
prj ZeroIdx (Push _ v) = v
prj (SuccIdx idx) (Push val _) = prj idx val
prj _ _ =
INTERNAL_ERROR(error) "prj" "inconsistent valuation"
data OpenAcc aenv a where
Let :: OpenAcc aenv (Array dim e)
-> OpenAcc (aenv, Array dim e)
(Array dim' e')
-> OpenAcc aenv (Array dim' e')
Let2 :: OpenAcc aenv (Array dim1 e1,
Array dim2 e2)
-> OpenAcc ((aenv, Array dim1 e1),
Array dim2 e2)
(Array dim' e')
-> OpenAcc aenv (Array dim' e')
Avar :: (Ix dim, Elem e)
=> Idx aenv (Array dim e)
-> OpenAcc aenv (Array dim e)
Use :: Array dim e
-> OpenAcc aenv (Array dim e)
Unit :: Elem e
=> Exp aenv e
-> OpenAcc aenv (Scalar e)
Reshape :: Ix dim
=> Exp aenv dim
-> OpenAcc aenv (Array dim' e)
-> OpenAcc aenv (Array dim e)
Replicate :: (Ix dim, Elem slix)
=> SliceIndex (ElemRepr slix)
(ElemRepr sl)
co'
(ElemRepr dim)
-> Exp aenv slix
-> OpenAcc aenv (Array sl e)
-> OpenAcc aenv (Array dim e)
Index :: (Ix sl, Elem slix)
=> SliceIndex (ElemRepr slix)
(ElemRepr sl)
co'
(ElemRepr dim)
-> OpenAcc aenv (Array dim e)
-> Exp aenv slix
-> OpenAcc aenv (Array sl e)
Map :: Elem e'
=> Fun aenv (e -> e')
-> OpenAcc aenv (Array dim e)
-> OpenAcc aenv (Array dim e')
ZipWith :: Elem e3
=> Fun aenv (e1 -> e2 -> e3)
-> OpenAcc aenv (Array dim e1)
-> OpenAcc aenv (Array dim e2)
-> OpenAcc aenv (Array dim e3)
Fold :: Fun aenv (e -> e -> e)
-> Exp aenv e
-> OpenAcc aenv (Array dim e)
-> OpenAcc aenv (Scalar e)
FoldSeg :: Fun aenv (e -> e -> e)
-> Exp aenv e
-> OpenAcc aenv (Vector e)
-> OpenAcc aenv Segments
-> OpenAcc aenv (Vector e)
Scanl :: Fun aenv (e -> e -> e)
-> Exp aenv e
-> OpenAcc aenv (Vector e)
-> OpenAcc aenv (Vector e, Scalar e)
Scanr :: Fun aenv (e -> e -> e)
-> Exp aenv e
-> OpenAcc aenv (Vector e)
-> OpenAcc aenv (Vector e, Scalar e)
Permute :: Fun aenv (e -> e -> e)
-> OpenAcc aenv (Array dim' e)
-> Fun aenv (dim -> dim')
-> OpenAcc aenv (Array dim e)
-> OpenAcc aenv (Array dim' e)
Backpermute :: Ix dim'
=> Exp aenv dim'
-> Fun aenv (dim' -> dim)
-> OpenAcc aenv (Array dim e)
-> OpenAcc aenv (Array dim' e)
Stencil :: (Elem e, Elem e', Stencil dim e stencil)
=> Fun aenv (stencil -> e')
-> Boundary (ElemRepr e)
-> OpenAcc aenv (Array dim e)
-> OpenAcc aenv (Array dim e')
Stencil2 :: (Elem e1, Elem e2, Elem e',
Stencil dim e1 stencil1,
Stencil dim e2 stencil2)
=> Fun aenv (stencil1 ->
stencil2 -> e')
-> Boundary (ElemRepr e1)
-> OpenAcc aenv (Array dim e1)
-> Boundary (ElemRepr e2)
-> OpenAcc aenv (Array dim e2)
-> OpenAcc aenv (Array dim e')
type Acc a = OpenAcc () a
class IsTuple stencil => Stencil dim e stencil where
stencilAccess :: (dim -> e) -> dim -> stencil
instance Stencil DIM1 a (a, a, a) where
stencilAccess rf ix = (rf (ix 1), rf ix, rf (ix + 1))
instance Stencil DIM1 a (a, a, a, a, a) where
stencilAccess rf ix = (rf (ix 2), rf (ix 1), rf ix, rf (ix + 1), rf (ix + 2))
instance Stencil DIM1 a (a, a, a, a, a, a, a) where
stencilAccess rf ix = (rf (ix 3), rf (ix 2), rf (ix 1), rf ix,
rf (ix + 1), rf (ix + 2), rf (ix + 3))
instance Stencil DIM1 a (a, a, a, a, a, a, a, a, a) where
stencilAccess rf ix = (rf (ix 4), rf (ix 3), rf (ix 2), rf (ix 1), rf ix,
rf (ix + 1), rf (ix + 2), rf (ix + 3), rf (ix + 4))
instance (Stencil DIM1 a row2,
Stencil DIM1 a row1,
Stencil DIM1 a row0) => Stencil DIM2 a (row2, row1, row0) where
stencilAccess rf (x, y) = (stencilAccess (rf' (y 1)) x,
stencilAccess (rf' y ) x,
stencilAccess (rf' (y + 1)) x)
where
rf' y x = rf (x, y)
instance (Stencil DIM1 a row1,
Stencil DIM1 a row2,
Stencil DIM1 a row3,
Stencil DIM1 a row4,
Stencil DIM1 a row5) => Stencil DIM2 a (row1, row2, row3, row4, row5) where
stencilAccess rf (x, y) = (stencilAccess (rf' (y 2)) x,
stencilAccess (rf' (y 1)) x,
stencilAccess (rf' y ) x,
stencilAccess (rf' (y + 1)) x,
stencilAccess (rf' (y + 2)) x)
where
rf' y x = rf (x, y)
instance (Stencil DIM1 a row1,
Stencil DIM1 a row2,
Stencil DIM1 a row3,
Stencil DIM1 a row4,
Stencil DIM1 a row5,
Stencil DIM1 a row6,
Stencil DIM1 a row7) => Stencil DIM2 a (row1, row2, row3, row4, row5, row6, row7) where
stencilAccess rf (x, y) = (stencilAccess (rf' (y 3)) x,
stencilAccess (rf' (y 2)) x,
stencilAccess (rf' (y 1)) x,
stencilAccess (rf' y ) x,
stencilAccess (rf' (y + 1)) x,
stencilAccess (rf' (y + 2)) x,
stencilAccess (rf' (y + 3)) x)
where
rf' y x = rf (x, y)
instance (Stencil DIM1 a row1,
Stencil DIM1 a row2,
Stencil DIM1 a row3,
Stencil DIM1 a row4,
Stencil DIM1 a row5,
Stencil DIM1 a row6,
Stencil DIM1 a row7,
Stencil DIM1 a row8,
Stencil DIM1 a row9)
=> Stencil DIM2 a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where
stencilAccess rf (x, y) = (stencilAccess (rf' (y 4)) x,
stencilAccess (rf' (y 3)) x,
stencilAccess (rf' (y 2)) x,
stencilAccess (rf' (y 1)) x,
stencilAccess (rf' y ) x,
stencilAccess (rf' (y + 1)) x,
stencilAccess (rf' (y + 2)) x,
stencilAccess (rf' (y + 3)) x,
stencilAccess (rf' (y + 4)) x)
where
rf' y x = rf (x, y)
instance (Stencil DIM2 a row1,
Stencil DIM2 a row2,
Stencil DIM2 a row3) => Stencil DIM3 a (row1, row2, row3) where
stencilAccess rf (x, y, z) = (stencilAccess (rf' (z 1)) (x, y),
stencilAccess (rf' z ) (x, y),
stencilAccess (rf' (z + 1)) (x, y))
where
rf' z (x, y) = rf (x, y, z)
instance (Stencil DIM2 a row1,
Stencil DIM2 a row2,
Stencil DIM2 a row3,
Stencil DIM2 a row4,
Stencil DIM2 a row5) => Stencil DIM3 a (row1, row2, row3, row4, row5) where
stencilAccess rf (x, y, z) = (stencilAccess (rf' (z 2)) (x, y),
stencilAccess (rf' (z 1)) (x, y),
stencilAccess (rf' z ) (x, y),
stencilAccess (rf' (z + 1)) (x, y),
stencilAccess (rf' (z + 2)) (x, y))
where
rf' z (x, y) = rf (x, y, z)
instance (Stencil DIM2 a row1,
Stencil DIM2 a row2,
Stencil DIM2 a row3,
Stencil DIM2 a row4,
Stencil DIM2 a row5,
Stencil DIM2 a row6,
Stencil DIM2 a row7) => Stencil DIM3 a (row1, row2, row3, row4, row5, row6, row7) where
stencilAccess rf (x, y, z) = (stencilAccess (rf' (z 3)) (x, y),
stencilAccess (rf' (z 2)) (x, y),
stencilAccess (rf' (z 1)) (x, y),
stencilAccess (rf' z ) (x, y),
stencilAccess (rf' (z + 1)) (x, y),
stencilAccess (rf' (z + 2)) (x, y),
stencilAccess (rf' (z + 3)) (x, y))
where
rf' z (x, y) = rf (x, y, z)
instance (Stencil DIM2 a row1,
Stencil DIM2 a row2,
Stencil DIM2 a row3,
Stencil DIM2 a row4,
Stencil DIM2 a row5,
Stencil DIM2 a row6,
Stencil DIM2 a row7,
Stencil DIM2 a row8,
Stencil DIM2 a row9)
=> Stencil DIM3 a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where
stencilAccess rf (x, y, z) = (stencilAccess (rf' (z 4)) (x, y),
stencilAccess (rf' (z 3)) (x, y),
stencilAccess (rf' (z 2)) (x, y),
stencilAccess (rf' (z 1)) (x, y),
stencilAccess (rf' z ) (x, y),
stencilAccess (rf' (z + 1)) (x, y),
stencilAccess (rf' (z + 2)) (x, y),
stencilAccess (rf' (z + 3)) (x, y),
stencilAccess (rf' (z + 4)) (x, y))
where
rf' z (x, y) = rf (x, y, z)
data OpenFun env aenv t where
Body :: OpenExp env aenv t -> OpenFun env aenv t
Lam :: Elem a
=> OpenFun (env, ElemRepr a) aenv t -> OpenFun env aenv (a -> t)
type Fun aenv t = OpenFun () aenv t
data OpenExp env aenv t where
Var :: Elem t
=> Idx env (ElemRepr t)
-> OpenExp env aenv t
Const :: Elem t
=> ElemRepr t
-> OpenExp env aenv t
Tuple :: (Elem t, IsTuple t)
=> Tuple (OpenExp env aenv) (TupleRepr t)
-> OpenExp env aenv t
Prj :: (Elem t, IsTuple t)
=> TupleIdx (TupleRepr t) e
-> OpenExp env aenv t
-> OpenExp env aenv e
Cond :: OpenExp env aenv Bool
-> OpenExp env aenv t
-> OpenExp env aenv t
-> OpenExp env aenv t
PrimConst :: Elem t
=> PrimConst t -> OpenExp env aenv t
PrimApp :: (Elem a, Elem r)
=> PrimFun (a -> r)
-> OpenExp env aenv a
-> OpenExp env aenv r
IndexScalar :: OpenAcc aenv (Array dim t)
-> OpenExp env aenv dim
-> OpenExp env aenv t
Shape :: Elem dim
=> OpenAcc aenv (Array dim e)
-> OpenExp env aenv dim
type Exp aenv t = OpenExp () aenv t
data PrimConst ty where
PrimMinBound :: BoundedType a -> PrimConst a
PrimMaxBound :: BoundedType a -> PrimConst a
PrimPi :: FloatingType a -> PrimConst a
data PrimFun sig where
PrimAdd :: NumType a -> PrimFun ((a, a) -> a)
PrimSub :: NumType a -> PrimFun ((a, a) -> a)
PrimMul :: NumType a -> PrimFun ((a, a) -> a)
PrimNeg :: NumType a -> PrimFun (a -> a)
PrimAbs :: NumType a -> PrimFun (a -> a)
PrimSig :: NumType a -> PrimFun (a -> a)
PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a)
PrimRem :: IntegralType a -> PrimFun ((a, a) -> a)
PrimIDiv :: IntegralType a -> PrimFun ((a, a) -> a)
PrimMod :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBNot :: IntegralType a -> PrimFun (a -> a)
PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a)
PrimRecip :: FloatingType a -> PrimFun (a -> a)
PrimSin :: FloatingType a -> PrimFun (a -> a)
PrimCos :: FloatingType a -> PrimFun (a -> a)
PrimTan :: FloatingType a -> PrimFun (a -> a)
PrimAsin :: FloatingType a -> PrimFun (a -> a)
PrimAcos :: FloatingType a -> PrimFun (a -> a)
PrimAtan :: FloatingType a -> PrimFun (a -> a)
PrimAsinh :: FloatingType a -> PrimFun (a -> a)
PrimAcosh :: FloatingType a -> PrimFun (a -> a)
PrimAtanh :: FloatingType a -> PrimFun (a -> a)
PrimExpFloating :: FloatingType a -> PrimFun (a -> a)
PrimSqrt :: FloatingType a -> PrimFun (a -> a)
PrimLog :: FloatingType a -> PrimFun (a -> a)
PrimFPow :: FloatingType a -> PrimFun ((a,a) -> a)
PrimLogBase :: FloatingType a -> PrimFun ((a,a) -> a)
PrimAtan2 :: FloatingType a -> PrimFun ((a,a) -> a)
PrimLt :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimGt :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimNEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
PrimMax :: ScalarType a -> PrimFun ((a, a) -> a )
PrimMin :: ScalarType a -> PrimFun ((a, a) -> a )
PrimLAnd :: PrimFun ((Bool, Bool) -> Bool)
PrimLOr :: PrimFun ((Bool, Bool) -> Bool)
PrimLNot :: PrimFun (Bool -> Bool)
PrimOrd :: PrimFun (Char -> Int)
PrimChr :: PrimFun (Int -> Char)
PrimRoundFloatInt :: PrimFun (Float -> Int)
PrimTruncFloatInt :: PrimFun (Float -> Int)
PrimIntFloat :: PrimFun (Int -> Float)
PrimBoolToInt :: PrimFun (Bool -> Int)