module Data.Array.Accelerate.Analysis.Shape (
AccDim, accDim, preAccDim,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Sugar
type AccDim acc = forall aenv sh e. acc aenv (Array sh e) -> Int
accDim :: AccDim OpenAcc
accDim (OpenAcc acc) = preAccDim accDim acc
preAccDim :: forall acc aenv sh e. AccDim acc -> PreOpenAcc acc aenv (Array sh e) -> Int
preAccDim k pacc =
case pacc of
Alet _ acc -> k acc
Avar _ -> case arrays' (undefined :: Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
_ -> error "halt, fiend!"
Apply _ _ -> case arrays' (undefined :: Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
_ -> error "umm, hello"
Atuple _ -> case arrays' (undefined :: Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
_ -> error "can we keep him?"
Aprj _ _ -> case arrays' (undefined :: Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
_ -> error "inconceivable!"
Acond _ acc _ -> k acc
Use ((),(Array _ _)) -> ndim (eltType (undefined::sh))
Unit _ -> 0
Generate _ _ -> ndim (eltType (undefined::sh))
Reshape _ _ -> ndim (eltType (undefined::sh))
Replicate _ _ _ -> ndim (eltType (undefined::sh))
Index _ _ _ -> ndim (eltType (undefined::sh))
Map _ acc -> k acc
ZipWith _ _ acc -> k acc
Fold _ _ acc -> k acc 1
Fold1 _ acc -> k acc 1
FoldSeg _ _ acc _ -> k acc
Fold1Seg _ acc _ -> k acc
Scanl _ _ acc -> k acc
Scanl1 _ acc -> k acc
Scanr _ _ acc -> k acc
Scanr1 _ acc -> k acc
Permute _ acc _ _ -> k acc
Backpermute _ _ _ -> ndim (eltType (undefined::sh))
Stencil _ _ acc -> k acc
Stencil2 _ _ acc _ _ -> k acc
ndim :: TupleType a -> Int
ndim UnitTuple = 0
ndim (SingleTuple _) = 1
ndim (PairTuple a b) = ndim a + ndim b