module Data.Array.Accelerate.Analysis.Shape (
AccDim, AccDim2,
accDim, accDim2,
preAccDim, preAccDim2
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Sugar
type AccDim acc = forall aenv sh e. acc aenv (Array sh e) -> Int
type AccDim2 acc = forall aenv sh1 e1 sh2 e2. acc aenv (Array sh1 e1, Array sh2 e2) -> (Int,Int)
accDim :: AccDim OpenAcc
accDim (OpenAcc acc) = preAccDim accDim acc
preAccDim :: forall acc aenv sh e. AccDim acc -> PreOpenAcc acc aenv (Array sh e) -> Int
preAccDim k pacc =
case pacc of
Let _ acc -> k acc
Let2 _ acc -> k acc
Avar _ ->
case arrays :: ArraysR (Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
Apply _ _ ->
case arrays :: ArraysR (Array sh e) of
ArraysRarray -> ndim (eltType (undefined::sh))
Acond _ acc _ -> k acc
Use (Array _ _) -> ndim (eltType (undefined::sh))
Unit _ -> 0
Generate _ _ -> ndim (eltType (undefined::sh))
Reshape _ _ -> ndim (eltType (undefined::sh))
Replicate _ _ _ -> ndim (eltType (undefined::sh))
Index _ _ _ -> ndim (eltType (undefined::sh))
Map _ acc -> k acc
ZipWith _ _ acc -> k acc
Fold _ _ acc -> k acc 1
FoldSeg _ _ _ acc -> k acc
Fold1 _ acc -> k acc 1
Fold1Seg _ _ acc -> k acc
Scanl _ _ acc -> k acc
Scanl1 _ acc -> k acc
Scanr _ _ acc -> k acc
Scanr1 _ acc -> k acc
Permute _ acc _ _ -> k acc
Backpermute _ _ _ -> ndim (eltType (undefined::sh))
Stencil _ _ acc -> k acc
Stencil2 _ _ acc _ _ -> k acc
accDim2 :: AccDim2 OpenAcc
accDim2 (OpenAcc acc) = preAccDim2 accDim accDim2 acc
preAccDim2 :: forall acc aenv sh1 e1 sh2 e2.
AccDim acc
-> AccDim2 acc
-> PreOpenAcc acc aenv (Array sh1 e1, Array sh2 e2)
-> (Int, Int)
preAccDim2 k1 k2 pacc =
case pacc of
Let _ acc -> k2 acc
Let2 _ acc -> k2 acc
PairArrays acc1 acc2 -> (k1 acc1, k1 acc2)
Avar _ ->
case arrays :: ArraysR (Array sh1 e1, Array sh2 e2) of
ArraysRpair ArraysRarray ArraysRarray
-> (ndim (eltType (undefined::sh1))
,ndim (eltType (undefined::sh2)))
_ -> error "GHC is too dumb to realise that this is dead code"
Apply _ _ ->
case arrays :: ArraysR (Array sh1 e1, Array sh2 e2) of
ArraysRpair ArraysRarray ArraysRarray
-> (ndim (eltType (undefined::sh1))
,ndim (eltType (undefined::sh2)))
_ -> error "GHC is too dumb to realise that this is dead code"
Acond _ acc _ -> k2 acc
Scanl' _ _ acc -> (k1 acc, 0)
Scanr' _ _ acc -> (k1 acc, 0)
ndim :: TupleType a -> Int
ndim UnitTuple = 0
ndim (SingleTuple _) = 1
ndim (PairTuple a b) = ndim a + ndim b