{-# LANGUAGE RankNTypes, LiberalTypeSynonyms, ScopedTypeVariables #-} module Compiler.Hoopl.Combinators ( SimpleFwdRewrite, SimpleFwdRewrite3, noFwdRewrite, thenFwdRw , shallowFwdRw3, shallowFwdRw, deepFwdRw3, deepFwdRw, iterFwdRw , SimpleBwdRewrite, SimpleBwdRewrite3, noBwdRewrite, thenBwdRw , shallowBwdRw3, shallowBwdRw, deepBwdRw3, deepBwdRw, iterBwdRw , pairFwd, pairBwd, pairLattice ) where import Control.Monad import Data.Function import Data.Maybe import Compiler.Hoopl.Collections import Compiler.Hoopl.Dataflow import Compiler.Hoopl.Graph (Graph, C, O) import Compiler.Hoopl.Label type FR m n f = FwdRewrite m n f type BR m n f = BwdRewrite m n f type FwdRes m n f e x = Maybe (FwdRew m n f e x) type SFRW m n f e x = n e x -> f -> m (Maybe (Graph n e x)) type FRW m n f e x = n e x -> f -> m (FwdRes m n f e x) type SimpleFwdRewrite3 m n f = ExTriple (SFRW m n f) type ExTriple a = (a C O, a O O, a O C) -- ^ entry/exit triple type SimpleFwdRewrite m n f = forall e x . SFRW m n f e x type LiftFRW m n f e x = SFRW m n f e x -> FRW m n f e x type MapFRW m n f e x = FRW m n f e x -> FRW m n f e x type MapFRW2 m n f e x = FRW m n f e x -> FRW m n f e x -> FRW m n f e x ---------------------------------------------------------------- -- common operations on triples uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d uncurry3 f (a, b, c) = f a b c apply :: (a -> b, d -> e, g -> h) -> (a, d, g) -> (b, e, h) apply (f1, f2, f3) (x1, x2, x3) = (f1 x1, f2 x2, f3 x3) applyBinary :: (a -> b -> c, d -> e -> f, g -> h -> i) -> (a, d, g) -> (b, e, h) -> (c, f, i) applyBinary (f1, f2, f3) (x1, x2, x3) (y1, y2, y3) = (f1 x1 y1, f2 x2 y2, f3 x3 y3) ---------------------------------------------------------------- wrapSFRewrite3 :: ExTriple (LiftFRW m n f) -> SimpleFwdRewrite3 m n f -> FR m n f wrapSFRewrite3 lift rw = uncurry3 mkFRewrite3 $ apply lift rw wrapFRewrite3 :: ExTriple (MapFRW m n f) -> FR m n f -> FR m n f wrapFRewrite3 map frw = uncurry3 mkFRewrite3 $ apply map $ getFRewrite3 frw wrapFRewrites23 :: ExTriple (MapFRW2 m n f) -> FR m n f -> FR m n f -> FR m n f wrapFRewrites23 map frw1 frw2 = uncurry3 mkFRewrite3 $ (applyBinary map `on` getFRewrite3) frw1 frw2 -- Combinators for higher-rank rewriting functions: wrapSFRewrites' :: (forall e x . LiftFRW m n f e x) -> SimpleFwdRewrite3 m n f -> FR m n f wrapSFRewrites' lift = wrapSFRewrite3 (lift, lift, lift) wrapFRewrites :: (forall e x . MapFRW m n f e x) -> FR m n f -> FR m n f wrapFRewrites map = wrapFRewrite3 (map, map, map) -- It's ugly that we can't use -- wrapFRewrites' = mkFRewrite' -- Would be nice to refactor here XXX ---NR wrapFRewrites2 :: (forall e x . MapFRW2 m n f e x) -> FR m n f -> FR m n f -> FR m n f wrapFRewrites2 map = wrapFRewrites23 (map, map, map) ---------------------------------------------------------------- shallowFwdRw3 :: forall m n f . Monad m => SimpleFwdRewrite3 m n f -> FwdRewrite m n f shallowFwdRw3 rw = wrapSFRewrites' lift rw where lift rw n f = liftM (liftM (flip FwdRew noFwdRewrite)) (rw n f) shallowFwdRw :: Monad m => SimpleFwdRewrite m n f -> FwdRewrite m n f shallowFwdRw f = shallowFwdRw3 (f, f, f) deepFwdRw3 :: Monad m => SimpleFwdRewrite3 m n f -> FwdRewrite m n f deepFwdRw :: Monad m => SimpleFwdRewrite m n f -> FwdRewrite m n f deepFwdRw3 r = iterFwdRw (shallowFwdRw3 r) deepFwdRw f = deepFwdRw3 (f, f, f) -- N.B. rw3, rw3', and rw3a are triples of functions. -- But rw and rw' are single functions. -- @ start comb1.tex thenFwdRw :: Monad m => FwdRewrite m n f -> FwdRewrite m n f -> FwdRewrite m n f -- @ end comb1.tex thenFwdRw rw3 rw3' = wrapFRewrites2 thenrw rw3 rw3' where thenrw rw rw' n f = rw n f >>= fwdRes where fwdRes Nothing = rw' n f fwdRes (Just (FwdRew g rw3a)) = return $ Just $ FwdRew g (rw3a `thenFwdRw` rw3') noFwdRewrite :: Monad m => FwdRewrite m n f noFwdRewrite = mkFRewrite $ \ _ _ -> return Nothing -- @ start iterf.tex iterFwdRw :: Monad m => FwdRewrite m n f -> FwdRewrite m n f -- @ end iterf.tex iterFwdRw rw3 = wrapFRewrites iter rw3 where iter rw n f = liftM (liftM fwdRes) (rw n f) fwdRes (FwdRew g rw3a) = FwdRew g (rw3a `thenFwdRw` iterFwdRw rw3) ---------------------------------------------------------------- type BwdRes m n f e x = Maybe (BwdRew m n f e x) type SBRW m n f e x = n e x -> Fact x f -> m (Maybe (Graph n e x)) type BRW m n f e x = n e x -> Fact x f -> m (BwdRes m n f e x) type SimpleBwdRewrite3 m n f = ExTriple ( SBRW m n f) type SimpleBwdRewrite m n f = forall e x . SBRW m n f e x type LiftBRW m n f e x = SBRW m n f e x -> BRW m n f e x type MapBRW m n f e x = BRW m n f e x -> BRW m n f e x type MapBRW2 m n f e x = BRW m n f e x -> BRW m n f e x -> BRW m n f e x ---------------------------------------------------------------- wrapSBRewrite3 :: ExTriple (LiftBRW m n f) -> SimpleBwdRewrite3 m n f -> BwdRewrite m n f wrapSBRewrite3 lift rw = uncurry3 mkBRewrite3 $ apply lift rw wrapBRewrite3 :: ExTriple (MapBRW m n f) -> BwdRewrite m n f -> BwdRewrite m n f wrapBRewrite3 map rw = uncurry3 mkBRewrite3 $ apply map $ getBRewrite3 rw wrapBRewrites2 :: ExTriple (MapBRW2 m n f) -> BR m n f -> BR m n f -> BR m n f wrapBRewrites2 map rw1 rw2 = uncurry3 mkBRewrite3 $ (applyBinary map `on` getBRewrite3) rw1 rw2 -- Combinators for higher-rank rewriting functions: wrapSBRewrites' :: (forall e x . LiftBRW m n f e x) -> SimpleBwdRewrite3 m n f -> BR m n f wrapSBRewrites' lift = wrapSBRewrite3 (lift, lift, lift) wrapBRewrites' :: (forall e x . MapBRW m n f e x) -> BwdRewrite m n f -> BwdRewrite m n f wrapBRewrites' map = wrapBRewrite3 (map, map, map) wrapBRewrites2' :: (forall e x . MapBRW2 m n f e x) -> BR m n f -> BR m n f -> BR m n f wrapBRewrites2' map = wrapBRewrites2 (map, map, map) ---------------------------------------------------------------- noBwdRewrite :: Monad m => BwdRewrite m n f noBwdRewrite = mkBRewrite $ \ _ _ -> return Nothing shallowBwdRw3 :: Monad m => SimpleBwdRewrite3 m n f -> BwdRewrite m n f shallowBwdRw3 rw = wrapSBRewrites' lift rw where lift rw n f = liftM (liftM (flip BwdRew noBwdRewrite)) (rw n f) shallowBwdRw :: Monad m => SimpleBwdRewrite m n f -> BwdRewrite m n f shallowBwdRw f = shallowBwdRw3 (f, f, f) deepBwdRw3 :: Monad m => SimpleBwdRewrite3 m n f -> BwdRewrite m n f deepBwdRw :: Monad m => SimpleBwdRewrite m n f -> BwdRewrite m n f deepBwdRw3 r = iterBwdRw (shallowBwdRw3 r) deepBwdRw f = deepBwdRw3 (f, f, f) thenBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f -> BwdRewrite m n f thenBwdRw rw1 rw2 = wrapBRewrites2' f rw1 rw2 where f rw1 rw2' n f = do res1 <- rw1 n f case res1 of Nothing -> rw2' n f Just (BwdRew g rw1a) -> return $ Just $ BwdRew g (rw1a `thenBwdRw` rw2) iterBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f iterBwdRw rw = wrapBRewrites' f rw where f rw' n f = liftM (liftM iterRewrite) (rw' n f) iterRewrite (BwdRew g rw2) = BwdRew g (rw2 `thenBwdRw` iterBwdRw rw) -- @ start pairf.tex pairFwd :: Monad m => FwdPass m n f -> FwdPass m n f' -> FwdPass m n (f, f') -- @ end pairf.tex pairFwd pass1 pass2 = FwdPass lattice transfer rewrite where lattice = pairLattice (fp_lattice pass1) (fp_lattice pass2) transfer = mkFTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2) where tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2) tfb t1 t2 n (f1, f2) = mapMapWithKey withfb2 fb1 where fb1 = t1 n f1 fb2 = t2 n f2 withfb2 l f = (f, fromMaybe bot2 $ lookupFact l fb2) bot2 = fact_bot (fp_lattice pass2) (tf1, tm1, tl1) = getFTransfer3 (fp_transfer pass1) (tf2, tm2, tl2) = getFTransfer3 (fp_transfer pass2) rewrite = liftRW (fp_rewrite pass1) fst `thenFwdRw` liftRW (fp_rewrite pass2) snd where liftRW rws proj = mkFRewrite3 (lift f) (lift m) (lift l) where lift rw n f = liftM (liftM projRewrite) $ rw n (proj f) projRewrite (FwdRew g rws') = FwdRew g $ liftRW rws' proj (f, m, l) = getFRewrite3 rws pairBwd :: forall m n f f' . Monad m => BwdPass m n f -> BwdPass m n f' -> BwdPass m n (f, f') pairBwd pass1 pass2 = BwdPass lattice transfer rewrite where lattice = pairLattice (bp_lattice pass1) (bp_lattice pass2) transfer = mkBTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2) where tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2) tfb t1 t2 n fb = (t1 n $ mapMap fst fb, t2 n $ mapMap snd fb) (tf1, tm1, tl1) = getBTransfer3 (bp_transfer pass1) (tf2, tm2, tl2) = getBTransfer3 (bp_transfer pass2) rewrite = liftRW (bp_rewrite pass1) fst `thenBwdRw` liftRW (bp_rewrite pass2) snd where liftRW :: forall f1 . BwdRewrite m n f1 -> ((f, f') -> f1) -> BwdRewrite m n (f, f') liftRW rws proj = mkBRewrite3 (lift proj f) (lift proj m) (lift (mapMap proj) l) where lift proj' rw n f = liftM (liftM projRewrite) $ rw n (proj' f) projRewrite (BwdRew g rws') = BwdRew g $ liftRW rws' proj (f, m, l) = getBRewrite3 rws pairLattice :: forall f f' . DataflowLattice f -> DataflowLattice f' -> DataflowLattice (f, f') pairLattice l1 l2 = DataflowLattice { fact_name = fact_name l1 ++ " x " ++ fact_name l2 , fact_bot = (fact_bot l1, fact_bot l2) , fact_join = join } where join lbl (OldFact (o1, o2)) (NewFact (n1, n2)) = (c', (f1, f2)) where (c1, f1) = fact_join l1 lbl (OldFact o1) (NewFact n1) (c2, f2) = fact_join l2 lbl (OldFact o2) (NewFact n2) c' = case (c1, c2) of (NoChange, NoChange) -> NoChange _ -> SomeChange