{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -- | This module implements common-subexpression elimination. This -- module does not actually remove the duplicate, but only replaces -- one with a diference to the other. E.g: -- -- @ -- let a = x + y -- let b = x + y -- @ -- -- becomes: -- -- @ -- let a = x + y -- let b = a -- @ -- -- After which copy propagation in the simplifier will actually remove -- the definition of @b@. -- -- Our CSE is still rather stupid. No normalisation is performed, so -- the expressions @x+y@ and @y+x@ will be considered distinct. -- Furthermore, no expression with its own binding will be considered -- equal to any other, since the variable names will be distinct. -- This affects SOACs in particular. module Futhark.Optimise.CSE ( performCSE , CSEInOp ) where import Control.Monad.Reader import qualified Data.Set as S import qualified Data.Map.Strict as M import Data.Semigroup ((<>)) import Futhark.Analysis.Alias import Futhark.Representation.AST import Futhark.Representation.AST.Attributes.Aliases import Futhark.Representation.Aliases (removeFunDefAliases, Aliases, consumedInStms) import qualified Futhark.Representation.Kernels.Kernel as Kernel import qualified Futhark.Representation.Kernels.KernelExp as KernelExp import qualified Futhark.Representation.SOACS.SOAC as SOAC import qualified Futhark.Representation.ExplicitMemory as ExplicitMemory import Futhark.Transform.Substitute import Futhark.Pass -- | Perform CSE on every functioon in a program. performCSE :: (Attributes lore, CanBeAliased (Op lore), CSEInOp (OpWithAliases (Op lore))) => Bool -> Pass lore lore performCSE cse_arrays = Pass "CSE" "Combine common subexpressions." $ intraproceduralTransformation $ return . removeFunDefAliases . cseInFunDef cse_arrays . analyseFun cseInFunDef :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => Bool -> FunDef lore -> FunDef lore cseInFunDef cse_arrays fundec = fundec { funDefBody = runReader (cseInBody $ funDefBody fundec) $ newCSEState cse_arrays } type CSEM lore = Reader (CSEState lore) cseInBody :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => Body lore -> CSEM lore (Body lore) cseInBody (Body bodyattr bnds res) = cseInStms (consumedInStms bnds res) (stmsToList bnds) $ do CSEState (_, nsubsts) _ <- ask return $ Body bodyattr mempty $ substituteNames nsubsts res cseInLambda :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => Lambda lore -> CSEM lore (Lambda lore) cseInLambda lam = do body' <- cseInBody $ lambdaBody lam return lam { lambdaBody = body' } cseInStms :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => Names -> [Stm lore] -> CSEM lore (Body lore) -> CSEM lore (Body lore) cseInStms _ [] m = m cseInStms consumed (bnd:bnds) m = cseInStm consumed bnd $ \bnd' -> do Body bodyattr bnds' es <- cseInStms consumed bnds m bnd'' <- mapM nestedCSE bnd' return $ Body bodyattr (stmsFromList bnd''<>bnds') es where nestedCSE bnd' = do e <- mapExpM cse $ stmExp bnd' return bnd' { stmExp = e } cse = identityMapper { mapOnBody = const cseInBody , mapOnOp = cseInOp } cseInStm :: Attributes lore => Names -> Stm lore -> ([Stm lore] -> CSEM lore a) -> CSEM lore a cseInStm consumed (Let pat (StmAux cs eattr) e) m = do CSEState (esubsts, nsubsts) cse_arrays <- ask let e' = substituteNames nsubsts e pat' = substituteNames nsubsts pat if any (bad cse_arrays) $ patternValueElements pat then m [Let pat' (StmAux cs eattr) e'] else case M.lookup (eattr, e') esubsts of Just subpat -> local (addNameSubst pat' subpat) $ do let lets = [ Let (Pattern [] [patElem']) (StmAux cs eattr) $ BasicOp $ SubExp $ Var $ patElemName patElem | (name,patElem) <- zip (patternNames pat') $ patternElements subpat , let patElem' = patElem { patElemName = name } ] m lets _ -> local (addExpSubst pat' eattr e') $ m [Let pat' (StmAux cs eattr) e'] where bad cse_arrays pe | Mem{} <- patElemType pe = True | Array{} <- patElemType pe, not cse_arrays = True | patElemName pe `S.member` consumed = True | otherwise = False type ExpressionSubstitutions lore = M.Map (ExpAttr lore, Exp lore) (Pattern lore) type NameSubstitutions = M.Map VName VName data CSEState lore = CSEState { _cseSubstitutions :: (ExpressionSubstitutions lore, NameSubstitutions) , _cseArrays :: Bool } newCSEState :: Bool -> CSEState lore newCSEState = CSEState (M.empty, M.empty) mkSubsts :: PatternT attr -> PatternT attr -> M.Map VName VName mkSubsts pat vs = M.fromList $ zip (patternNames pat) (patternNames vs) addNameSubst :: PatternT attr -> PatternT attr -> CSEState lore -> CSEState lore addNameSubst pat subpat (CSEState (esubsts, nsubsts) cse_arrays) = CSEState (esubsts, mkSubsts pat subpat `M.union` nsubsts) cse_arrays addExpSubst :: Attributes lore => Pattern lore -> ExpAttr lore -> Exp lore -> CSEState lore -> CSEState lore addExpSubst pat eattr e (CSEState (esubsts, nsubsts) cse_arrays) = CSEState (M.insert (eattr,e) pat esubsts, nsubsts) cse_arrays -- | The operations that permit CSE. class CSEInOp op where -- | Perform CSE within any nested expressions. cseInOp :: op -> CSEM lore op instance CSEInOp () where cseInOp () = return () subCSE :: CSEM lore r -> CSEM otherlore r subCSE m = do CSEState _ cse_arrays <- ask return $ runReader m $ newCSEState cse_arrays instance (Attributes lore, Aliased lore, CSEInOp (Op lore)) => CSEInOp (Kernel.Kernel lore) where cseInOp = subCSE . Kernel.mapKernelM (Kernel.KernelMapper return cseInLambda cseInBody return return cseInKernelBody) cseInKernelBody :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => Kernel.KernelBody lore -> CSEM lore (Kernel.KernelBody lore) cseInKernelBody (Kernel.KernelBody bodyattr bnds res) = do Body _ bnds' _ <- cseInBody $ Body bodyattr bnds [] return $ Kernel.KernelBody bodyattr bnds' res instance (Attributes lore, Aliased lore, CSEInOp (Op lore)) => CSEInOp (KernelExp.KernelExp lore) where cseInOp (KernelExp.Combine cspace ts active body) = subCSE $ KernelExp.Combine cspace ts active <$> cseInBody body cseInOp (KernelExp.GroupReduce w lam input) = subCSE $ KernelExp.GroupReduce w <$> cseInLambda lam <*> pure input cseInOp (KernelExp.GroupStream w max_chunk lam nes arrs) = subCSE $ KernelExp.GroupStream w max_chunk <$> cseInGroupStreamLambda lam <*> pure nes <*> pure arrs cseInOp op = return op cseInGroupStreamLambda :: (Attributes lore, Aliased lore, CSEInOp (Op lore)) => KernelExp.GroupStreamLambda lore -> CSEM lore (KernelExp.GroupStreamLambda lore) cseInGroupStreamLambda lam = do body' <- cseInBody $ KernelExp.groupStreamLambdaBody lam return lam { KernelExp.groupStreamLambdaBody = body' } instance CSEInOp op => CSEInOp (ExplicitMemory.MemOp op) where cseInOp o@ExplicitMemory.Alloc{} = return o cseInOp (ExplicitMemory.Inner k) = ExplicitMemory.Inner <$> subCSE (cseInOp k) instance (Attributes lore, CanBeAliased (Op lore), CSEInOp (OpWithAliases (Op lore))) => CSEInOp (SOAC.SOAC (Aliases lore)) where cseInOp = subCSE . SOAC.mapSOACM (SOAC.SOACMapper return cseInLambda return)