module Data.Array.Accelerate.Trafo.Simplify (
Simplify(..),
) where
import Prelude hiding ( exp, iterate )
import Data.List ( nubBy )
import Data.Maybe
import Data.Monoid
import Data.Typeable
import Control.Applicative hiding ( Const )
import Data.Array.Accelerate.AST hiding ( prj )
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Trafo.Algebra
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Analysis.Shape
import Data.Array.Accelerate.Array.Sugar ( Elt, Shape, Slice, toElt, fromElt, (:.)(..) )
import Data.Array.Accelerate.Pretty.Print
import qualified Data.Array.Accelerate.Debug as Stats
#include "accelerate.h"
class Simplify f where
simplify :: f -> f
instance Kit acc => Simplify (PreFun acc aenv f) where
simplify = simplifyFun
instance Kit acc => Simplify (PreExp acc aenv e) where
simplify = simplifyExp
localCSE :: (Kit acc, Elt a, Elt b)
=> Gamma acc env env aenv
-> PreOpenExp acc env aenv a
-> PreOpenExp acc (env,a) aenv b
-> Maybe (PreOpenExp acc env aenv b)
localCSE env bnd body
| Just ix <- lookupExp env bnd = Stats.ruleFired "CSE" . Just $ inline body (Var ix)
| otherwise = Nothing
recoverLoops
:: (Kit acc, Elt b)
=> Gamma acc env env aenv
-> PreOpenExp acc env aenv a
-> PreOpenExp acc (env,a) aenv b
-> Maybe (PreOpenExp acc env aenv b)
recoverLoops _ _ _
= Nothing
simplifyOpenExp
:: forall acc env aenv e. Kit acc
=> Gamma acc env env aenv
-> PreOpenExp acc env aenv e
-> (Bool, PreOpenExp acc env aenv e)
simplifyOpenExp env = first getAny . cvtE
where
cvtE :: PreOpenExp acc env aenv t -> (Any, PreOpenExp acc env aenv t)
cvtE exp = case exp of
Let bnd body
| Just reduct <- localCSE env (snd bnd') (snd body') -> yes . snd $ cvtE reduct
| Just reduct <- recoverLoops env (snd bnd') (snd body') -> yes . snd $ cvtE reduct
| otherwise -> Let <$> bnd' <*> body'
where
bnd' = cvtE bnd
env' = PushExp env (snd bnd')
body' = cvtE' (incExp env') body
Var ix -> pure $ Var ix
Const c -> pure $ Const c
Tuple tup -> Tuple <$> cvtT tup
Prj ix t -> prj ix (cvtE t)
IndexNil -> pure IndexNil
IndexAny -> pure IndexAny
IndexCons sh sz -> indexCons (cvtE sh) (cvtE sz)
IndexHead sh -> indexHead (cvtE sh)
IndexTail sh -> indexTail (cvtE sh)
IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh
IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl
ToIndex sh ix -> ToIndex <$> cvtE sh <*> cvtE ix
FromIndex sh ix -> FromIndex <$> cvtE sh <*> cvtE ix
Cond p t e -> cond (cvtE p) (cvtE t) (cvtE e)
PrimConst c -> pure $ PrimConst c
PrimApp f x -> evalPrimApp env f <$> cvtE x
Index a sh -> Index a <$> cvtE sh
LinearIndex a i -> LinearIndex a <$> cvtE i
Shape a -> pure $ Shape a
ShapeSize sh -> ShapeSize <$> cvtE sh
Intersect s t -> cvtE s `intersect` cvtE t
Foreign ff f e -> Foreign ff <$> first Any (simplifyOpenFun EmptyExp f) <*> cvtE e
While p f x -> While <$> cvtF env p <*> cvtF env f <*> cvtE x
cvtT :: Tuple (PreOpenExp acc env aenv) t -> (Any, Tuple (PreOpenExp acc env aenv) t)
cvtT NilTup = pure NilTup
cvtT (SnocTup t e) = SnocTup <$> cvtT t <*> cvtE e
cvtE' :: Gamma acc env' env' aenv -> PreOpenExp acc env' aenv e' -> (Any, PreOpenExp acc env' aenv e')
cvtE' env' = first Any . simplifyOpenExp env'
cvtF :: Gamma acc env' env' aenv -> PreOpenFun acc env' aenv f -> (Any, PreOpenFun acc env' aenv f)
cvtF env' = first Any . simplifyOpenFun env'
intersect :: Shape t
=> (Any, PreOpenExp acc env aenv t)
-> (Any, PreOpenExp acc env aenv t)
-> (Any, PreOpenExp acc env aenv t)
intersect (c1, sh1) (c2, sh2)
| Nothing <- match sh sh' = Stats.ruleFired "intersect" (yes sh')
| otherwise = (c1 <> c2, sh')
where
sh = Intersect sh1 sh2
sh' = foldl1 Intersect
$ nubBy (\x y -> isJust (match x y))
$ leaves sh1 ++ leaves sh2
leaves :: Shape t => PreOpenExp acc env aenv t -> [PreOpenExp acc env aenv t]
leaves (Intersect x y) = leaves x ++ leaves y
leaves rest = [rest]
cond :: forall t. Elt t
=> (Any, PreOpenExp acc env aenv Bool)
-> (Any, PreOpenExp acc env aenv t)
-> (Any, PreOpenExp acc env aenv t)
-> (Any, PreOpenExp acc env aenv t)
cond p@(_,p') t@(_,t') e@(_,e')
| Const ((),True) <- p' = Stats.knownBranch "True" (yes t')
| Const ((),False) <- p' = Stats.knownBranch "False" (yes e')
| Just REFL <- match t' e' = Stats.knownBranch "redundant" (yes e')
| otherwise = Cond <$> p <*> t <*> e
prj :: forall s t. (Elt s, Elt t, IsTuple t)
=> TupleIdx (TupleRepr t) s
-> (Any, PreOpenExp acc env aenv t)
-> (Any, PreOpenExp acc env aenv s)
prj ix exp@(_,exp')
| Tuple t <- exp' = Stats.inline "prj/Tuple" . yes $ prjT ix t
| Const c <- exp' = Stats.inline "prj/Const" . yes $ prjC ix (fromTuple (toElt c :: t))
| Let a b <- exp' = Stats.ruleFired "prj/Let" $ cvtE (Let a (Prj ix b))
| otherwise = Prj ix <$> exp
where
prjT :: TupleIdx tup s -> Tuple (PreOpenExp acc env aenv) tup -> PreOpenExp acc env aenv s
prjT ZeroTupIdx (SnocTup _ e) = e
prjT (SuccTupIdx idx) (SnocTup t _) = prjT idx t
prjT _ _ = error "DO MORE OF WHAT MAKES YOU HAPPY"
prjC :: TupleIdx tup s -> tup -> PreOpenExp acc env aenv s
prjC ZeroTupIdx (_, v) = Const (fromElt v)
prjC (SuccTupIdx idx) (tup, _) = prjC idx tup
indexCons :: (Slice sl, Elt sz)
=> (Any, PreOpenExp acc env aenv sl)
-> (Any, PreOpenExp acc env aenv sz)
-> (Any, PreOpenExp acc env aenv (sl :. sz))
indexCons (_,sl') (_,sz')
| Just REFL <- match sl' IndexNil
, IndexHead sh <- sz'
, expDim sz' == 1
, Just sh' <- gcast sh
= yes sh'
indexCons sl sz
= IndexCons <$> sl <*> sz
indexHead :: (Slice sl, Elt sz) => (Any, PreOpenExp acc env aenv (sl :. sz)) -> (Any, PreOpenExp acc env aenv sz)
indexHead (_, IndexCons _ sz) = yes sz
indexHead sh = IndexHead <$> sh
indexTail :: (Slice sl, Elt sz) => (Any, PreOpenExp acc env aenv (sl :. sz)) -> (Any, PreOpenExp acc env aenv sl)
indexTail (_, IndexCons sl _) = yes sl
indexTail sh = IndexTail <$> sh
first :: (a -> a') -> (a,b) -> (a',b)
first f (x,y) = (f x, y)
yes :: x -> (Any, x)
yes x = (Any True, x)
simplifyOpenFun
:: Kit acc
=> Gamma acc env env aenv
-> PreOpenFun acc env aenv f
-> (Bool, PreOpenFun acc env aenv f)
simplifyOpenFun env (Body e) = Body <$> simplifyOpenExp env e
simplifyOpenFun env (Lam f) = Lam <$> simplifyOpenFun env' f
where
env' = incExp env `PushExp` Var ZeroIdx
simplifyExp :: Kit acc => PreExp acc aenv t -> PreExp acc aenv t
simplifyExp = iterate (show . prettyPreExp prettyAcc 0 0 noParens) (simplifyOpenExp EmptyExp)
simplifyFun :: Kit acc => PreFun acc aenv f -> PreFun acc aenv f
simplifyFun = iterate (show . prettyPreFun prettyAcc 0) (simplifyOpenFun EmptyExp)
iterate
:: forall f a. (Match f, Shrink (f a))
=> (f a -> String)
-> (f a -> (Bool, f a))
-> f a
-> f a
iterate ppr f = fix 0 . setup . simplify'
where
lIMIT = 1
simplify' = Stats.simplifierDone . f
setup (_,x) = msg x x
fix :: Int -> f a -> f a
fix !i !x0
| i >= lIMIT = INTERNAL_CHECK(warning) "iterate" "iteration limit reached" (x0 ==^ f x0) x0
| not shrunk = x1
| not simplified = x2
| otherwise = fix (i+1) x2
where
(shrunk, x1) = trace $ shrink' x0
(simplified, x2) = trace $ simplify' x1
u ==^ (_,v) = isJust (match u v)
trace v@(changed,x)
| changed = msg x v
| otherwise = v
msg :: f a -> x -> x
msg x next = Stats.tracePure Stats.dump_simpl_iterations (unlines [ "simplifier done", ppr x ]) next