module Data.Array.Accelerate.Analysis.Type (
accType, accType2, expType, sizeOf
) where
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.AST
import qualified Foreign.Storable as F
accType :: forall aenv dim e.
OpenAcc aenv (Array dim e) -> TupleType (ElemRepr e)
accType (Let _ acc) = accType acc
accType (Let2 _ acc) = accType acc
accType (Avar _) = elemType (undefined::e)
accType (Use arr) = arrayType arr
accType (Unit _) = elemType (undefined::e)
accType (Reshape _ acc) = accType acc
accType (Replicate _ _ acc) = accType acc
accType (Index _ acc _) = accType acc
accType (Map _ _) = elemType (undefined::e)
accType (ZipWith _ _ _) = elemType (undefined::e)
accType (Fold _ _ acc) = accType acc
accType (FoldSeg _ _ acc _) = accType acc
accType (Permute _ _ _ acc) = accType acc
accType (Backpermute _ _ acc) = accType acc
accType2 :: OpenAcc aenv (Array dim1 e1, Array dim2 e2)
-> (TupleType (ElemRepr e1), TupleType (ElemRepr e2))
accType2 (Scanl _ e acc) = (accType acc, expType e)
accType2 (Scanr _ e acc) = (accType acc, expType e)
expType :: forall aenv env t. OpenExp aenv env t -> TupleType (ElemRepr t)
expType (Var _) = elemType (undefined::t)
expType (Const _) = elemType (undefined::t)
expType (Tuple _) = elemType (undefined::t)
expType (Prj idx _) = tupleIdxType idx
expType (Cond _ t _) = expType t
expType (PrimConst _) = elemType (undefined::t)
expType (PrimApp _ _) = elemType (undefined::t)
expType (IndexScalar acc _) = accType acc
expType (Shape _) = elemType (undefined::t)
tupleIdxType :: forall t e. TupleIdx t e -> TupleType (ElemRepr e)
tupleIdxType ZeroTupIdx = elemType (undefined::e)
tupleIdxType (SuccTupIdx idx) = tupleIdxType idx
sizeOf :: TupleType a -> Int
sizeOf UnitTuple = 0
sizeOf (PairTuple a b) = sizeOf a + sizeOf b
sizeOf (SingleTuple (NumScalarType (IntegralNumType t)))
| IntegralDict <- integralDict t = F.sizeOf $ (undefined :: IntegralType a -> a) t
sizeOf (SingleTuple (NumScalarType (FloatingNumType t)))
| FloatingDict <- floatingDict t = F.sizeOf $ (undefined :: FloatingType a -> a) t
sizeOf (SingleTuple (NonNumScalarType t))
| NonNumDict <- nonNumDict t = F.sizeOf $ (undefined :: NonNumType a -> a) t