module Data.Array.Accelerate.Pretty.Traverse
where
import Data.Array.Accelerate.Array.Sugar hiding ((!))
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Type
cat :: Show s => String -> s -> String
cat t s = t ++ show s
data Labels = Labels { accFormat :: String
, expFormat :: String
, funFormat :: String
, tupleFormat :: String
, arrayFormat :: String
, boundaryFormat :: String
, primFunFormat :: String
}
travAcc :: forall m b aenv a. Monad m => Labels -> (String -> String -> [m b] -> m b)
-> (String -> String -> m b) -> OpenAcc aenv a -> m b
travAcc f c l (OpenAcc openAcc) = travAcc' openAcc
where
combine = c (accFormat f)
leaf = l (accFormat f)
travAcc' :: PreOpenAcc OpenAcc aenv a -> m b
travAcc' (Alet acc1 acc2) = combine "Alet" [travAcc f c l acc1, travAcc f c l acc2]
travAcc' (Avar idx) = leaf ("AVar " `cat` idxToInt idx)
travAcc' (Apply afun acc) = combine "Apply" [travAfun f c l afun, travAcc f c l acc]
travAcc' (Aforeign ff afun acc) = combine ("Aforeign " ++ strForeign ff) [travAfun f c l afun, travAcc f c l acc]
travAcc' (Acond e acc1 acc2) = combine "Acond" [travExp f c l e, travAcc f c l acc1, travAcc f c l acc2]
travAcc' (Atuple tup) = combine "Atuple" [ travAtuple f c l tup ]
travAcc' (Aprj idx a) = combine ("Aprj " `cat` tupleIdxToInt idx) [ travAcc f c l a ]
travAcc' (Use arr) = combine "Use" [ travArrays f c l (arrays (undefined::a)) arr ]
travAcc' (Unit e) = combine "Unit" [ travExp f c l e ]
travAcc' (Generate sh fun) = combine "Generate" [ travExp f c l sh, travFun f c l fun]
travAcc' (Transform sh pf vf acc) = combine "Transform" [ travExp f c l sh, travFun f c l pf, travFun f c l vf, travAcc f c l acc ]
travAcc' (Reshape sh acc) = combine "Reshape" [ travExp f c l sh, travAcc f c l acc ]
travAcc' (Replicate _ ix acc) = combine "Replicate" [ travExp f c l ix, travAcc f c l acc ]
travAcc' (Slice _ acc ix) = combine "Slice" [ travAcc f c l acc, travExp f c l ix ]
travAcc' (Map fun acc) = combine "Map" [ travFun f c l fun, travAcc f c l acc ]
travAcc' (ZipWith fun acc1 acc2) = combine "ZipWith" [ travFun f c l fun, travAcc f c l acc1, travAcc f c l acc2 ]
travAcc' (Fold fun e acc) = combine "Fold" [ travFun f c l fun, travExp f c l e, travAcc f c l acc]
travAcc' (Fold1 fun acc) = combine "Fold1" [ travFun f c l fun, travAcc f c l acc]
travAcc' (FoldSeg fun e acc1 acc2) = combine "FoldSeg" [ travFun f c l fun, travExp f c l e
, travAcc f c l acc1, travAcc f c l acc2 ]
travAcc' (Fold1Seg fun acc1 acc2) = combine "FoldSeg1" [ travFun f c l fun, travAcc f c l acc1, travAcc f c l acc2 ]
travAcc' (Scanl fun e acc) = combine "Scanl" [ travFun f c l fun, travExp f c l e, travAcc f c l acc ]
travAcc' (Scanl' fun e acc) = combine "Scanl'" [ travFun f c l fun, travExp f c l e, travAcc f c l acc ]
travAcc' (Scanl1 fun acc) = combine "Scanl1" [ travFun f c l fun, travAcc f c l acc ]
travAcc' (Scanr fun e acc) = combine "Scanr" [ travFun f c l fun, travExp f c l e, travAcc f c l acc ]
travAcc' (Scanr' fun e acc) = combine "Scanr'" [ travFun f c l fun, travExp f c l e, travAcc f c l acc ]
travAcc' (Scanr1 fun acc) = combine "Scanr1" [ travFun f c l fun, travAcc f c l acc ]
travAcc' (Permute fun dfts p acc) = combine "Permute" [ travFun f c l fun, travAcc f c l dfts
, travFun f c l p, travAcc f c l acc]
travAcc' (Backpermute sh p acc) = combine "Backpermute" [ travExp f c l sh, travFun f c l p, travAcc f c l acc]
travAcc' (Stencil sten bndy acc) = combine "Stencil" [ travFun f c l sten, travBoundary f l acc bndy
, travAcc f c l acc]
travAcc' (Stencil2 sten bndy1 acc1 bndy2 acc2) = combine "Stencil2" [ travFun f c l sten, travBoundary f l acc1 bndy1
, travAcc f c l acc1, travBoundary f l acc2 bndy2
, travAcc f c l acc2]
travExp :: forall m env aenv a b . Monad m => Labels
-> (String -> String -> [m b] -> m b)
-> (String -> String -> m b)
-> OpenExp env aenv a -> m b
travExp f c l expr = travExp' expr
where
combine = c (expFormat f)
leaf = l (expFormat f)
travExp' :: OpenExp env aenv a -> m b
travExp' (Let e1 e2) = combine "Let" [travExp f c l e1, travExp f c l e2]
travExp' (Var idx) = leaf ("Var " `cat` idxToInt idx)
travExp' (Const v) = leaf ("Const " `cat` (toElt v :: a))
travExp' (Tuple tup) = combine "Tuple" [ travTuple f c l tup ]
travExp' (Prj idx e) = combine ("Prj " `cat` tupleIdxToInt idx) [ travExp f c l e ]
travExp' (IndexNil) = leaf "IndexNil"
travExp' (IndexCons t h) = combine "IndexCons" [ travExp f c l t, travExp f c l h]
travExp' (IndexHead ix) = combine "IndexHead" [ travExp f c l ix ]
travExp' (IndexTail ix) = combine "IndexTail" [ travExp f c l ix ]
travExp' (IndexAny) = leaf "IndexAny"
travExp' (IndexSlice _ slix sh) = combine "IndexSlice" [ travExp f c l slix, travExp f c l sh ]
travExp' (IndexFull _ slix sl) = combine "IndexFull" [ travExp f c l slix, travExp f c l sl ]
travExp' (ToIndex sh ix) = combine "ToIndex" [ travExp f c l sh, travExp f c l ix ]
travExp' (FromIndex sh ix) = combine "FromIndex" [ travExp f c l sh, travExp f c l ix ]
travExp' (Cond cond thn els) = combine "Cond" [travExp f c l cond, travExp f c l thn, travExp f c l els]
travExp' (Iterate _ fun x) = combine "Iterate" [ travFun f c l (Lam (Body fun)), travExp f c l x ]
travExp' (PrimConst a) = leaf ("PrimConst " `cat` labelForConst a)
travExp' (PrimApp p a) = combine "PrimApp" [ l (primFunFormat f) (labelForPrimFun p), travExp f c l a ]
travExp' (Index idx i) = combine "Index" [ travAcc f c l idx, travExp f c l i]
travExp' (LinearIndex idx i) = combine "LinearIndex" [ travAcc f c l idx, travExp f c l i]
travExp' (Shape idx) = combine "Shape" [ travAcc f c l idx ]
travExp' (ShapeSize e) = combine "ShapeSize" [ travExp f c l e ]
travExp' (Intersect sh1 sh2) = combine "Intersect" [ travExp f c l sh1, travExp f c l sh2 ]
travExp' (Foreign ff fun e) = combine ("Foreign " ++ strForeign ff) [ travFun f c l fun, travExp f c l e ]
travAfun :: forall m b aenv fun. Monad m => Labels -> (String -> String -> [m b] -> m b)
-> (String -> String -> m b) -> OpenAfun aenv fun -> m b
travAfun f c l openAfun = travAfun' openAfun
where
combine = c (funFormat f)
travAfun' :: OpenAfun aenv fun -> m b
travAfun' (Abody body) = combine "Abody" [ travAcc f c l body ]
travAfun' (Alam fun) = combine "Alam" [ travAfun f c l fun ]
travArrays
:: forall m a b. Monad m
=> Labels
-> (String -> String -> [m b] -> m b)
-> (String -> String -> m b)
-> ArraysR a
-> a
-> m b
travArrays f c l = trav
where
combine = c (accFormat f)
leaf = l (accFormat f)
trav :: ArraysR a' -> a' -> m b
trav ArraysRunit () = leaf "ArraysRunit"
trav ArraysRarray a = combine "ArraysRarray" [ travArray f l a ]
trav (ArraysRpair r1 r2) (a1, a2) = combine "ArraysRpair" [trav r1 a1, trav r2 a2]
travArray :: forall dim a m b. Monad m => Labels -> (String -> String -> m b) -> Array dim a -> m b
travArray f l (Array sh _) = l (arrayFormat f) ("Array" `cat` (toElt sh :: dim))
travBoundary :: forall aenv dim e m b. (Monad m, Elt e) => Labels -> (String -> String -> m b)
-> OpenAcc aenv (Array dim e)
-> Boundary (EltRepr e) -> m b
travBoundary f l boundary = travBoundary' boundary
where
leaf = l (boundaryFormat f)
travBoundary' :: OpenAcc aenv (Array dim e) -> Boundary (EltRepr e) -> m b
travBoundary' _ Clamp = leaf "Clamp"
travBoundary' _ Mirror = leaf "Mirror"
travBoundary' _ Wrap = leaf "Wrap"
travBoundary' _ (Constant e) = leaf ("Constant " `cat` (toElt e :: e))
travAtuple
:: forall m b aenv t. Monad m
=> Labels
-> (String -> String -> [m b] -> m b)
-> (String -> String -> m b)
-> Atuple (OpenAcc aenv) t
-> m b
travAtuple f c l = trav
where
leaf = l (tupleFormat f)
combine = c (tupleFormat f)
trav :: Atuple (OpenAcc aenv) t' -> m b
trav NilAtup = leaf "NilAtup"
trav (SnocAtup tup a) = combine "SnocAtup" [ trav tup, travAcc f c l a ]
travFun :: forall m b env aenv fun.Monad m => Labels -> (String -> String -> [m b] -> m b)
-> (String -> String -> m b) -> OpenFun env aenv fun -> m b
travFun f c l openFun = travFun' openFun
where
combine = c (funFormat f)
travFun' :: OpenFun env aenv fun -> m b
travFun' (Body body) = combine "Body" [ travExp f c l body ]
travFun' (Lam fun) = combine "Lam" [ travFun f c l fun ]
travTuple :: forall m b env aenv t. Monad m => Labels -> (String -> String -> [m b] -> m b)
-> (String -> String -> m b) -> Tuple (OpenExp env aenv) t -> m b
travTuple f c l tuple = travTuple' tuple
where
leaf = l (tupleFormat f)
combine = c (tupleFormat f)
travTuple' :: Tuple (OpenExp env aenv) t -> m b
travTuple' NilTup = leaf "NilTup"
travTuple' (SnocTup tup e) = combine "SnocTup" [ travTuple f c l tup, travExp f c l e ]
labelForPrimFun :: PrimFun a -> String
labelForPrimFun (PrimAdd _) = "PrimAdd"
labelForPrimFun (PrimSub _) = "PrimSub"
labelForPrimFun (PrimMul _) = "PrimMul"
labelForPrimFun (PrimNeg _) = "PrimNeg"
labelForPrimFun (PrimAbs _) = "PrimAbs"
labelForPrimFun (PrimSig _) = "PrimSig"
labelForPrimFun (PrimQuot _) = "PrimQuot"
labelForPrimFun (PrimRem _) = "PrimRem"
labelForPrimFun (PrimIDiv _) = "PrimIDiv"
labelForPrimFun (PrimMod _) = "PrimMod"
labelForPrimFun (PrimBAnd _) = "PrimBAnd"
labelForPrimFun (PrimBOr _) = "PrimBOr"
labelForPrimFun (PrimBXor _) = "PrimBXor"
labelForPrimFun (PrimBNot _) = "PrimBNot"
labelForPrimFun (PrimBShiftL _) = "PrimBShiftL"
labelForPrimFun (PrimBShiftR _) = "PrimBShiftR"
labelForPrimFun (PrimBRotateL _) = "PrimBRotateL"
labelForPrimFun (PrimBRotateR _) = "PrimBRotateR"
labelForPrimFun (PrimFDiv _) = "PrimFDiv"
labelForPrimFun (PrimRecip _) = "PrimRecip"
labelForPrimFun (PrimSin _) = "PrimSin"
labelForPrimFun (PrimCos _) = "PrimCos"
labelForPrimFun (PrimTan _) = "PrimTan"
labelForPrimFun (PrimAsin _) = "PrimAsin"
labelForPrimFun (PrimAcos _) = "PrimAcos"
labelForPrimFun (PrimAtan _) = "PrimAtan"
labelForPrimFun (PrimAsinh _) = "PrimAsinh"
labelForPrimFun (PrimAcosh _) = "PrimAcosh"
labelForPrimFun (PrimAtanh _) = "PrimAtanh"
labelForPrimFun (PrimExpFloating _) = "PrimExpFloating"
labelForPrimFun (PrimSqrt _) = "PrimSqrt"
labelForPrimFun (PrimLog _) = "PrimLog"
labelForPrimFun (PrimFPow _) = "PrimFPow"
labelForPrimFun (PrimLogBase _) = "PrimLogBase"
labelForPrimFun (PrimTruncate _ _) = "PrimTruncate"
labelForPrimFun (PrimRound _ _) = "PrimRound"
labelForPrimFun (PrimFloor _ _) = "PrimFloor"
labelForPrimFun (PrimCeiling _ _) = "PrimCeiling"
labelForPrimFun (PrimAtan2 _) = "PrimAtan2"
labelForPrimFun (PrimLt _) = "PrimLt"
labelForPrimFun (PrimGt _) = "PrimGt"
labelForPrimFun (PrimLtEq _) = "PrimLtEq"
labelForPrimFun (PrimGtEq _) = "PrimGtEq"
labelForPrimFun (PrimEq _) = "PrimEq"
labelForPrimFun (PrimNEq _) = "PrimNEq"
labelForPrimFun (PrimMax _) = "PrimMax"
labelForPrimFun (PrimMin _) = "PrimMin"
labelForPrimFun PrimLAnd = "PrimLAnd"
labelForPrimFun PrimLOr = "PrimLOr"
labelForPrimFun PrimLNot = "PrimLNot"
labelForPrimFun PrimOrd = "PrimOrd"
labelForPrimFun PrimChr = "PrimChr"
labelForPrimFun PrimBoolToInt = "PrimBoolToInt"
labelForPrimFun (PrimFromIntegral _ _) = "PrimFromIntegral"
labelForConst :: PrimConst a -> String
labelForConst (PrimMinBound _) = "PrimMinBound"
labelForConst (PrimMaxBound _) = "PrimMaxBound"
labelForConst (PrimPi _) = "PrimPi"