{- - Copyright (c) 2009, ERICSSON AB All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer - in the documentation and/or other materials provided with the - distribution. - * Neither the name of the ERICSSON AB nor the names of its - contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -} module Feldspar.Compiler.Optimization.Simplification where import qualified Data.Map as Map import qualified Data.Set as Set import Data.List hiding (insert,union) import Data.Maybe import Feldspar.Compiler.Imperative.Representation import Feldspar.Compiler.Optimization.Replace doSimplification :: [ImpFunction] -> [ImpFunction] doSimplification = map doSimplificationOne doSimplificationOne :: ImpFunction -> ImpFunction doSimplificationOne = backward . delUnused Set.empty . fst . computeSemInfVar . fst . propagate Map.empty . fst . computeSemInfVar -------------------------------------------------------------- -- Computing semantic information for variable declarations -- -------------------------------------------------------------- class ComputeSemInfVar t where computeSemInfVar :: t -> (t, VariableMap) instance ComputeSemInfVar ImpFunction where computeSemInfVar fun = (fun{ prg = fst result}, snd result) where result = computeSemInfVar $ prg fun instance ComputeSemInfVar CompleteProgram where computeSemInfVar (CompPrg locals body) = (CompPrg locs $ fst result, rest) where result = computeSemInfVar body dresult = computeSemInfVar locals locs = map updateLocal locals rest = Map.filterWithKey (\k _ -> not $ isLocal k) (snd result) isLocal name = Prelude.filter (\(Decl (Var n _ _) _ _ _) -> n == name) locals /= [] updateLocal d@(Decl (Var name _ _) _ _ _) = case Map.lookup name $ addVarMap (snd result) (snd dresult) of Nothing -> d Just inf -> d{ semInfVar = inf } instance ComputeSemInfVar Program where computeSemInfVar Empty = (Empty,Map.empty) computeSemInfVar p@(Primitive _ seminf) = (p, varMap seminf) computeSemInfVar (Seq ps sem) = (Seq (map fst result) sem, foldr addVarMap Map.empty $ map snd result) where result = map computeSemInfVar ps computeSemInfVar (IfThenElse (Var cName k t) p1 p2 sem) = (IfThenElse (Var cName k t) (fst result1) (fst result2) sem, foldr addVarMap condResult [snd result1, snd result2]) where result1 = computeSemInfVar p1 result2 = computeSemInfVar p2 condResult = Map.singleton cName $ SemInfVar None UnknownR computeSemInfVar (SeqLoop z@(Var c _ _) cp bp sem) = (SeqLoop z (fst cresult) (fst bresult) sem, iterated) where cresult = addCondVarInf $ computeSemInfVar cp bresult = computeSemInfVar bp iterated = multiply $ addVarMap (snd cresult) (snd bresult) addCondVarInf (CompPrg locs bod, sem) = (CompPrg (map addCondVarInfToDecl locs) bod, sem) addCondVarInfToDecl d | var d == z = d{ semInfVar = addSemInfVar (SemInfVar None MultipleR) $ semInfVar d} | otherwise = d computeSemInfVar (ParLoop init test count body seminfo) = (ParLoop init test count (fst bodyResult) seminfo, result) where result = multiply $ addVarMap (snd bodyResult) testResult bodyResult = computeSemInfVar body testResult = rightVarMap $ (\(Expr core _) -> core) test multiply m = fmap multiplyOne m multiplyOne sem = sem{ usedLeft = multiplyLeft $ usedLeft sem, usedRight = multiplyRight $ usedRight sem } multiplyLeft (Single _) = MultipleL multiplyLeft l = l multiplyRight (Times 0) = Times 0 multiplyRight (Times _) = MultipleR multiplyRight r = r instance (ComputeSemInfVar a) => ComputeSemInfVar [a] where computeSemInfVar xs = (fst result, foldr addVarMap Map.empty $ snd result) where result = unzip $ map computeSemInfVar xs instance ComputeSemInfVar Declaration where computeSemInfVar d@(Decl (Var name _ _) _ (Just ini) _) = (d, Map.singleton name $ SemInfVar (Single $ Just ini) (Times 0) ) computeSemInfVar d@(Decl (Var name _ _) _ Nothing _) = (d, Map.empty) -------------------- -- Simplification -- -------------------- type PropagateMap = Map.Map String (Maybe ImpLangExpr) type DelSet = Set.Set String class Simplification a where propagate :: PropagateMap -> a -> (a,PropagateMap) delUnused :: DelSet -> a -> a backward :: a -> a writesVar :: a -> String -> Bool readsVar :: a -> String -> Bool instance Simplification ImpFunction where propagate m (Fun n ips ops cprg) = (Fun n ips ops $ fst $ propagate m cprg, Map.empty) delUnused s (Fun n ips ops cprg) = Fun n ips ops $ delUnused s cprg backward fun = fun { prg = backward $ prg fun } writesVar fun var = False -- Should not be used. readsVar fun var = False -- Should not be used. instance Simplification CompleteProgram where propagate m (CompPrg dl b) = (CompPrg dl $ fst result, purgePropagateMap (snd result) dl) where result = propagate (Map.union m $ makePropagateMap dl) b delUnused s (CompPrg dl b) = CompPrg (fst result) $ delUnused (Set.union s $ snd result) b where result = makeUnusedSet dl backward (CompPrg dl b) = doBackward dl $ toPrgList $ backward b writesVar (CompPrg _ b) var = writesVar b var readsVar (CompPrg _ b) var = readsVar b var instance Simplification Program where propagate m Empty = (Empty, m) propagate m (Primitive instr seminf) = (Primitive (fst $ propagate m instr) (fst seminf'), snd seminf') where seminf' = propagate m seminf propagate m s@(Seq ps seminf) = (Seq (fst result) seminf, snd result) where result = propagate m ps propagate m (IfThenElse v cp1 cp2 seminf) = (IfThenElse v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where result1 = propagate m cp1 result2 = propagate m cp2 propagate m (SeqLoop v cp1 cp2 seminf) = (SeqLoop v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where result1 = propagate m cp1 result2 = propagate m cp2 propagate m (ParLoop v i1 i2 cp seminf) = (ParLoop v i1 i2 (fst result) seminf, snd result) where result = propagate m cp delUnused _ Empty = Empty delUnused s p@(Primitive _ seminf) | all (\v -> Set.member v s) $ leftVars $ varMap seminf = Empty | otherwise = p delUnused s (Seq ps seminf) = Seq (delUnused s ps) seminf delUnused s (IfThenElse v cp1 cp2 seminf) = IfThenElse v (delUnused s cp1) (delUnused s cp2) seminf delUnused s (SeqLoop v cp1 cp2 seminf) = SeqLoop v (delUnused s cp1) (delUnused s cp2) seminf delUnused s (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (delUnused s cp) seminf backward (Seq ps seminf) = Seq (backward ps) seminf backward (IfThenElse v cp1 cp2 seminf) = IfThenElse v (backward cp1) (backward cp2) seminf backward (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (backward cp) seminf backward (SeqLoop v cp1 cp2 seminf) = SeqLoop v (backward cp1) (backward cp2) seminf backward x = x writesVar Empty _ = False writesVar (Primitive i _) var = writesVar i var writesVar (Seq ps _) var = writesVar ps var writesVar (IfThenElse v cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var writesVar (SeqLoop _ cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var writesVar (ParLoop _ _ _ cp _) var = writesVar cp var readsVar Empty _ = False readsVar (Primitive i _) var = readsVar i var readsVar (Seq ps _) var = readsVar ps var readsVar (IfThenElse v cp1 cp2 _) var = (name v == var) || readsVar cp1 var || readsVar cp2 var readsVar (SeqLoop _ cp1 cp2 _) var = readsVar cp1 var || readsVar cp2 var readsVar (ParLoop _ _ _ cp _) var = readsVar cp var instance (Simplification a) => Simplification [a] where propagate m [] = ([],m) propagate m (x:xs) = (fst xresult : fst xsresult, snd xsresult) where xresult = propagate m x xsresult = propagate (snd xresult) xs delUnused s xs = map (delUnused s) xs backward xs = map backward xs writesVar xs var = any (\x -> writesVar x var) xs readsVar xs var = any (\x -> readsVar x var) xs instance Simplification Instruction where propagate m (Assign left right) = (Assign (fst $ propagate m left) (fst $ propagate m right), m) propagate m (CFun name ps) = (CFun name $ map (fst . propagate m) ps, m) delUnused _ = id backward = id writesVar (Assign left _) var = writesVar left var writesVar (CFun _ ps) var = any (\p -> writesVar p var) ps readsVar (Assign left right) var = readsVarHelp left var || readsVar right var readsVar (CFun _ ps) var = any (\p -> readsVar p var) ps instance Simplification SemInfPrim where propagate m seminf = (seminf{ varMap = seminf' }, updated) where updated = Map.map upd2 $ Map.mapWithKey upd1 m upd1 name expr = case Map.lookup name seminf' of Nothing -> expr Just sem -> case usedLeft sem of None -> expr Single e -> e _ -> Nothing upd2 expr = case expr of Nothing -> expr Just e | any (\v -> contains v e) $ leftVars seminf' -> Nothing | otherwise -> expr seminf' = Map.foldWithKey prop Map.empty $ varMap seminf prop :: String -> SemInfVar -> Map.Map String SemInfVar -> Map.Map String SemInfVar prop name sem other = addVarMap other $ addVarMap (Map.singleton name $ SemInfVar (propLeft $ usedLeft sem) (Times 0)) $ propRight name $ usedRight sem propLeft (Single (Just expr)) = Single $ Just $ fst $ propagate m expr propLeft x = x propRight :: String -> RightUse -> Map.Map String SemInfVar propRight name right = case Map.lookup name m of Just (Just e) -> Map.map (mult right) $ rightVarMap e _ -> Map.singleton name (SemInfVar None right) mult UnknownR sem = sem{ usedRight=UnknownR } mult (Times n) sem = case usedRight sem of Times n' -> sem{ usedRight = Times $ n*n' } _ -> sem mult MultipleR sem = case usedRight sem of UnknownR -> sem _ -> sem{ usedRight = MultipleR } delUnused _ = id backward = id writesVar _ _ = False readsVar _ _ = False instance Simplification ImpLangExpr where propagate m i@(Expr (LeftExpr (LVar (Var n _ _))) t) | Map.member n m = case m Map.! n of Nothing -> (i,m) Just expr -> (fst $ propagate m expr, m) | otherwise = (i,m) propagate m (Expr (LeftExpr x) t) = (Expr (LeftExpr (fst $ propagate m x)) t, m) propagate m (Expr (FunCall r s is) t) = (Expr (FunCall r s (map (fst . propagate m) is)) t, m) propagate m x = (x,m) delUnused _ = id backward = id writesVar (Expr (LeftExpr lv) _) var = writesVar lv var writesVar _ _ = False readsVar (Expr (LeftExpr lv) _) var = readsVar lv var readsVar (Expr (AddressOf lv) _) var = readsVar lv var readsVar (Expr (ConstExpr _) _) var = False readsVar (Expr (FunCall _ _ es) _) var = any (\e -> readsVar e var) es instance Simplification LeftValue where propagate m l@(LVar (Var n _ _)) | Map.member n m = case m Map.! n of Nothing -> (l,m) Just expr -> (getLeftValue $ fst $ propagate m expr, m) | otherwise = (l, m) propagate m (ArrayElem lv ile) = (ArrayElem (fst $ propagate m lv) (fst $ propagate m ile), m) propagate m (PointedVal lv) = (PointedVal (fst $ propagate m lv), m) delUnused _ = id backward = id writesVar (LVar v) var = name v == var writesVar (ArrayElem lv exp) var = writesVar lv var writesVar (PointedVal lv) var = writesVar lv var readsVar (LVar v) var = name v == var readsVar (ArrayElem lv exp) var = readsVar lv var || readsVar exp var readsVar (PointedVal lv) var = readsVar lv var instance Simplification Parameter where propagate m (In ile) = (In (fst $ propagate m ile), m) propagate m (Out (k, ile)) = (Out (k, (fst $ propagate m ile)), m) delUnused _ = id backward = id writesVar (In _) _ = False writesVar (Out (_,exp)) var = writesVar exp var readsVar (In exp) var = readsVar exp var readsVar (Out (_,exp)) var = readsVarHelp (getLeftValue exp) var makePropagateMap :: [Declaration] -> PropagateMap makePropagateMap dl = foldr Map.union Map.empty $ map makePropagateMap' dl where makePropagateMap' d = case usedLeft $ semInfVar d of Single (Just e) | usedRight (semInfVar d) == Times 1 || simpleExpr e -> Map.singleton (name $ var d) $ initVal d | otherwise -> Map.empty otherwise -> Map.empty simpleExpr (Expr (LeftExpr (LVar _)) _) = True simpleExpr (Expr (ConstExpr _) t) = simpleType t simpleExpr _ = False purgePropagateMap :: PropagateMap -> [Declaration] -> PropagateMap purgePropagateMap m dl = Map.differenceWith (\_ _ -> Nothing) m (makePropagateMap dl) combineExpr :: Maybe ImpLangExpr -> Maybe ImpLangExpr -> Maybe ImpLangExpr combineExpr e1 e2 | e1 == e2 = e1 | otherwise = Nothing makeUnusedSet :: [Declaration] -> ([Declaration],DelSet) makeUnusedSet [] = ([],Set.empty) makeUnusedSet (d:ds) = case usedRight $ semInfVar d of Times 0 -> (fst result, Set.insert (name $ var d) $ snd result) _ -> (d : fst result, snd result) where result = makeUnusedSet ds readsVarHelp :: LeftValue -> String -> Bool readsVarHelp (LVar _) _ = False readsVarHelp (ArrayElem lv exp) var = readsVarHelp lv var || readsVar exp var readsVarHelp (PointedVal lv) _ = False ----------------------------- -- Backward simplification -- ----------------------------- doBackward :: [Declaration] -> [Program] -> CompleteProgram doBackward ds ps | cont = doBackward ds' ps' | otherwise = CompPrg ds' (Seq ps' []) where (cont,ds',ps') = backwardRec ds ([],ps) backwardRec :: [Declaration] -> ([Program],[Program]) -> (Bool, [Declaration], [Program]) backwardRec ds (xs,[]) = (False, ds, reverse xs) backwardRec ds (xs,y:ys) = case backwardPossible ds xs y ys of Nothing -> backwardRec ds (y:xs,ys) Just (left,right,init) -> (True, fst result, init : snd result) where result = backwardRepl left right ds xs ys backwardPossible :: [Declaration] -> [Program] -> Program -> [Program] -> Maybe (LeftValue,String,Program) backwardPossible ds xs y ys = case y of (Primitive (Assign left (Expr (LeftExpr (LVar (Var name _ _))) _)) (SemInfPrim _ True)) -> check left name (Primitive (CFun fname [In (Expr (LeftExpr (LVar (Var name _ _))) _), Out (_,(Expr (LeftExpr left) _))]) (SemInfPrim _ True)) | isPrefixOf "copy" fname -> check left name -- TODO: eliminate string constant | otherwise -> Nothing _ -> Nothing where check left name | isJust declarationOK && beforeOK && afterOK = Just (left,name,fromJust declarationOK) | otherwise = Nothing where declarationOK = case find (declares name) ds of Just d -> case initVal d of Nothing -> Just Empty Just expr | simpleType (exprType expr) -> Just $ Primitive (Assign left expr) $ SemInfPrim Map.empty False | otherwise -> Nothing Nothing -> Nothing beforeOK = case useBefore of (False, _) -> False (True, False) -> True (True, True) -> case declarationOK of Nothing -> False Just Empty -> True Just _ -> False afterOK = not $ any (\p -> readsVar p name || writesVar p name) ys useBefore = foldl step (True,False) xs step (ok,out) prg = (ok',out') where out' = out || outRead || outWritten ok' | not ok = False | out && (varWritten || varRead) = False | outWritten && (varWritten || varRead) = False | outRead && varRead = False | otherwise = True outWritten = prg `writesVar` outName outRead = prg `readsVar` outName outName = getVarName left varWritten = prg `writesVar` name varRead = prg `readsVar` name backwardRepl :: LeftValue -> String -> [Declaration] -> [Program] -> [Program] -> ([Declaration], [Program]) backwardRepl lv var ds xs ys = (filter (not . declares var) ds, replaceLExpr (reverse xs ++ ys) (var,lv)) toPrgList :: Program -> [Program] toPrgList (Seq ps _) = ps toPrgList p = [p] declares :: String -> Declaration -> Bool declares n d = n == name (var d)