module Data.Array.Accelerate.Interpreter (
run
) where
import Control.Monad
import Data.Bits
import Data.Char (chr, ord)
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Representation
import Data.Array.Accelerate.Array.Delayed
import Data.Array.Accelerate.AST
import qualified Data.Array.Accelerate.Smart as Sugar
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
run :: Sugar.Arrays a => Sugar.Acc a -> a
run = Sugar.toArrays . force . evalAcc . Sugar.convertAcc
data Val env where
Empty :: Val ()
Push :: Val env -> t -> Val (env, t)
prj :: Idx env t -> Val env -> t
prj ZeroIdx (Push _ v) = v
prj (SuccIdx idx) (Push val _) = prj idx val
prj _ _ =
error "Data.Array.Accelerate.Interpreter: prj: inconsistent valuation"
evalOpenAcc :: Delayable a => OpenAcc aenv a -> Val aenv -> Delayed a
evalOpenAcc (Let acc1 acc2) aenv
= let !arr1 = force $ evalOpenAcc acc1 aenv
in evalOpenAcc acc2 (aenv `Push` arr1)
evalOpenAcc (Let2 acc1 acc2) aenv
= let (!arr1, !arr2) = force $ evalOpenAcc acc1 aenv
in evalOpenAcc acc2 (aenv `Push` arr1 `Push` arr2)
evalOpenAcc (Avar idx) aenv = delay $ prj idx aenv
evalOpenAcc (Use arr) _aenv = delay arr
evalOpenAcc (Unit e) aenv = unitOp (evalExp e aenv)
evalOpenAcc (Reshape e acc) aenv
= reshapeOp (evalExp e aenv) (evalOpenAcc acc aenv)
evalOpenAcc (Replicate sliceIndex slix acc) aenv
= replicateOp sliceIndex (evalExp slix aenv) (evalOpenAcc acc aenv)
evalOpenAcc (Index sliceIndex acc slix) aenv
= indexOp sliceIndex (evalOpenAcc acc aenv) (evalExp slix aenv)
evalOpenAcc (Map f acc) aenv = mapOp (evalFun f aenv) (evalOpenAcc acc aenv)
evalOpenAcc (ZipWith f acc1 acc2) aenv
= zipWithOp (evalFun f aenv) (evalOpenAcc acc1 aenv) (evalOpenAcc acc2 aenv)
evalOpenAcc (Fold f e acc) aenv
= foldOp (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)
evalOpenAcc (Scan f e acc) aenv
= scanOp (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)
evalOpenAcc (Permute f dftAcc p acc) aenv
= permuteOp (evalFun f aenv) (evalOpenAcc dftAcc aenv)
(evalFun p aenv) (evalOpenAcc acc aenv)
evalOpenAcc (Backpermute e p acc) aenv
= backpermuteOp (evalExp e aenv) (evalFun p aenv) (evalOpenAcc acc aenv)
evalAcc :: Delayable a => Acc a -> Delayed a
evalAcc acc = evalOpenAcc acc Empty
unitOp :: ArrayElem e => e -> Delayed (Scalar e)
unitOp e = DelayedArray {shapeDA = (), repfDA = const e}
reshapeOp :: Ix dim => dim -> Delayed (Array dim' e) -> Delayed (Array dim e)
reshapeOp newShape darr@(DelayedArray {shapeDA = oldShape})
| size newShape == size oldShape
= let Array _ adata = force darr
in
delay $ Array newShape adata
| otherwise
= error "Data.Array.Accelerate.Interpreter.reshape: shape mismatch"
replicateOp :: Ix dim
=> SliceIndex slix sl co dim
-> slix
-> Delayed (Array sl e)
-> Delayed (Array dim e)
replicateOp sliceIndex slix (DelayedArray sh pf)
= DelayedArray sh' (pf . pf')
where
(sh', pf') = extend sliceIndex slix sh
extend :: SliceIndex slix sl co dim
-> slix
-> sl
-> (dim, dim -> sl)
extend SliceNil () () = ((), const ())
extend (SliceAll sliceIndex) (slix, ()) (sl, sz)
= let (dim', pf') = extend sliceIndex slix sl
in
((dim', sz), \(ix, i) -> (pf' ix, i))
extend (SliceFixed sliceIndex) (slix, sz) sl
= let (dim', pf') = extend sliceIndex slix sl
in
((dim', sz), \(ix, _) -> pf' ix)
indexOp :: Ix sl
=> SliceIndex slix sl co dim
-> Delayed (Array dim e)
-> slix
-> Delayed (Array sl e)
indexOp sliceIndex (DelayedArray sh pf) slix
= DelayedArray sh' (pf . pf')
where
(sh', pf') = restrict sliceIndex slix sh
restrict :: SliceIndex slix sl co dim
-> slix
-> dim
-> (sl, sl -> dim)
restrict SliceNil () () = ((), const ())
restrict (SliceAll sliceIndex) (slix, ()) (sh, sz)
= let (sl', pf') = restrict sliceIndex slix sh
in
((sl', sz), \(ix, i) -> (pf' ix, i))
restrict (SliceFixed sliceIndex) (slix, i) (sh, sz)
| i < sz
= let (sl', pf') = restrict sliceIndex slix sh
in
(sl', \ix -> (pf' ix, i))
| otherwise = error "Index out of bounds"
mapOp :: ArrayElem e'
=> (e -> e')
-> Delayed (Array dim e)
-> Delayed (Array dim e')
mapOp f (DelayedArray sh rf) = DelayedArray sh (f . rf)
zipWithOp :: ArrayElem e3
=> (e1 -> e2 -> e3)
-> Delayed (Array dim e1)
-> Delayed (Array dim e2)
-> Delayed (Array dim e3)
zipWithOp f (DelayedArray sh1 rf1) (DelayedArray sh2 rf2)
= DelayedArray (sh1 `intersect` sh2) (\ix -> f (rf1 ix) (rf2 ix))
foldOp :: (e -> e -> e)
-> e
-> Delayed (Array dim e)
-> Delayed (Scalar e)
foldOp f e (DelayedArray sh rf)
= unitOp $ iter sh rf f e
scanOp :: (e -> e -> e)
-> e
-> Delayed (Vector e)
-> Delayed (Vector e, Scalar e)
scanOp f e (DelayedArray sh rf)
= DelayedPair (delay $ adata `seq` Array sh adata) (unitOp final)
where
n = size sh
(adata, final) = runArrayData $ do
arr <- newArrayData n
final <- traverse arr 0 e
return (arr, final)
traverse arr i v
| i >= n = return v
| otherwise = do
writeArrayData arr i v
traverse arr (i + 1) (f v (rf ((), i)))
permuteOp :: (e -> e -> e)
-> Delayed (Array dim' e)
-> (dim -> dim')
-> Delayed (Array dim e)
-> Delayed (Array dim' e)
permuteOp f (DelayedArray dftsSh dftsPf) p (DelayedArray sh pf)
= delay $ adata `seq` Array dftsSh adata
where
(adata, _)
= runArrayData $ do
arr <- newArrayData (size dftsSh)
let write ix = writeArrayData arr (index dftsSh ix) (dftsPf ix)
iter dftsSh write (>>) (return ())
let update ix = do
let target = p ix
unless (target == ignore) $ do
let i = index dftsSh target
e <- readArrayData arr i
writeArrayData arr i (pf ix `f` e)
iter sh update (>>) (return ())
return (arr, undefined)
backpermuteOp :: Ix dim'
=> dim'
-> (dim' -> dim)
-> Delayed (Array dim e)
-> Delayed (Array dim' e)
backpermuteOp sh' p (DelayedArray _sh rf)
= DelayedArray sh' (rf . p)
evalOpenFun :: OpenFun env aenv t -> Val env -> Val aenv -> t
evalOpenFun (Body e) env aenv = evalOpenExp e env aenv
evalOpenFun (Lam f) env aenv = \x -> evalOpenFun f (env `Push` x) aenv
evalFun :: Fun aenv t -> Val aenv -> t
evalFun f aenv = evalOpenFun f Empty aenv
evalOpenExp :: OpenExp env aenv a -> Val env -> Val aenv -> a
evalOpenExp (Var idx) env _ = prj idx env
evalOpenExp (Const c) _ _ = Sugar.fromElem c
evalOpenExp (Pair ds dt e1 e2) env aenv
= evalPair ds dt (evalOpenExp e1 env aenv) (evalOpenExp e2 env aenv)
evalOpenExp (Fst ds dt e) env aenv
= evalFst ds dt (evalOpenExp e env aenv)
evalOpenExp (Snd ds dt e) env aenv
= evalSnd ds dt (evalOpenExp e env aenv)
evalOpenExp (Cond c t e) env aenv
= if Sugar.toElem (evalOpenExp c env aenv)
then evalOpenExp t env aenv
else evalOpenExp e env aenv
evalOpenExp (PrimConst c) _ _ = Sugar.fromElem $ evalPrimConst c
evalOpenExp (PrimApp p arg) env aenv
= Sugar.fromElem $ evalPrim p (Sugar.toElem (evalOpenExp arg env aenv))
evalOpenExp (IndexScalar acc ix) env aenv
= let ix' = evalOpenExp ix env aenv
in
case evalOpenAcc acc aenv of
DelayedArray sh pf -> index sh ix' `seq` pf ix'
evalOpenExp (Shape acc) _ aenv
= let Array sh _ = force $ evalOpenAcc acc aenv
in sh
evalExp :: Exp aenv t -> Val aenv -> t
evalExp e aenv = evalOpenExp e Empty aenv
evalPrimConst :: PrimConst a -> a
evalPrimConst (PrimMinBound ty) = evalMinBound ty
evalPrimConst (PrimMaxBound ty) = evalMaxBound ty
evalPrimConst (PrimPi ty) = evalPi ty
evalPrim :: PrimFun p -> p
evalPrim (PrimAdd ty) = evalAdd ty
evalPrim (PrimSub ty) = evalSub ty
evalPrim (PrimMul ty) = evalMul ty
evalPrim (PrimNeg ty) = evalNeg ty
evalPrim (PrimAbs ty) = evalAbs ty
evalPrim (PrimSig ty) = evalSig ty
evalPrim (PrimQuot ty) = evalQuot ty
evalPrim (PrimRem ty) = evalRem ty
evalPrim (PrimIDiv ty) = evalIDiv ty
evalPrim (PrimMod ty) = evalMod ty
evalPrim (PrimBAnd ty) = evalBAnd ty
evalPrim (PrimBOr ty) = evalBOr ty
evalPrim (PrimBXor ty) = evalBXor ty
evalPrim (PrimBNot ty) = evalBNot ty
evalPrim (PrimFDiv ty) = evalFDiv ty
evalPrim (PrimRecip ty) = evalRecip ty
evalPrim (PrimLt ty) = evalLt ty
evalPrim (PrimGt ty) = evalGt ty
evalPrim (PrimLtEq ty) = evalLtEq ty
evalPrim (PrimGtEq ty) = evalGtEq ty
evalPrim (PrimEq ty) = evalEq ty
evalPrim (PrimNEq ty) = evalNEq ty
evalPrim (PrimMax ty) = evalMax ty
evalPrim (PrimMin ty) = evalMin ty
evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim PrimOrd = evalOrd
evalPrim PrimChr = evalChr
evalPrim PrimRoundFloatInt = evalRoundFloatInt
evalPrim PrimTruncFloatInt = evalTruncFloatInt
evalPrim PrimIntFloat = evalIntFloat
evalPrim PrimBoolToInt = evalBoolToInt
evalPair :: forall s t. (Sugar.Elem s, Sugar.Elem t)
=> s
-> t
-> Sugar.ElemRepr s
-> Sugar.ElemRepr t
-> Sugar.ElemRepr (s, t)
evalPair _ _ x y = Sugar.fromElem (Sugar.toElem x :: s, Sugar.toElem y :: t)
evalFst :: forall s t. (Sugar.Elem s, Sugar.Elem t)
=> s
-> t
-> Sugar.ElemRepr (s, t)
-> Sugar.ElemRepr s
evalFst _ _ xy = let (x, !_) = Sugar.toElem xy :: (s, t)
in Sugar.fromElem x
evalSnd :: forall s t. (Sugar.Elem s, Sugar.Elem t)
=> s
-> t
-> Sugar.ElemRepr (s, t)
-> Sugar.ElemRepr t
evalSnd _ _ xy = let (!_, y) = Sugar.toElem xy :: (s, t)
in Sugar.fromElem y
evalLAnd :: (Bool, Bool) -> Bool
evalLAnd (!x, !y) = x && y
evalLOr :: (Bool, Bool) -> Bool
evalLOr (!x, !y) = x || y
evalLNot :: Bool -> Bool
evalLNot x = not x
evalOrd :: Char -> Int
evalOrd = ord
evalChr :: Int -> Char
evalChr = chr
evalRoundFloatInt :: Float -> Int
evalRoundFloatInt = round
evalTruncFloatInt :: Float -> Int
evalTruncFloatInt = truncate
evalIntFloat :: Int -> Float
evalIntFloat = fromIntegral
evalBoolToInt :: Bool -> Int
evalBoolToInt = fromEnum
evalMinBound :: BoundedType a -> a
evalMinBound (IntegralBoundedType ty)
| IntegralDict <- integralDict ty = minBound
evalMinBound (NonNumBoundedType ty)
| NonNumDict <- nonNumDict ty = minBound
evalMaxBound :: BoundedType a -> a
evalMaxBound (IntegralBoundedType ty)
| IntegralDict <- integralDict ty = maxBound
evalMaxBound (NonNumBoundedType ty)
| NonNumDict <- nonNumDict ty = maxBound
evalPi :: FloatingType a -> a
evalPi ty | FloatingDict <- floatingDict ty = pi
evalAdd :: NumType a -> ((a, a) -> a)
evalAdd (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (+)
evalAdd (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (+)
evalSub :: NumType a -> ((a, a) -> a)
evalSub (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry ()
evalSub (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry ()
evalMul :: NumType a -> ((a, a) -> a)
evalMul (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (*)
evalMul (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (*)
evalNeg :: NumType a -> (a -> a)
evalNeg (IntegralNumType ty) | IntegralDict <- integralDict ty = negate
evalNeg (FloatingNumType ty) | FloatingDict <- floatingDict ty = negate
evalAbs :: NumType a -> (a -> a)
evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = abs
evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = abs
evalSig :: NumType a -> (a -> a)
evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = signum
evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = signum
evalQuot :: IntegralType a -> ((a, a) -> a)
evalQuot ty | IntegralDict <- integralDict ty = uncurry quot
evalRem :: IntegralType a -> ((a, a) -> a)
evalRem ty | IntegralDict <- integralDict ty = uncurry rem
evalIDiv :: IntegralType a -> ((a, a) -> a)
evalIDiv ty | IntegralDict <- integralDict ty = uncurry div
evalMod :: IntegralType a -> ((a, a) -> a)
evalMod ty | IntegralDict <- integralDict ty = uncurry mod
evalBAnd :: IntegralType a -> ((a, a) -> a)
evalBAnd ty | IntegralDict <- integralDict ty = uncurry (.&.)
evalBOr :: IntegralType a -> ((a, a) -> a)
evalBOr ty | IntegralDict <- integralDict ty = uncurry (.|.)
evalBXor :: IntegralType a -> ((a, a) -> a)
evalBXor ty | IntegralDict <- integralDict ty = uncurry xor
evalBNot :: IntegralType a -> (a -> a)
evalBNot ty | IntegralDict <- integralDict ty = complement
evalFDiv :: FloatingType a -> ((a, a) -> a)
evalFDiv ty | FloatingDict <- floatingDict ty = uncurry (/)
evalRecip :: FloatingType a -> (a -> a)
evalRecip ty | FloatingDict <- floatingDict ty = recip
evalLt :: ScalarType a -> ((a, a) -> Bool)
evalLt (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (<)
evalLt (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (<)
evalLt (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (<)
evalGt :: ScalarType a -> ((a, a) -> Bool)
evalGt (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (>)
evalGt (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (>)
evalGt (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (>)
evalLtEq :: ScalarType a -> ((a, a) -> Bool)
evalLtEq (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (<=)
evalLtEq (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (<=)
evalLtEq (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (<=)
evalGtEq :: ScalarType a -> ((a, a) -> Bool)
evalGtEq (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (>=)
evalGtEq (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (>=)
evalGtEq (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (>=)
evalEq :: ScalarType a -> ((a, a) -> Bool)
evalEq (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (==)
evalEq (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (==)
evalEq (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (==)
evalNEq :: ScalarType a -> ((a, a) -> Bool)
evalNEq (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry (/=)
evalNEq (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry (/=)
evalNEq (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry (/=)
evalMax :: ScalarType a -> ((a, a) -> a)
evalMax (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry max
evalMax (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry max
evalMax (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry max
evalMin :: ScalarType a -> ((a, a) -> a)
evalMin (NumScalarType (IntegralNumType ty))
| IntegralDict <- integralDict ty = uncurry min
evalMin (NumScalarType (FloatingNumType ty))
| FloatingDict <- floatingDict ty = uncurry min
evalMin (NonNumScalarType ty)
| NonNumDict <- nonNumDict ty = uncurry min