{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.LLVM.CodeGen.Exp -- Copyright : [2015..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.LLVM.CodeGen.Exp where import Control.Applicative hiding ( Const ) import Control.Monad import Prelude hiding ( exp, any ) import qualified Data.IntMap as IM import Data.Array.Accelerate.AST hiding ( Val(..), prj ) import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Array.Sugar hiding ( Foreign, toTuple, shape, intersect, union ) import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) ) import Data.Array.Accelerate.Error import Data.Array.Accelerate.Product import Data.Array.Accelerate.Trafo import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Array.Sugar as A import Data.Array.Accelerate.LLVM.CodeGen.Array import Data.Array.Accelerate.LLVM.CodeGen.Base import Data.Array.Accelerate.LLVM.CodeGen.Constant import Data.Array.Accelerate.LLVM.CodeGen.Environment import Data.Array.Accelerate.LLVM.CodeGen.IR import Data.Array.Accelerate.LLVM.CodeGen.Monad ( CodeGen ) import Data.Array.Accelerate.LLVM.CodeGen.Sugar import Data.Array.Accelerate.LLVM.Foreign import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L import qualified Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A -- Scalar expressions -- ================== {-# INLINEABLE llvmOfFun1 #-} llvmOfFun1 :: Foreign arch => arch -> DelayedFun aenv (a -> b) -> Gamma aenv -> IRFun1 arch aenv (a -> b) llvmOfFun1 arch (Lam (Body body)) aenv = IRFun1 $ \x -> llvmOfOpenExp arch body (Empty `Push` x) aenv llvmOfFun1 _ _ _ = $internalError "llvmOfFun1" "impossible evaluation" {-# INLINEABLE llvmOfFun2 #-} llvmOfFun2 :: Foreign arch => arch -> DelayedFun aenv (a -> b -> c) -> Gamma aenv -> IRFun2 arch aenv (a -> b -> c) llvmOfFun2 arch (Lam (Lam (Body body))) aenv = IRFun2 $ \x y -> llvmOfOpenExp arch body (Empty `Push` x `Push` y) aenv llvmOfFun2 _ _ _ = $internalError "llvmOfFun2" "impossible evaluation" -- | Convert an open scalar expression into a sequence of LLVM IR instructions. -- Code is generated in depth first order, and uses a monad to collect the -- sequence of instructions used to construct basic blocks. -- {-# INLINEABLE llvmOfOpenExp #-} llvmOfOpenExp :: forall arch env aenv _t. Foreign arch => arch -> DelayedOpenExp env aenv _t -> Val env -> Gamma aenv -> IROpenExp arch env aenv _t llvmOfOpenExp arch top env aenv = cvtE top where cvtM :: DelayedOpenAcc aenv (Array sh e) -> IRManifest arch aenv (Array sh e) cvtM (Manifest (Avar ix)) = IRManifest ix cvtM _ = $internalError "llvmOfOpenExp" "expected manifest array variable" cvtF1 :: DelayedOpenFun env aenv (a -> b) -> IROpenFun1 arch env aenv (a -> b) cvtF1 (Lam (Body body)) = IRFun1 $ \x -> llvmOfOpenExp arch body (env `Push` x) aenv cvtF1 _ = $internalError "cvtF1" "impossible evaluation" cvtE :: forall t. DelayedOpenExp env aenv t -> IROpenExp arch env aenv t cvtE exp = case exp of Let bnd body -> do x <- cvtE bnd llvmOfOpenExp arch body (env `Push` x) aenv Var ix -> return $ prj ix env Const c -> return $ IR (constant (eltType (undefined::t)) c) PrimConst c -> return $ IR (constant (eltType (undefined::t)) (fromElt (primConst c))) PrimApp f x -> primFun f x IndexNil -> return indexNil IndexAny -> return indexAny IndexCons sh sz -> indexCons <$> cvtE sh <*> cvtE sz IndexHead ix -> indexHead <$> cvtE ix IndexTail ix -> indexTail <$> cvtE ix Prj ix tup -> prjT ix <$> cvtE tup Tuple tup -> cvtT tup Foreign asm f x -> foreignE asm f =<< cvtE x Cond c t e -> A.ifThenElse (cvtE c) (cvtE t) (cvtE e) IndexSlice slice slix sh -> indexSlice slice <$> cvtE slix <*> cvtE sh IndexFull slice slix sh -> indexFull slice <$> cvtE slix <*> cvtE sh ToIndex sh ix -> join $ intOfIndex <$> cvtE sh <*> cvtE ix FromIndex sh ix -> join $ indexOfInt <$> cvtE sh <*> cvtE ix Index acc ix -> index (cvtM acc) =<< cvtE ix LinearIndex acc ix -> linearIndex (cvtM acc) =<< cvtE ix ShapeSize sh -> shapeSize =<< cvtE sh Shape acc -> return $ shape (cvtM acc) Intersect sh1 sh2 -> join $ intersect <$> cvtE sh1 <*> cvtE sh2 Union sh1 sh2 -> join $ union <$> cvtE sh1 <*> cvtE sh2 While c f x -> while (cvtF1 c) (cvtF1 f) (cvtE x) indexNil :: IR Z indexNil = IR (constant (eltType Z) (fromElt Z)) indexAny :: forall sh. Shape sh => IR (Any sh) indexAny = let any = Any :: Any sh in IR (constant (eltType any) (fromElt any)) indexSlice :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) -> IR slix -> IR sh -> IR sl indexSlice slice (IR slix) (IR sh) = IR $ restrict slice slix sh where restrict :: SliceIndex slix sl co sh -> Operands slix -> Operands sh -> Operands sl restrict SliceNil OP_Unit OP_Unit = OP_Unit restrict (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) = let sl' = restrict sliceIdx slx sl in OP_Pair sl' sz restrict (SliceFixed sliceIdx) (OP_Pair slx _i) (OP_Pair sl _sz) = restrict sliceIdx slx sl indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) -> IR slix -> IR sl -> IR sh indexFull slice (IR slix) (IR sh) = IR $ extend slice slix sh where extend :: SliceIndex slix sl co sh -> Operands slix -> Operands sl -> Operands sh extend SliceNil OP_Unit OP_Unit = OP_Unit extend (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) = let sh' = extend sliceIdx slx sl in OP_Pair sh' sz extend (SliceFixed sliceIdx) (OP_Pair slx sz) sl = let sh' = extend sliceIdx slx sl in OP_Pair sh' sz prjT :: forall t e. (Elt t, Elt e) => TupleIdx (TupleRepr t) e -> IR t -> IR e prjT tix (IR ops) = IR $ go tix (eltType (undefined::t)) ops where go :: TupleIdx v e -> TupleType t' -> Operands t' -> Operands (EltRepr e) go ZeroTupIdx (PairTuple _ t) (OP_Pair _ v) | Just Refl <- matchTupleType t (eltType (undefined :: e)) = v go (SuccTupIdx ix) (PairTuple t _) (OP_Pair tup _) = go ix t tup go _ _ _ = $internalError "prjT" "inconsistent valuation" cvtT :: forall t. (Elt t, IsTuple t) => Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (IR t) cvtT tup = IR <$> go (eltType (undefined::t)) tup where go :: TupleType t' -> Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Operands t') go UnitTuple NilTup = return OP_Unit go (PairTuple ta tb) (SnocTup a (b :: DelayedOpenExp env aenv b)) -- We must assert that the reified type 'tb' of 'b' is actually -- equivalent to the type of 'b'. This can not fail, but is necessary -- because 'tb' observes the representation type of surface type 'b'. | Just Refl <- matchTupleType tb (eltType (undefined::b)) = do a' <- go ta a IR b' <- cvtE b return $ OP_Pair a' b' go _ _ = $internalError "cvtT" $ unlines [ "impossible evaluation" , " possible solution: ensure that the 'EltRepr' and 'ProdRepr' instances of your data type are consistent." ] linearIndex :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR Int -> CodeGen (IR e) linearIndex (IRManifest v) ix = readArray (irArray (aprj v aenv)) ix index :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh -> CodeGen (IR e) index (IRManifest v) ix = let arr = irArray (aprj v aenv) in readArray arr =<< intOfIndex (irArrayShape arr) ix shape :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh shape (IRManifest v) = irArrayShape (irArray (aprj v aenv)) shapeSize :: forall sh. Shape sh => IR sh -> CodeGen (IR Int) shapeSize (IR extent) = go (eltType (undefined::sh)) extent where go :: TupleType t -> Operands t -> CodeGen (IR Int) go UnitTuple OP_Unit = return $ IR (constant (eltType (undefined :: Int)) 1) go (PairTuple tsh t) (OP_Pair sh sz) | Just Refl <- matchTupleType t (eltType (undefined::Int)) = do a <- go tsh sh b <- A.mul numType a (IR sz) return b go (SingleTuple t) (op' t -> i) | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) = return $ ir t i go _ _ = $internalError "shapeSize" "expected shape with Int components" intersect :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh) intersect (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2 where go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t) go UnitTuple OP_Unit OP_Unit = return OP_Unit go (SingleTuple t) sh1 sh2 | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) -- TLM: GHC hang if this is omitted = do IR x <- A.min t (IR sh1) (IR sh2) return x go (PairTuple tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2) = do sz' <- go tsz sz1 sz2 sh' <- go tsh sh1 sh2 return $ OP_Pair sh' sz' go _ _ _ = $internalError "intersect" "expected shape with Int components" union :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh) union (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2 where go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t) go UnitTuple OP_Unit OP_Unit = return OP_Unit go (SingleTuple t) sh1 sh2 | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) -- TLM: GHC hang if this is omitted = do IR x <- A.max t (IR sh1) (IR sh2) return x go (PairTuple tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2) = do sz' <- go tsz sz1 sz2 sh' <- go tsh sh1 sh2 return $ OP_Pair sh' sz' go _ _ _ = $internalError "union" "expected shape with Int components" while :: Elt a => IROpenFun1 arch env aenv (a -> Bool) -> IROpenFun1 arch env aenv (a -> a) -> IROpenExp arch env aenv a -> IROpenExp arch env aenv a while p f x = L.while (app1 p) (app1 f) =<< x foreignE :: (Elt a, Elt b, Foreign arch, A.Foreign asm) => asm (a -> b) -> DelayedFun () (a -> b) -> IR a -> IRExp arch () b foreignE asm no x = case foreignExp arch asm of Just f -> app1 f x Nothing | Lam (Body b) <- no -> llvmOfOpenExp arch b (Empty `Push` x) IM.empty _ -> error "when a grid's misaligned with another behind / that's a moiré..." primFun :: Elt r => PrimFun (a -> r) -> DelayedOpenExp env aenv a -> CodeGen (IR r) primFun f x = let -- The Accelerate language and its code generator are hyper-strict. -- However, we must not eagerly evaluate the arguments to logical -- operations (&&*) and (||*) so that they can short-circuit. Since we -- only have unary functions, this is a little tricky for us. -- -- 'inl' and 'inr' attempt to destruct the incoming AST so that we can -- evaluate the left or right components of a pair individually. It -- should be noted that there are other cases which can evaluate to -- pairs; 'Constant', 'Let' and 'Var', for example, but these cases -- are (probably) not applicable in this context. -- inl :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv a inl (Tuple (SnocTup (SnocTup NilTup a) _)) = cvtE a inl t = cvtE $ Prj (SuccTupIdx ZeroTupIdx) t inr :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv b inr (Tuple (SnocTup _ b)) = cvtE b inr t = cvtE $ Prj ZeroTupIdx t in case f of PrimAdd t -> A.uncurry (A.add t) =<< cvtE x PrimSub t -> A.uncurry (A.sub t) =<< cvtE x PrimMul t -> A.uncurry (A.mul t) =<< cvtE x PrimNeg t -> A.negate t =<< cvtE x PrimAbs t -> A.abs t =<< cvtE x PrimSig t -> A.signum t =<< cvtE x PrimQuot t -> A.uncurry (A.quot t) =<< cvtE x PrimRem t -> A.uncurry (A.rem t) =<< cvtE x PrimQuotRem t -> A.uncurry (A.quotRem t) =<< cvtE x PrimIDiv t -> A.uncurry (A.idiv t) =<< cvtE x PrimMod t -> A.uncurry (A.mod t) =<< cvtE x PrimDivMod t -> A.uncurry (A.divMod t) =<< cvtE x PrimBAnd t -> A.uncurry (A.band t) =<< cvtE x PrimBOr t -> A.uncurry (A.bor t) =<< cvtE x PrimBXor t -> A.uncurry (A.xor t) =<< cvtE x PrimBNot t -> A.complement t =<< cvtE x PrimBShiftL t -> A.uncurry (A.shiftL t) =<< cvtE x PrimBShiftR t -> A.uncurry (A.shiftR t) =<< cvtE x PrimBRotateL t -> A.uncurry (A.rotateL t) =<< cvtE x PrimBRotateR t -> A.uncurry (A.rotateR t) =<< cvtE x PrimPopCount t -> A.popCount t =<< cvtE x PrimCountLeadingZeros t -> A.countLeadingZeros t =<< cvtE x PrimCountTrailingZeros t -> A.countTrailingZeros t =<< cvtE x PrimFDiv t -> A.uncurry (A.fdiv t) =<< cvtE x PrimRecip t -> A.recip t =<< cvtE x PrimSin t -> A.sin t =<< cvtE x PrimCos t -> A.cos t =<< cvtE x PrimTan t -> A.tan t =<< cvtE x PrimSinh t -> A.sinh t =<< cvtE x PrimCosh t -> A.cosh t =<< cvtE x PrimTanh t -> A.tanh t =<< cvtE x PrimAsin t -> A.asin t =<< cvtE x PrimAcos t -> A.acos t =<< cvtE x PrimAtan t -> A.atan t =<< cvtE x PrimAsinh t -> A.asinh t =<< cvtE x PrimAcosh t -> A.acosh t =<< cvtE x PrimAtanh t -> A.atanh t =<< cvtE x PrimAtan2 t -> A.uncurry (A.atan2 t) =<< cvtE x PrimExpFloating t -> A.exp t =<< cvtE x PrimFPow t -> A.uncurry (A.fpow t) =<< cvtE x PrimSqrt t -> A.sqrt t =<< cvtE x PrimLog t -> A.log t =<< cvtE x PrimLogBase t -> A.uncurry (A.logBase t) =<< cvtE x PrimTruncate ta tb -> A.truncate ta tb =<< cvtE x PrimRound ta tb -> A.round ta tb =<< cvtE x PrimFloor ta tb -> A.floor ta tb =<< cvtE x PrimCeiling ta tb -> A.ceiling ta tb =<< cvtE x PrimIsNaN t -> A.isNaN t =<< cvtE x PrimIsInfinite t -> A.isInfinite t =<< cvtE x PrimLt t -> A.uncurry (A.lt t) =<< cvtE x PrimGt t -> A.uncurry (A.gt t) =<< cvtE x PrimLtEq t -> A.uncurry (A.lte t) =<< cvtE x PrimGtEq t -> A.uncurry (A.gte t) =<< cvtE x PrimEq t -> A.uncurry (A.eq t) =<< cvtE x PrimNEq t -> A.uncurry (A.neq t) =<< cvtE x PrimMax t -> A.uncurry (A.max t) =<< cvtE x PrimMin t -> A.uncurry (A.min t) =<< cvtE x PrimLAnd -> A.land (inl x) (inr x) -- short circuit PrimLOr -> A.lor (inl x) (inr x) -- short circuit PrimLNot -> A.lnot =<< cvtE x PrimOrd -> A.ord =<< cvtE x PrimChr -> A.chr =<< cvtE x PrimBoolToInt -> A.boolToInt =<< cvtE x PrimFromIntegral ta tb -> A.fromIntegral ta tb =<< cvtE x PrimToFloating ta tb -> A.toFloating ta tb =<< cvtE x PrimCoerce ta tb -> A.coerce ta tb =<< cvtE x -- no missing patterns, whoo! -- | Extract the head of an index -- indexHead :: IR (sh :. sz) -> IR sz indexHead (IR (OP_Pair _ sz)) = IR sz -- | Extract the tail of an index -- indexTail :: IR (sh :. sz) -> IR sh indexTail (IR (OP_Pair sh _)) = IR sh -- | Construct an index from the head and tail -- indexCons :: IR sh -> IR sz -> IR (sh :. sz) indexCons (IR sh) (IR sz) = IR (OP_Pair sh sz) -- | Convert a multidimensional array index into a linear index -- intOfIndex :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR Int) intOfIndex (IR extent) (IR index) = cvt (eltType (undefined::sh)) extent index where cvt :: TupleType t -> Operands t -> Operands t -> CodeGen (IR Int) cvt UnitTuple OP_Unit OP_Unit = return $ IR (constant (eltType (undefined :: Int)) 0) cvt (PairTuple tsh t) (OP_Pair sh sz) (OP_Pair ix i) | Just Refl <- matchTupleType t (eltType (undefined::Int)) -- If we short-circuit the last dimension, we can avoid inserting -- a multiply by zero and add of the result. = case matchTupleType tsh (eltType (undefined::Z)) of Just Refl -> return (IR i) Nothing -> do a <- cvt tsh sh ix b <- A.mul numType a (IR sz) c <- A.add numType b (IR i) return c cvt (SingleTuple t) _ (op' t -> i) | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) = return $ ir t i cvt _ _ _ = $internalError "intOfIndex" "expected shape with Int components" -- | Convert a linear index into into a multidimensional index -- indexOfInt :: forall sh. Shape sh => IR sh -> IR Int -> CodeGen (IR sh) indexOfInt (IR extent) index = IR <$> cvt (eltType (undefined::sh)) extent index where cvt :: TupleType t -> Operands t -> IR Int -> CodeGen (Operands t) cvt UnitTuple OP_Unit _ = return OP_Unit cvt (PairTuple tsh tsz) (OP_Pair sh sz) i | Just Refl <- matchTupleType tsz (eltType (undefined::Int)) = do i' <- A.quot integralType i (IR sz) -- If we assume the index is in range, there is no point computing -- the remainder of the highest dimension since (i < sz) must hold IR r <- case matchTupleType tsh (eltType (undefined::Z)) of Just Refl -> return i -- TODO: in debug mode assert (i < sz) Nothing -> A.rem integralType i (IR sz) sh' <- cvt tsh sh i' return $ OP_Pair sh' r cvt (SingleTuple t) _ (IR i) | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) = return i cvt _ _ _ = $internalError "indexOfInt" "expected shape with Int components"