{-# LANGUAGE GADTs, KindSignatures, TypeFamilies, MultiParamTypeClasses , ScopedTypeVariables, PatternGuards #-} {-# OPTIONS_GHC -Wall -fno-warn-unused-imports -fno-warn-orphans -fno-warn-missing-signatures #-} ---------------------------------------------------------------------- -- | -- Module : Shady.Language.Cse -- Copyright : (c) Conal Elliott 2009 -- License : AGPLv3 -- -- Maintainer : conal@conal.net -- Stability : experimental -- -- Common subexpression elimination. -- -- TODO: Improve variable names (now \"x8\" etc). ---------------------------------------------------------------------- module Shady.Language.Cse (cse) where import Control.Applicative (pure,(<$>),(<*>)) import Data.Maybe (fromMaybe) import qualified Data.IntMap as I import System.IO.Unsafe (unsafePerformIO) import Shady.Misc import Shady.Language.Type import Shady.Language.Operator import Shady.Language.Exp import Shady.Language.Graph import Shady.Language.Reify -- V from Tid ev :: Tid a -> V a ev (Tid i t) = V ('x':show i) t children :: N a -> [NodeId] children (VN _) = [] children (ON _) = [] children (App (Tid a _) (Tid b _)) = [a,b] childrenB :: Bind -> [NodeId] childrenB (Bind _ n) = children n -- Number of references for each node. Important: partially apply, so -- that the binding list can be converted just once into an efficiently -- searchable representation. uses :: [Bind] -> (NodeId -> Int) uses = fmap (fromMaybe 0) . flip I.lookup . histogram . concatMap childrenB -- histogram :: Ord k => [k] -> I.Map k Int -- histogram = foldr (\ k -> I.insertWith (+) k 1) I.empty histogram :: [Int] -> I.IntMap Int histogram = foldr (\ k -> I.insertWith (+) k 1) I.empty -- Fast version, using an IntMap. Important: partially apply. bindsF :: forall a. [Bind] -> (Tid a -> N a) bindsF binds = \ (Tid i' a') -> extract a' (I.lookup i' m) where m :: I.IntMap Bind m = I.fromList [(i,b) | b@(Bind i _) <- binds] extract :: Type a' -> Maybe Bind -> N a' extract _ Nothing = error "bindsF: variable not found" extract a' (Just (Bind _ n)) | Just Refl <- typeOf1 n `tyEq` a' = n | otherwise = error $ "bindsF: wrong type. " ++ show (typeOf1 n) ++ " vs " ++ show a' tid :: HasType a => NodeId -> Tid a tid i = Tid i typeT letI :: (HasType a, HasType b) => NodeId -> E a -> E b -> E b letI i = letE (ev (tid i)) unGraph :: HasType a => Graph a -> E a unGraph (Graph binds root) = foldr llet (var' root) (reverse binds) where -- Wrap a let if non-trivial llet :: HasType b => Bind -> E b -> E b llet bind | trivial bind = id llet (Bind i n) = letI i (nodeE' n) -- How many uses of variable count :: NodeId -> Int count = uses binds -- Bindings as IntMap lookup psf :: Tid a -> N a psf = bindsF binds -- Too trivial to bother abstracting. trivial :: Bind -> Bool trivial (Bind _ (VN _)) = True trivial (Bind _ (ON (Lit a))) = not (abstractable a) trivial (Bind _ (ON _)) = True trivial (Bind i _) | count i < 2 = True trivial _ = False -- Like nodeE but with inlining of trivial bindings nodeE' :: N a -> E a nodeE' (VN v) = Var v nodeE' (ON o) = Op o nodeE' (App a b) = var' a :^ var' b -- Variable reference or inline var' :: HasType a => Tid a -> E a var' t@(Tid i _) | trivial (Bind i n) = nodeE' n | otherwise = Var (ev t) where n = psf t -- Possible and worthwhile to abstract. abstractable :: forall a. HasType a => a -> Bool abstractable a = case (typeOf a :: Type a) of VecT (VectorT n _) -> natToZ n > 1 _ -> False -- | Common subexpression elimination. Use with care, since it breaks -- referential transparency on the /representation/ of expressions, but -- not on their meaning. cse :: HasType a => E a -> E a cse = unsafePerformIO . fmap unGraph . reifyGraph {- -- Remove the comment braces to use the testing code {-------------------------------------------------------------------- Testing --------------------------------------------------------------------} -- Simpler version of unGraph. No inlining. unGraph' :: HasType a => Graph a -> E a unGraph' (Graph binds root) = foldr f (Var (ev root)) (reverse binds) where f :: Bind -> (forall b. HasType b => E b -> E b) f (Bind i n) = letE (ev (Tid i (typeOf1 n))) (nodeE n) nodeE (VN v) = Var v nodeE (ON o) = Op o nodeE (App u v) = Var (ev u) :^ Var (ev v) -- Convert expressions to simple SSA forms ssa :: HasType a => E a -> IO (E a) ssa = fmap unGraph' . reifyGraph -- type-specialize reify :: HasType a => E a -> IO (Graph a) reify = reifyGraph type I1 = One Int va, vb :: E I1 va = Var (var "a") vb = Var (var "b") -- test expressions e1 = va + vb :: E I1 e2 = e1 * e1 e3 = va + va :: E I1 -- For instance, -- > e2 -- (a + b) * (a + b) -- -- > reify e2 -- let [0 = App x1 x3,1 = App x2 x3,3 = App x4 x7,7 = VN b,4 = App x5 x6,6 = VN a,5 = ON (+),2 = ON (*)] in x0 -- -- > ssa e2 -- let x2 = (*) in -- let x5 = (+) in -- let x6 = a in -- let x4 = x5 x6 in -- let x7 = b in -- let x3 = x4 x7 in -- let x1 = x2 x3 in -- let x0 = x1 x3 in -- x0 -- -- > cse e2 -- let x3 = a + b in -- x3 * x3 -}