{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Exp
where
import Control.Applicative hiding ( Const )
import Control.Monad
import Data.Proxy
import Data.Typeable
import Text.Printf
import Prelude hiding ( exp, any )
import qualified Data.IntMap as IM
import Data.Array.Accelerate.AST hiding ( Val(..), prj )
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar hiding ( Foreign, toTuple, shape, intersect, union )
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Sugar as A
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Operand
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.CodeGen.Type hiding ( typeOf )
import Data.Array.Accelerate.LLVM.Foreign
import qualified Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L
{-# INLINEABLE llvmOfFun1 #-}
llvmOfFun1
:: Foreign arch
=> arch
-> DelayedFun aenv (a -> b)
-> Gamma aenv
-> IRFun1 arch aenv (a -> b)
llvmOfFun1 arch (Lam (Body body)) aenv = IRFun1 $ \x -> llvmOfOpenExp arch body (Empty `Push` x) aenv
llvmOfFun1 _ _ _ = $internalError "llvmOfFun1" "impossible evaluation"
{-# INLINEABLE llvmOfFun2 #-}
llvmOfFun2
:: Foreign arch
=> arch
-> DelayedFun aenv (a -> b -> c)
-> Gamma aenv
-> IRFun2 arch aenv (a -> b -> c)
llvmOfFun2 arch (Lam (Lam (Body body))) aenv = IRFun2 $ \x y -> llvmOfOpenExp arch body (Empty `Push` x `Push` y) aenv
llvmOfFun2 _ _ _ = $internalError "llvmOfFun2" "impossible evaluation"
{-# INLINEABLE llvmOfOpenExp #-}
llvmOfOpenExp
:: forall arch env aenv _t. Foreign arch
=> arch
-> DelayedOpenExp env aenv _t
-> Val env
-> Gamma aenv
-> IROpenExp arch env aenv _t
llvmOfOpenExp arch top env aenv = cvtE top
where
cvtM :: DelayedOpenAcc aenv (Array sh e) -> IRManifest arch aenv (Array sh e)
cvtM (Manifest (Avar ix)) = IRManifest ix
cvtM _ = $internalError "llvmOfOpenExp" "expected manifest array variable"
cvtF1 :: DelayedOpenFun env aenv (a -> b) -> IROpenFun1 arch env aenv (a -> b)
cvtF1 (Lam (Body body)) = IRFun1 $ \x -> llvmOfOpenExp arch body (env `Push` x) aenv
cvtF1 _ = $internalError "cvtF1" "impossible evaluation"
cvtE :: forall t. DelayedOpenExp env aenv t -> IROpenExp arch env aenv t
cvtE exp =
case exp of
Let bnd body -> do x <- cvtE bnd
llvmOfOpenExp arch body (env `Push` x) aenv
Var ix -> return $ prj ix env
Const c -> return $ IR (constant (eltType (undefined::t)) c)
PrimConst c -> return $ IR (constant (eltType (undefined::t)) (fromElt (primConst c)))
PrimApp f x -> primFun f x
Undef -> return undefE
IndexNil -> return indexNil
IndexAny -> return indexAny
IndexCons sh sz -> indexCons <$> cvtE sh <*> cvtE sz
IndexHead ix -> indexHead <$> cvtE ix
IndexTail ix -> indexTail <$> cvtE ix
Prj ix tup -> prjT ix =<< cvtE tup
Tuple tup -> cvtT tup
Foreign asm f x -> foreignE asm f =<< cvtE x
Cond c t e -> A.ifThenElse (cvtE c) (cvtE t) (cvtE e)
IndexSlice slice slix sh -> indexSlice slice <$> cvtE slix <*> cvtE sh
IndexFull slice slix sh -> indexFull slice <$> cvtE slix <*> cvtE sh
ToIndex sh ix -> join $ intOfIndex <$> cvtE sh <*> cvtE ix
FromIndex sh ix -> join $ indexOfInt <$> cvtE sh <*> cvtE ix
Index acc ix -> index (cvtM acc) =<< cvtE ix
LinearIndex acc ix -> linearIndex (cvtM acc) =<< cvtE ix
ShapeSize sh -> shapeSize =<< cvtE sh
Shape acc -> return $ shape (cvtM acc)
Intersect sh1 sh2 -> join $ intersect <$> cvtE sh1 <*> cvtE sh2
Union sh1 sh2 -> join $ union <$> cvtE sh1 <*> cvtE sh2
While c f x -> while (cvtF1 c) (cvtF1 f) (cvtE x)
Coerce x -> coerce =<< cvtE x
indexNil :: IR Z
indexNil = IR (constant (eltType Z) (fromElt Z))
indexAny :: forall sh. Shape sh => IR (Any sh)
indexAny = let any = Any :: Any sh
in IR (constant (eltType any) (fromElt any))
undefE :: forall t. Elt t => IR t
undefE = IR $ go (eltType (undefined::t))
where
go :: TupleType s -> Operands s
go TypeRunit = OP_Unit
go (TypeRscalar t) = ir' t (undef t)
go (TypeRpair a b) = OP_Pair (go a) (go b)
indexSlice :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> IR slix
-> IR sh
-> IR sl
indexSlice slice (IR slix) (IR sh) = IR $ restrict slice slix sh
where
restrict :: SliceIndex slix sl co sh -> Operands slix -> Operands sh -> Operands sl
restrict SliceNil OP_Unit OP_Unit = OP_Unit
restrict (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) =
let sl' = restrict sliceIdx slx sl
in OP_Pair sl' sz
restrict (SliceFixed sliceIdx) (OP_Pair slx _i) (OP_Pair sl _sz) =
restrict sliceIdx slx sl
indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> IR slix
-> IR sl
-> IR sh
indexFull slice (IR slix) (IR sh) = IR $ extend slice slix sh
where
extend :: SliceIndex slix sl co sh -> Operands slix -> Operands sl -> Operands sh
extend SliceNil OP_Unit OP_Unit = OP_Unit
extend (SliceAll sliceIdx) (OP_Pair slx OP_Unit) (OP_Pair sl sz) =
let sh' = extend sliceIdx slx sl
in OP_Pair sh' sz
extend (SliceFixed sliceIdx) (OP_Pair slx sz) sl =
let sh' = extend sliceIdx slx sl
in OP_Pair sh' sz
prjT :: forall t e. (Elt t, IsTuple t, Elt e) => TupleIdx (TupleRepr t) e -> IR t -> CodeGen (IR e)
prjT tix (IR tup) =
case eltType (undefined::t) of
TypeRscalar (VectorScalarType v) -> goV tix v tup
t -> goT tix t tup
where
goT :: TupleIdx s e -> TupleType t' -> Operands t' -> CodeGen (IR e)
goT (SuccTupIdx ix) (TypeRpair t _) (OP_Pair x _) = goT ix t x
goT ZeroTupIdx (TypeRpair _ t) (OP_Pair _ x)
| Just Refl <- matchTupleType t (eltType (undefined::e))
= return $ IR x
goT _ _ _
= $internalError "prjT/tup" "inconsistent valuation"
goV :: forall (v :: * -> *) a. TupleIdx (ProdRepr t) e -> VectorType (v a) -> Operands (v a) -> CodeGen (IR e)
goV vix v (op' v -> vec)
| Just Refl <- matchProdR (prod Proxy (undefined::t)) (vecProdR v)
, Just Refl <- matchVecT (eltType (undefined::e)) (vecElemT v)
= instr $ ExtractElement v vix vec
goV _ _ _
= $internalError "prjT/vec" "inconsistent valuation"
matchVecT :: TupleType (EltRepr e) -> TupleType a -> Maybe (e :~: a)
matchVecT e v
| Just Refl <- matchTupleType e v = gcast Refl
| otherwise = Nothing
vecElemT :: VectorType (v a) -> TupleType a
vecElemT (Vector2Type a) = TypeRscalar (SingleScalarType a)
vecElemT (Vector3Type a) = TypeRscalar (SingleScalarType a)
vecElemT (Vector4Type a) = TypeRscalar (SingleScalarType a)
vecElemT (Vector8Type a) = TypeRscalar (SingleScalarType a)
vecElemT (Vector16Type a) = TypeRscalar (SingleScalarType a)
matchProdR :: ProdR Elt a -> ProdR Elt b -> Maybe (a :~: b)
matchProdR ProdRunit ProdRunit = Just Refl
matchProdR pa@(ProdRsnoc a) pb@(ProdRsnoc b)
| Just Refl <- matchProdR a b
, Just Refl <- matchTop pa pb
= Just Refl
where
matchTop :: forall ta tb a b. (Elt a, Elt b) => ProdR Elt (ta,a) -> ProdR Elt (tb,b) -> Maybe (a :~: b)
matchTop _ _
| Just Refl <- matchTupleType (eltType (undefined::a)) (eltType (undefined::b)) = gcast Refl
| otherwise = Nothing
matchProdR _ _
= Nothing
vecProdR :: VectorType v -> ProdR Elt (ProdRepr v)
vecProdR (Vector2Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V2 a)
vecProdR (Vector3Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V3 a)
vecProdR (Vector4Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V4 a)
vecProdR (Vector8Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V8 a)
vecProdR (Vector16Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V16 a)
cvtT :: forall t. (Elt t, IsTuple t) => Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (IR t)
cvtT tup =
case eltType (undefined::t) of
TypeRscalar (VectorScalarType v) -> IR <$> goV v tup
t -> IR <$> goT t tup
where
goT :: TupleType t' -> Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Operands t')
goT TypeRunit NilTup
= return OP_Unit
goT (TypeRpair ta tb) (SnocTup a (b :: DelayedOpenExp env aenv b))
| Just Refl <- matchTupleType tb (eltType (undefined::b))
= do a' <- goT ta a
IR b' <- cvtE b
return $ OP_Pair a' b'
goT _ _
= $internalError "cvtT/tup"
$ unlines [ "impossible evaluation"
, " possible solution: ensure that the 'EltRepr' and 'ProdRepr' instances of your data type are consistent." ]
goV :: forall (v :: * -> *) a. VectorType (v a) -> Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (Operands (v a))
goV v ts = ir' v . snd <$> pack ts
where
pack :: Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Int32, Operand (v a))
pack NilTup
= return (0, undef (VectorScalarType v))
pack (SnocTup t x)
| Just Refl <- matchExpType x
= do
x' <- cvtE x
(i, vec) <- pack t
vec' <- instr' $ InsertElement i vec (op a x')
return (i+1, vec')
where
matchExpType :: forall s. Elt s => DelayedOpenExp env aenv s -> Maybe (s :~: a)
matchExpType _
| Just Refl <- matchTupleType (eltType (undefined::s)) (TypeRscalar (SingleScalarType a)) = gcast Refl
| otherwise = Nothing
pack _
= $internalError "cvtT/vec" "impossible evaluation"
a :: SingleType a
a = case v of
Vector2Type t -> t
Vector3Type t -> t
Vector4Type t -> t
Vector8Type t -> t
Vector16Type t -> t
linearIndex :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR Int -> CodeGen (IR e)
linearIndex (IRManifest v) ix =
readArray (irArray (aprj v aenv)) ix
index :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh -> CodeGen (IR e)
index (IRManifest v) ix =
let arr = irArray (aprj v aenv)
in readArray arr =<< intOfIndex (irArrayShape arr) ix
shape :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh
shape (IRManifest v) = irArrayShape (irArray (aprj v aenv))
shapeSize :: forall sh. Shape sh => IR sh -> CodeGen (IR Int)
shapeSize (IR extent) = go (eltType (undefined::sh)) extent
where
go :: TupleType t -> Operands t -> CodeGen (IR Int)
go TypeRunit OP_Unit
= return $ IR (constant (eltType (undefined :: Int)) 1)
go (TypeRpair tsh t) (OP_Pair sh sz)
| Just Refl <- matchTupleType t (eltType (undefined::Int))
= do
a <- go tsh sh
b <- A.mul numType a (IR sz)
return b
go (TypeRscalar t) (op' t -> i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return $ ir t i
go _ _
= $internalError "shapeSize" "expected shape with Int components"
intersect :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
intersect (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
where
go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
go TypeRunit OP_Unit OP_Unit
= return OP_Unit
go (TypeRscalar t) sh1 sh2
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= do IR x <- A.min (singleType :: SingleType Int) (IR sh1) (IR sh2)
return x
go (TypeRpair tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
= do
sz' <- go tsz sz1 sz2
sh' <- go tsh sh1 sh2
return $ OP_Pair sh' sz'
go _ _ _
= $internalError "intersect" "expected shape with Int components"
union :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
union (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
where
go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
go TypeRunit OP_Unit OP_Unit
= return OP_Unit
go (TypeRscalar t) sh1 sh2
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= do IR x <- A.max (singleType :: SingleType Int) (IR sh1) (IR sh2)
return x
go (TypeRpair tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
= do
sz' <- go tsz sz1 sz2
sh' <- go tsh sh1 sh2
return $ OP_Pair sh' sz'
go _ _ _
= $internalError "union" "expected shape with Int components"
while :: Elt a
=> IROpenFun1 arch env aenv (a -> Bool)
-> IROpenFun1 arch env aenv (a -> a)
-> IROpenExp arch env aenv a
-> IROpenExp arch env aenv a
while p f x =
L.while (app1 p) (app1 f) =<< x
foreignE :: (Elt a, Elt b, Foreign arch, A.Foreign asm)
=> asm (a -> b)
-> DelayedFun () (a -> b)
-> IR a
-> IRExp arch () b
foreignE asm no x =
case foreignExp arch asm of
Just f -> app1 f x
Nothing | Lam (Body b) <- no -> llvmOfOpenExp arch b (Empty `Push` x) IM.empty
_ -> error "when a grid's misaligned with another behind / that's a moiré..."
coerce :: forall a b. (Elt a, Elt b) => IR a -> CodeGen (IR b)
coerce (IR as) = IR <$> go (eltType (undefined::a)) (eltType (undefined::b)) as
where
go :: TupleType s -> TupleType t -> Operands s -> CodeGen (Operands t)
go TypeRunit TypeRunit OP_Unit = return OP_Unit
go (TypeRpair s1 s2) (TypeRpair t1 t2) (OP_Pair x1 x2) = OP_Pair <$> go s1 t1 x1 <*> go s2 t2 x2
go (TypeRscalar s) (TypeRscalar t) x
| Just Refl <- matchScalarType s t = return x
| otherwise = ir' t <$> instr' (BitCast t (op' s x))
go (TypeRpair TypeRunit s) t@TypeRscalar{} (OP_Pair OP_Unit x) = go s t x
go s@TypeRscalar{} (TypeRpair TypeRunit t) x = OP_Pair OP_Unit <$> go s t x
go _ _ _
= error $ printf "could not coerce type `%s' to `%s'"
(show (typeOf (undefined::a)))
(show (typeOf (undefined::b)))
primFun :: Elt r
=> PrimFun (a -> r)
-> DelayedOpenExp env aenv a
-> CodeGen (IR r)
primFun f x =
let
inl :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv a
inl (Tuple (SnocTup (SnocTup NilTup a) _)) = cvtE a
inl t = cvtE $ Prj (SuccTupIdx ZeroTupIdx) t
inr :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv b
inr (Tuple (SnocTup _ b)) = cvtE b
inr t = cvtE $ Prj ZeroTupIdx t
in
case f of
PrimAdd t -> A.uncurry (A.add t) =<< cvtE x
PrimSub t -> A.uncurry (A.sub t) =<< cvtE x
PrimMul t -> A.uncurry (A.mul t) =<< cvtE x
PrimNeg t -> A.negate t =<< cvtE x
PrimAbs t -> A.abs t =<< cvtE x
PrimSig t -> A.signum t =<< cvtE x
PrimQuot t -> A.uncurry (A.quot t) =<< cvtE x
PrimRem t -> A.uncurry (A.rem t) =<< cvtE x
PrimQuotRem t -> A.uncurry (A.quotRem t) =<< cvtE x
PrimIDiv t -> A.uncurry (A.idiv t) =<< cvtE x
PrimMod t -> A.uncurry (A.mod t) =<< cvtE x
PrimDivMod t -> A.uncurry (A.divMod t) =<< cvtE x
PrimBAnd t -> A.uncurry (A.band t) =<< cvtE x
PrimBOr t -> A.uncurry (A.bor t) =<< cvtE x
PrimBXor t -> A.uncurry (A.xor t) =<< cvtE x
PrimBNot t -> A.complement t =<< cvtE x
PrimBShiftL t -> A.uncurry (A.shiftL t) =<< cvtE x
PrimBShiftR t -> A.uncurry (A.shiftR t) =<< cvtE x
PrimBRotateL t -> A.uncurry (A.rotateL t) =<< cvtE x
PrimBRotateR t -> A.uncurry (A.rotateR t) =<< cvtE x
PrimPopCount t -> A.popCount t =<< cvtE x
PrimCountLeadingZeros t -> A.countLeadingZeros t =<< cvtE x
PrimCountTrailingZeros t -> A.countTrailingZeros t =<< cvtE x
PrimFDiv t -> A.uncurry (A.fdiv t) =<< cvtE x
PrimRecip t -> A.recip t =<< cvtE x
PrimSin t -> A.sin t =<< cvtE x
PrimCos t -> A.cos t =<< cvtE x
PrimTan t -> A.tan t =<< cvtE x
PrimSinh t -> A.sinh t =<< cvtE x
PrimCosh t -> A.cosh t =<< cvtE x
PrimTanh t -> A.tanh t =<< cvtE x
PrimAsin t -> A.asin t =<< cvtE x
PrimAcos t -> A.acos t =<< cvtE x
PrimAtan t -> A.atan t =<< cvtE x
PrimAsinh t -> A.asinh t =<< cvtE x
PrimAcosh t -> A.acosh t =<< cvtE x
PrimAtanh t -> A.atanh t =<< cvtE x
PrimAtan2 t -> A.uncurry (A.atan2 t) =<< cvtE x
PrimExpFloating t -> A.exp t =<< cvtE x
PrimFPow t -> A.uncurry (A.fpow t) =<< cvtE x
PrimSqrt t -> A.sqrt t =<< cvtE x
PrimLog t -> A.log t =<< cvtE x
PrimLogBase t -> A.uncurry (A.logBase t) =<< cvtE x
PrimTruncate ta tb -> A.truncate ta tb =<< cvtE x
PrimRound ta tb -> A.round ta tb =<< cvtE x
PrimFloor ta tb -> A.floor ta tb =<< cvtE x
PrimCeiling ta tb -> A.ceiling ta tb =<< cvtE x
PrimIsNaN t -> A.isNaN t =<< cvtE x
PrimIsInfinite t -> A.isInfinite t =<< cvtE x
PrimLt t -> A.uncurry (A.lt t) =<< cvtE x
PrimGt t -> A.uncurry (A.gt t) =<< cvtE x
PrimLtEq t -> A.uncurry (A.lte t) =<< cvtE x
PrimGtEq t -> A.uncurry (A.gte t) =<< cvtE x
PrimEq t -> A.uncurry (A.eq t) =<< cvtE x
PrimNEq t -> A.uncurry (A.neq t) =<< cvtE x
PrimMax t -> A.uncurry (A.max t) =<< cvtE x
PrimMin t -> A.uncurry (A.min t) =<< cvtE x
PrimLAnd -> A.land (inl x) (inr x)
PrimLOr -> A.lor (inl x) (inr x)
PrimLNot -> A.lnot =<< cvtE x
PrimOrd -> A.ord =<< cvtE x
PrimChr -> A.chr =<< cvtE x
PrimBoolToInt -> A.boolToInt =<< cvtE x
PrimFromIntegral ta tb -> A.fromIntegral ta tb =<< cvtE x
PrimToFloating ta tb -> A.toFloating ta tb =<< cvtE x
indexHead :: IR (sh :. sz) -> IR sz
indexHead (IR (OP_Pair _ sz)) = IR sz
indexTail :: IR (sh :. sz) -> IR sh
indexTail (IR (OP_Pair sh _)) = IR sh
indexCons :: IR sh -> IR sz -> IR (sh :. sz)
indexCons (IR sh) (IR sz) = IR (OP_Pair sh sz)
intOfIndex :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR Int)
intOfIndex (IR extent) (IR index) = cvt (eltType (undefined::sh)) extent index
where
cvt :: TupleType t -> Operands t -> Operands t -> CodeGen (IR Int)
cvt TypeRunit OP_Unit OP_Unit
= return $ IR (constant (eltType (undefined :: Int)) 0)
cvt (TypeRpair tsh t) (OP_Pair sh sz) (OP_Pair ix i)
| Just Refl <- matchTupleType t (eltType (undefined::Int))
= case matchTupleType tsh (eltType (undefined::Z)) of
Just Refl -> return (IR i)
Nothing -> do
a <- cvt tsh sh ix
b <- A.mul numType a (IR sz)
c <- A.add numType b (IR i)
return c
cvt (TypeRscalar t) _ (op' t -> i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return $ ir t i
cvt _ _ _
= $internalError "intOfIndex" "expected shape with Int components"
indexOfInt :: forall sh. Shape sh => IR sh -> IR Int -> CodeGen (IR sh)
indexOfInt (IR extent) index = IR <$> cvt (eltType (undefined::sh)) extent index
where
cvt :: TupleType t -> Operands t -> IR Int -> CodeGen (Operands t)
cvt TypeRunit OP_Unit _
= return OP_Unit
cvt (TypeRpair tsh tsz) (OP_Pair sh sz) i
| Just Refl <- matchTupleType tsz (eltType (undefined::Int))
= do
i' <- A.quot integralType i (IR sz)
IR r <- case matchTupleType tsh (eltType (undefined::Z)) of
Just Refl -> return i
Nothing -> A.rem integralType i (IR sz)
sh' <- cvt tsh sh i'
return $ OP_Pair sh' r
cvt (TypeRscalar t) _ (IR i)
| Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
= return i
cvt _ _ _
= $internalError "indexOfInt" "expected shape with Int components"