{-| Copyright : (C) 2012-2016, University of Twente License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij Rewriting combinators and traversals -} {-# LANGUAGE ScopedTypeVariables #-} module Clash.Rewrite.Combinators where import Control.DeepSeq (deepseq) import Control.Monad ((<=<), (>=>)) import qualified Control.Monad.Writer as Writer import qualified Data.Monoid as Monoid import Unbound.Generics.LocallyNameless (Embed, Fresh, bind, embed, rec, unbind, unembed, unrec) import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) import Clash.Core.Term (Pat, Term (..)) import Clash.Core.Util (patIds) import Clash.Core.Var (Id) import Clash.Rewrite.Types -- | Apply a transformation on the subtrees of an term allR :: forall m . (Monad m, Fresh m) => Bool -- ^ Freshen variable references in abstracted terms -> Transform m -- ^ The transformation to apply to the subtrees -> Transform m allR _ _ _ (Var t x) = return (Var t x) allR _ _ _ (Data dc) = return (Data dc) allR _ _ _ (Literal l) = return (Literal l) allR _ _ _ (Prim nm t) = return (Prim nm t) allR rf trans c (Lam b) = do (v,e) <- if rf then unbind b else return (unsafeUnbind b) e' <- trans (LamBody v:c) e return . Lam $ bind v e' allR rf trans c (TyLam b) = do (tv, e) <- if rf then unbind b else return (unsafeUnbind b) e' <- trans (TyLamBody tv:c) e return . TyLam $ bind tv e' allR _ trans c (App e1 e2) = do e1' <- trans (AppFun:c) e1 e2' <- trans (AppArg:c) e2 return $ App e1' e2' allR _ trans c (TyApp e ty) = do e' <- trans (TyAppC:c) e return $ TyApp e' ty allR _ trans c (Cast e ty1 ty2) = do e' <- trans (CastBody:c) e return $ Cast e' ty1 ty2 allR rf trans c (Letrec b) = do (xesR,e) <- if rf then unbind b else return (unsafeUnbind b) let xes = unrec xesR let bndrs = map fst xes e' <- trans (LetBody bndrs:c) e xes' <- mapM (rewriteBind bndrs) xes return . Letrec $ bind (rec xes') e' where rewriteBind :: [Id] -> (Id,Embed Term) -> m (Id,Embed Term) rewriteBind bndrs (b', e) = do e' <- trans (LetBinding b' bndrs:c) (unembed e) return (b',embed e') allR rf trans c (Case scrut ty alts) = do scrut' <- trans (CaseScrut:c) scrut alts' <- if rf then mapM (fmap (uncurry bind) . rewriteAlt <=< unbind) alts else mapM (fmap (uncurry bind) . rewriteAlt . unsafeUnbind) alts return $ Case scrut' ty alts' where rewriteAlt :: (Pat, Term) -> m (Pat, Term) rewriteAlt (p,e) = do let (tvs,ids) = patIds p e' <- trans (CaseAlt tvs ids:c) e return (p,e') infixr 6 >-> -- | Apply two transformations in succession (>->) :: Monad m => Transform m -> Transform m -> Transform m (>->) r1 r2 c = r1 c >=> r2 c infixr 6 >-!-> -- | Apply two transformations in succession, and perform a deepseq in between. (>-!->) :: Monad m => Transform m -> Transform m -> Transform m (>-!->) r1 r2 c e = do e' <- r1 c e deepseq e' (r2 c e') {- Note [topdown repeatR] ~~~~~~~~~~~~~~~~~~~~~~ In a topdown traversal we need to repeat the transformation r because if r replaces a parent node with one of its children we should still apply r to that child, before continuing with its children. Example: topdownR (inlineBinders (\_ _ -> return True)) on: > letrec > x = 1 > in letrec > y = 2 > in f x y inlineBinders would inline x and return: > letrec > y = 2 > in f 1 y Then we must repeat the transformation to let it also inline y. -} -- | Apply a transformation in a topdown traversal topdownR :: Rewrite m -> Rewrite m -- See Note [topdown repeatR] topdownR r = repeatR r >-> allR True (topdownR r) -- | Apply a transformation in a topdown traversal. Doesn't freshen bound -- variables unsafeTopdownR :: Rewrite m -> Rewrite m -- See NOTE [topdown repeatR] unsafeTopdownR r = repeatR r >-> allR False (unsafeTopdownR r) -- | Apply a transformation in a bottomup traversal bottomupR :: Fresh m => Transform m -> Transform m bottomupR r = allR True (bottomupR r) >-> r -- | Apply a transformation in a bottomup traversal. Doesn't freshen bound -- variables unsafeBottomupR :: Fresh m => Transform m -> Transform m unsafeBottomupR r = allR False (unsafeBottomupR r) >-> r infixr 5 !-> -- | Only apply the second transformation if the first one succeeds. (!->) :: Rewrite m -> Rewrite m -> Rewrite m (!->) r1 r2 c expr = do (expr',changed) <- Writer.listen $ r1 c expr if Monoid.getAny changed then r2 c expr' else return expr' infixr 5 >-! -- | Only apply the second transformation if the first one fails. (>-!) :: Rewrite m -> Rewrite m -> Rewrite m (>-!) r1 r2 c expr = do (expr',changed) <- Writer.listen $ r1 c expr if Monoid.getAny changed then return expr' else r2 c expr' -- | Keep applying a transformation until it fails. repeatR :: Rewrite m -> Rewrite m repeatR r = r !-> repeatR r whenR :: Monad m => ([CoreContext] -> Term -> m Bool) -> Transform m -> Transform m whenR f r1 ctx expr = do b <- f ctx expr if b then r1 ctx expr else return expr -- | Only traverse downwards when the assertion evaluates to true bottomupWhenR :: Fresh m => ([CoreContext] -> Term -> m Bool) -> Transform m -> Transform m bottomupWhenR f r ctx expr = do b <- f ctx expr if b then (allR True (bottomupWhenR f r) >-> r) ctx expr else r ctx expr