{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Exp
where
import Control.Applicative hiding ( Const )
import Control.Monad
import Prelude hiding ( exp, any )
import qualified Data.IntMap as IM
import Data.Array.Accelerate.AST hiding ( Val(..), prj )
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar hiding ( Foreign, toTuple, shape, intersect, union )
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Sugar as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad ( CodeGen )
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Foreign
import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L
import qualified Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
{-# INLINEABLE llvmOfFun1 #-}
llvmOfFun1
:: Foreign arch
=> arch
-> DelayedFun aenv (a -> b)
-> Gamma aenv
-> IRFun1 arch aenv (a -> b)
llvmOfFun1 arch (Lam (Body body)) aenv = IRFun1 $ \x -> llvmOfOpenExp arch body (Empty `Push` x) aenv
llvmOfFun1 _ _ _ = $internalError "llvmOfFun1" "impossible evaluation"
{-# INLINEABLE llvmOfFun2 #-}
llvmOfFun2
:: Foreign arch
=> arch
-> DelayedFun aenv (a -> b -> c)
-> Gamma aenv
-> IRFun2 arch aenv (a -> b -> c)
llvmOfFun2 arch (Lam (Lam (Body body))) aenv = IRFun2 $ \x y -> llvmOfOpenExp arch body (Empty `Push` x `Push` y) aenv
llvmOfFun2 _ _ _ = $internalError "llvmOfFun2" "impossible evaluation"
{-# INLINEABLE llvmOfOpenExp #-}
llvmOfOpenExp
:: forall arch env aenv _t. Foreign arch
=> arch
-> DelayedOpenExp env aenv _t
-> Val env
-> Gamma aenv
-> IROpenExp arch env aenv _t
llvmOfOpenExp arch top env aenv = cvtE top
where
cvtM :: DelayedOpenAcc aenv (Array sh e) -> IRManifest arch aenv (Array sh e)
cvtM (Manifest (Avar ix)) = IRManifest ix
cvtM _ = $internalError "llvmOfOpenExp" "expected manifest array variable"
cvtF1 :: DelayedOpenFun env aenv (a -> b) -> IROpenFun1 arch env aenv (a -> b)
cvtF1 (Lam (Body body)) = IRFun1 $ \x -> llvmOfOpenExp arch body (env `Push` x) aenv
cvtF1 _ = $internalError "cvtF1" "impossible evaluation"
cvtE :: forall t. DelayedOpenExp env aenv t -> IROpenExp arch env aenv t
cvtE exp =
case exp of
Let bnd body -> do x <- cvtE bnd
llvmOfOpenExp arch body (env `Push` x) aenv
Var ix -> return $ prj ix env
Const c -> return $ IR (constant (eltType (undefined::t)) c)
PrimConst c -> return $ IR (constant (eltType (undefined::t)) (fromElt (primConst c)))
PrimApp f x -> primFun f x
IndexNil -> return indexNil
IndexAny -> return indexAny
IndexCons sh sz -> indexCons <$> cvtE sh <*> cvtE sz
IndexHead ix -> indexHead <$> cvtE ix
IndexTail ix -> indexTail <$> cvtE ix
Prj ix tup -> prjT ix <$> cvtE tup
Tuple tup -> cvtT tup
Foreign asm f x -> foreignE asm f =<< cvtE x
Cond c t e -> A.ifThenElse (cvtE c) (cvtE t) (cvtE e)
IndexSlice slice slix sh -> indexSlice slice <$> cvtE slix <*> cvtE sh
IndexFull slice slix sh -> indexFull slice <$> cvtE slix <*> cvtE sh
ToIndex sh ix -> join $ intOfIndex <$> cvtE sh <*> cvtE ix
FromIndex sh ix -> join $ indexOfInt <$> cvtE sh <*> cvtE ix
Index acc ix -> index (cvtM acc) =<< cvtE ix
LinearIndex acc ix -> linearIndex (cvtM acc) =<< cvtE ix
ShapeSize sh -> shapeSize =<< cvtE sh
Shape acc -> return $ shape (cvtM acc)
Intersect sh1 sh2 -> join $ intersect <$> cvtE sh1 <*> cvtE sh2
Union sh1 sh2 -> join $ union <$> cvtE sh1 <*> cvtE sh2
While c f x -> while (cvtF1 c) (cvtF1 f) (cvtE x)
indexNil :: IR Z
indexNil = IR (constant (eltType Z) (fromElt Z))
indexAny :: forall sh. Shape sh => IR (Any sh)
indexAny = let any = Any :: Any sh
in IR (constant (eltType any) (fromElt any))
indexSlice :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> IR slix
-> IR sh
-> IR sl
indexSlice slice (IR slix) (IR sh) = IR $ restrict slice slix sh
where
restrict :: SliceIndex slix sl co sh -> Operands slix -> Operands sh -> Operands sl
restrict SliceNil OP_Unit OP_Unit = OP_Unit
restrict (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) =
let sl' = restrict sliceIdx slx sl
in OP_Pair sl' sz
restrict (SliceFixed sliceIdx) (OP_Pair slx _i) (OP_Pair sl _sz) =
restrict sliceIdx slx sl
indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> IR slix
-> IR sl
-> IR sh
indexFull slice (IR slix) (IR sh) = IR $ extend slice slix sh
where
extend :: SliceIndex slix sl co sh -> Operands slix -> Operands sl -> Operands sh
extend SliceNil OP_Unit OP_Unit = OP_Unit
extend (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) =
let sh' = extend sliceIdx slx sl
in OP_Pair sh' sz
extend (SliceFixed sliceIdx) (OP_Pair slx sz) sl =
let sh' = extend sliceIdx slx sl
in OP_Pair sh' sz
prjT :: forall t e. (Elt t, Elt e) => TupleIdx (TupleRepr t) e -> IR t -> IR e
prjT tix (IR ops) = IR $ go tix (eltType (undefined::t)) ops
where
go :: TupleIdx v e -> TupleType t' -> Operands t' -> Operands (EltRepr e)
go ZeroTupIdx (PairTuple _ t) (OP_Pair _ v)
| Just Refl <- matchTupleType t (eltType (undefined :: e))
= v
go (SuccTupIdx ix) (PairTuple t _) (OP_Pair tup _) = go ix t tup
go _ _ _ = $internalError "prjT" "inconsistent valuation"
cvtT :: forall t. (Elt t, IsTuple t) => Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (IR t)
cvtT tup = IR <$> go (eltType (undefined::t)) tup
where
go :: TupleType t' -> Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Operands t')
go UnitTuple NilTup
= return OP_Unit
go (PairTuple ta tb) (SnocTup a (b :: DelayedOpenExp env aenv b))
| Just Refl <- matchTupleType tb (eltType (undefined::b))
= do a' <- go ta a
IR b' <- cvtE b
return $ OP_Pair a' b'
go _ _ = $internalError "cvtT" "impossible evaluation"
linearIndex :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR Int -> CodeGen (IR e)
linearIndex (IRManifest v) ix =
readArray (irArray (aprj v aenv)) ix
index :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh -> CodeGen (IR e)
index (IRManifest v) ix =
let arr = irArray (aprj v aenv)
in readArray arr =<< intOfIndex (irArrayShape arr) ix
shape :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh
shape (IRManifest v) = irArrayShape (irArray (aprj v aenv))
shapeSize :: forall sh. Shape sh => IR sh -> CodeGen (IR Int)
shapeSize (IR extent) = go (eltType (undefined::sh)) extent
where
go :: TupleType t -> Operands t -> CodeGen (IR Int)
go UnitTuple OP_Unit
= return $ IR (constant (eltType (undefined :: Int)) 1)
go (PairTuple tsh t) (OP_Pair sh sz)
| Just Refl <- matchTupleType t (eltType (undefined::Int))
= do
a <- go tsh sh
b <- A.mul numType a (IR sz)
return b
go (SingleTuple t) (op' t -> i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return $ ir t i
go _ _
= $internalError "shapeSize" "expected shape with Int components"
intersect :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
intersect (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
where
go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
go UnitTuple OP_Unit OP_Unit
= return OP_Unit
go (SingleTuple t) sh1 sh2
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= do IR x <- A.min t (IR sh1) (IR sh2)
return x
go (PairTuple tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
= do
sz' <- go tsz sz1 sz2
sh' <- go tsh sh1 sh2
return $ OP_Pair sh' sz'
go _ _ _
= $internalError "intersect" "expected shape with Int components"
union :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
union (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
where
go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
go UnitTuple OP_Unit OP_Unit
= return OP_Unit
go (SingleTuple t) sh1 sh2
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= do IR x <- A.max t (IR sh1) (IR sh2)
return x
go (PairTuple tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
= do
sz' <- go tsz sz1 sz2
sh' <- go tsh sh1 sh2
return $ OP_Pair sh' sz'
go _ _ _
= $internalError "union" "expected shape with Int components"
while :: Elt a
=> IROpenFun1 arch env aenv (a -> Bool)
-> IROpenFun1 arch env aenv (a -> a)
-> IROpenExp arch env aenv a
-> IROpenExp arch env aenv a
while p f x =
L.while (app1 p) (app1 f) =<< x
foreignE :: (Elt a, Elt b, Foreign arch, A.Foreign asm)
=> asm (a -> b)
-> DelayedFun () (a -> b)
-> IR a
-> IRExp arch () b
foreignE asm no x =
case foreignExp arch asm of
Just f -> app1 f x
Nothing | Lam (Body b) <- no -> llvmOfOpenExp arch b (Empty `Push` x) IM.empty
_ -> error "when a grid's misaligned with another behind / that's a moiré..."
primFun :: Elt r
=> PrimFun (a -> r)
-> DelayedOpenExp env aenv a
-> CodeGen (IR r)
primFun f x =
let
inl :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv a
inl (Tuple (SnocTup (SnocTup NilTup a) _)) = cvtE a
inl t = cvtE $ Prj (SuccTupIdx ZeroTupIdx) t
inr :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv b
inr (Tuple (SnocTup _ b)) = cvtE b
inr t = cvtE $ Prj ZeroTupIdx t
in
case f of
PrimAdd t -> A.uncurry (A.add t) =<< cvtE x
PrimSub t -> A.uncurry (A.sub t) =<< cvtE x
PrimMul t -> A.uncurry (A.mul t) =<< cvtE x
PrimNeg t -> A.negate t =<< cvtE x
PrimAbs t -> A.abs t =<< cvtE x
PrimSig t -> A.signum t =<< cvtE x
PrimQuot t -> A.uncurry (A.quot t) =<< cvtE x
PrimRem t -> A.uncurry (A.rem t) =<< cvtE x
PrimQuotRem t -> A.uncurry (A.quotRem t) =<< cvtE x
PrimIDiv t -> A.uncurry (A.idiv t) =<< cvtE x
PrimMod t -> A.uncurry (A.mod t) =<< cvtE x
PrimDivMod t -> A.uncurry (A.divMod t) =<< cvtE x
PrimBAnd t -> A.uncurry (A.band t) =<< cvtE x
PrimBOr t -> A.uncurry (A.bor t) =<< cvtE x
PrimBXor t -> A.uncurry (A.xor t) =<< cvtE x
PrimBNot t -> A.complement t =<< cvtE x
PrimBShiftL t -> A.uncurry (A.shiftL t) =<< cvtE x
PrimBShiftR t -> A.uncurry (A.shiftR t) =<< cvtE x
PrimBRotateL t -> A.uncurry (A.rotateL t) =<< cvtE x
PrimBRotateR t -> A.uncurry (A.rotateR t) =<< cvtE x
PrimPopCount t -> A.popCount t =<< cvtE x
PrimCountLeadingZeros t -> A.countLeadingZeros t =<< cvtE x
PrimCountTrailingZeros t -> A.countTrailingZeros t =<< cvtE x
PrimFDiv t -> A.uncurry (A.fdiv t) =<< cvtE x
PrimRecip t -> A.recip t =<< cvtE x
PrimSin t -> A.sin t =<< cvtE x
PrimCos t -> A.cos t =<< cvtE x
PrimTan t -> A.tan t =<< cvtE x
PrimSinh t -> A.sinh t =<< cvtE x
PrimCosh t -> A.cosh t =<< cvtE x
PrimTanh t -> A.tanh t =<< cvtE x
PrimAsin t -> A.asin t =<< cvtE x
PrimAcos t -> A.acos t =<< cvtE x
PrimAtan t -> A.atan t =<< cvtE x
PrimAsinh t -> A.asinh t =<< cvtE x
PrimAcosh t -> A.acosh t =<< cvtE x
PrimAtanh t -> A.atanh t =<< cvtE x
PrimAtan2 t -> A.uncurry (A.atan2 t) =<< cvtE x
PrimExpFloating t -> A.exp t =<< cvtE x
PrimFPow t -> A.uncurry (A.fpow t) =<< cvtE x
PrimSqrt t -> A.sqrt t =<< cvtE x
PrimLog t -> A.log t =<< cvtE x
PrimLogBase t -> A.uncurry (A.logBase t) =<< cvtE x
PrimTruncate ta tb -> A.truncate ta tb =<< cvtE x
PrimRound ta tb -> A.round ta tb =<< cvtE x
PrimFloor ta tb -> A.floor ta tb =<< cvtE x
PrimCeiling ta tb -> A.ceiling ta tb =<< cvtE x
PrimIsNaN t -> A.isNaN t =<< cvtE x
PrimLt t -> A.uncurry (A.lt t) =<< cvtE x
PrimGt t -> A.uncurry (A.gt t) =<< cvtE x
PrimLtEq t -> A.uncurry (A.lte t) =<< cvtE x
PrimGtEq t -> A.uncurry (A.gte t) =<< cvtE x
PrimEq t -> A.uncurry (A.eq t) =<< cvtE x
PrimNEq t -> A.uncurry (A.neq t) =<< cvtE x
PrimMax t -> A.uncurry (A.max t) =<< cvtE x
PrimMin t -> A.uncurry (A.min t) =<< cvtE x
PrimLAnd -> A.land (inl x) (inr x)
PrimLOr -> A.lor (inl x) (inr x)
PrimLNot -> A.lnot =<< cvtE x
PrimOrd -> A.ord =<< cvtE x
PrimChr -> A.chr =<< cvtE x
PrimBoolToInt -> A.boolToInt =<< cvtE x
PrimFromIntegral ta tb -> A.fromIntegral ta tb =<< cvtE x
PrimToFloating ta tb -> A.toFloating ta tb =<< cvtE x
PrimCoerce ta tb -> A.coerce ta tb =<< cvtE x
indexHead :: IR (sh :. sz) -> IR sz
indexHead (IR (OP_Pair _ sz)) = IR sz
indexTail :: IR (sh :. sz) -> IR sh
indexTail (IR (OP_Pair sh _)) = IR sh
indexCons :: IR sh -> IR sz -> IR (sh :. sz)
indexCons (IR sh) (IR sz) = IR (OP_Pair sh sz)
intOfIndex :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR Int)
intOfIndex (IR extent) (IR index) = cvt (eltType (undefined::sh)) extent index
where
cvt :: TupleType t -> Operands t -> Operands t -> CodeGen (IR Int)
cvt UnitTuple OP_Unit OP_Unit
= return $ IR (constant (eltType (undefined :: Int)) 0)
cvt (PairTuple tsh t) (OP_Pair sh sz) (OP_Pair ix i)
| Just Refl <- matchTupleType t (eltType (undefined::Int))
= case matchTupleType tsh (eltType (undefined::Z)) of
Just Refl -> return (IR i)
Nothing -> do
a <- cvt tsh sh ix
b <- A.mul numType a (IR sz)
c <- A.add numType b (IR i)
return c
cvt (SingleTuple t) _ (op' t -> i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return $ ir t i
cvt _ _ _
= $internalError "intOfIndex" "expected shape with Int components"
indexOfInt :: forall sh. Shape sh => IR sh -> IR Int -> CodeGen (IR sh)
indexOfInt (IR extent) index = IR <$> cvt (eltType (undefined::sh)) extent index
where
cvt :: TupleType t -> Operands t -> IR Int -> CodeGen (Operands t)
cvt UnitTuple OP_Unit _
= return OP_Unit
cvt (PairTuple tsh tsz) (OP_Pair sh sz) i
| Just Refl <- matchTupleType tsz (eltType (undefined::Int))
= do
i' <- A.quot integralType i (IR sz)
IR r <- case matchTupleType tsh (eltType (undefined::Z)) of
Just Refl -> return i
Nothing -> A.rem integralType i (IR sz)
sh' <- cvt tsh sh i'
return $ OP_Pair sh' r
cvt (SingleTuple t) _ (IR i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return i
cvt _ _ _
= $internalError "indexOfInt" "expected shape with Int components"