module Feldspar.Compiler.Optimization.Simplification where
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List hiding (insert,union)
import Data.Maybe
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Optimization.Replace
doSimplification :: [ImpFunction] -> [ImpFunction]
doSimplification = map doSimplificationOne
doSimplificationOne :: ImpFunction -> ImpFunction
doSimplificationOne = backward . delUnused Set.empty . fst . computeSemInfVar . fst . propagate Map.empty . fst . computeSemInfVar
class ComputeSemInfVar t where
computeSemInfVar :: t -> (t, VariableMap)
instance ComputeSemInfVar ImpFunction where
computeSemInfVar fun = (fun{ prg = fst result}, snd result)
where
result = computeSemInfVar $ prg fun
instance ComputeSemInfVar CompleteProgram where
computeSemInfVar (CompPrg locals body) = (CompPrg locs $ fst result, rest)
where
result = computeSemInfVar body
dresult = computeSemInfVar locals
locs = map updateLocal locals
rest = Map.filterWithKey (\k _ -> not $ isLocal k) (snd result)
isLocal name = Prelude.filter (\(Decl (Var n _ _) _ _ _) -> n == name) locals /= []
updateLocal d@(Decl (Var name _ _) _ _ _) = case Map.lookup name $ addVarMap (snd result) (snd dresult) of
Nothing -> d
Just inf -> d{ semInfVar = inf }
instance ComputeSemInfVar Program where
computeSemInfVar Empty = (Empty,Map.empty)
computeSemInfVar p@(Primitive _ seminf) = (p, varMap seminf)
computeSemInfVar (Seq ps sem) = (Seq (map fst result) sem, foldr addVarMap Map.empty $ map snd result) where
result = map computeSemInfVar ps
computeSemInfVar (IfThenElse (Var cName k t) p1 p2 sem)
= (IfThenElse (Var cName k t) (fst result1) (fst result2) sem, foldr addVarMap condResult [snd result1, snd result2]) where
result1 = computeSemInfVar p1
result2 = computeSemInfVar p2
condResult = Map.singleton cName $ SemInfVar None UnknownR
computeSemInfVar (SeqLoop z@(Var c _ _) cp bp sem)
= (SeqLoop z (fst cresult) (fst bresult) sem, iterated) where
cresult = addCondVarInf $ computeSemInfVar cp
bresult = computeSemInfVar bp
iterated = multiply $ addVarMap (snd cresult) (snd bresult)
addCondVarInf (CompPrg locs bod, sem) = (CompPrg (map addCondVarInfToDecl locs) bod, sem)
addCondVarInfToDecl d
| var d == z = d{ semInfVar = addSemInfVar (SemInfVar None MultipleR) $ semInfVar d}
| otherwise = d
computeSemInfVar (ParLoop init test count body seminfo)
= (ParLoop init test count (fst bodyResult) seminfo, result)
where
result = multiply $ addVarMap (snd bodyResult) testResult
bodyResult = computeSemInfVar body
testResult = rightVarMap $ (\(Expr core _) -> core) test
multiply m = fmap multiplyOne m
multiplyOne sem = sem{ usedLeft = multiplyLeft $ usedLeft sem, usedRight = multiplyRight $ usedRight sem }
multiplyLeft (Single _) = MultipleL
multiplyLeft l = l
multiplyRight (Times 0) = Times 0
multiplyRight (Times _) = MultipleR
multiplyRight r = r
instance (ComputeSemInfVar a) => ComputeSemInfVar [a] where
computeSemInfVar xs = (fst result, foldr addVarMap Map.empty $ snd result) where
result = unzip $ map computeSemInfVar xs
instance ComputeSemInfVar Declaration where
computeSemInfVar d@(Decl (Var name _ _) _ (Just ini) _) = (d, Map.singleton name $ SemInfVar (Single $ Just ini) (Times 0) )
computeSemInfVar d@(Decl (Var name _ _) _ Nothing _) = (d, Map.empty)
type PropagateMap = Map.Map String (Maybe ImpLangExpr)
type DelSet = Set.Set String
class Simplification a where
propagate :: PropagateMap -> a -> (a,PropagateMap)
delUnused :: DelSet -> a -> a
backward :: a -> a
writesVar :: a -> String -> Bool
readsVar :: a -> String -> Bool
instance Simplification ImpFunction where
propagate m (Fun n ips ops cprg) = (Fun n ips ops $ fst $ propagate m cprg, Map.empty)
delUnused s (Fun n ips ops cprg) = Fun n ips ops $ delUnused s cprg
backward fun = fun { prg = backward $ prg fun }
writesVar fun var = False
readsVar fun var = False
instance Simplification CompleteProgram where
propagate m (CompPrg dl b) = (CompPrg dl $ fst result, purgePropagateMap (snd result) dl) where
result = propagate (Map.union m $ makePropagateMap dl) b
delUnused s (CompPrg dl b) = CompPrg (fst result) $ delUnused (Set.union s $ snd result) b where
result = makeUnusedSet dl
backward (CompPrg dl b) = doBackward dl $ toPrgList $ backward b
writesVar (CompPrg _ b) var = writesVar b var
readsVar (CompPrg _ b) var = readsVar b var
instance Simplification Program where
propagate m Empty = (Empty, m)
propagate m (Primitive instr seminf) = (Primitive (fst $ propagate m instr) (fst seminf'), snd seminf') where
seminf' = propagate m seminf
propagate m s@(Seq ps seminf) = (Seq (fst result) seminf, snd result) where
result = propagate m ps
propagate m (IfThenElse v cp1 cp2 seminf)
= (IfThenElse v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where
result1 = propagate m cp1
result2 = propagate m cp2
propagate m (SeqLoop v cp1 cp2 seminf)
= (SeqLoop v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where
result1 = propagate m cp1
result2 = propagate m cp2
propagate m (ParLoop v i1 i2 cp seminf) = (ParLoop v i1 i2 (fst result) seminf, snd result) where
result = propagate m cp
delUnused _ Empty = Empty
delUnused s p@(Primitive _ seminf)
| all (\v -> Set.member v s) $ leftVars $ varMap seminf = Empty
| otherwise = p
delUnused s (Seq ps seminf) = Seq (delUnused s ps) seminf
delUnused s (IfThenElse v cp1 cp2 seminf) = IfThenElse v (delUnused s cp1) (delUnused s cp2) seminf
delUnused s (SeqLoop v cp1 cp2 seminf) = SeqLoop v (delUnused s cp1) (delUnused s cp2) seminf
delUnused s (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (delUnused s cp) seminf
backward (Seq ps seminf) = Seq (backward ps) seminf
backward (IfThenElse v cp1 cp2 seminf) = IfThenElse v (backward cp1) (backward cp2) seminf
backward (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (backward cp) seminf
backward (SeqLoop v cp1 cp2 seminf) = SeqLoop v (backward cp1) (backward cp2) seminf
backward x = x
writesVar Empty _ = False
writesVar (Primitive i _) var = writesVar i var
writesVar (Seq ps _) var = writesVar ps var
writesVar (IfThenElse v cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var
writesVar (SeqLoop _ cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var
writesVar (ParLoop _ _ _ cp _) var = writesVar cp var
readsVar Empty _ = False
readsVar (Primitive i _) var = readsVar i var
readsVar (Seq ps _) var = readsVar ps var
readsVar (IfThenElse v cp1 cp2 _) var = (name v == var) || readsVar cp1 var || readsVar cp2 var
readsVar (SeqLoop _ cp1 cp2 _) var = readsVar cp1 var || readsVar cp2 var
readsVar (ParLoop _ _ _ cp _) var = readsVar cp var
instance (Simplification a) => Simplification [a] where
propagate m [] = ([],m)
propagate m (x:xs) = (fst xresult : fst xsresult, snd xsresult) where
xresult = propagate m x
xsresult = propagate (snd xresult) xs
delUnused s xs = map (delUnused s) xs
backward xs = map backward xs
writesVar xs var = any (\x -> writesVar x var) xs
readsVar xs var = any (\x -> readsVar x var) xs
instance Simplification Instruction where
propagate m (Assign left right) = (Assign (fst $ propagate m left) (fst $ propagate m right), m)
propagate m (CFun name ps) = (CFun name $ map (fst . propagate m) ps, m)
delUnused _ = id
backward = id
writesVar (Assign left _) var = writesVar left var
writesVar (CFun _ ps) var = any (\p -> writesVar p var) ps
readsVar (Assign left right) var = readsVarHelp left var || readsVar right var
readsVar (CFun _ ps) var = any (\p -> readsVar p var) ps
instance Simplification SemInfPrim where
propagate m seminf = (seminf{ varMap = seminf' }, updated) where
updated = Map.map upd2 $ Map.mapWithKey upd1 m
upd1 name expr = case Map.lookup name seminf' of
Nothing -> expr
Just sem -> case usedLeft sem of
None -> expr
Single e -> e
_ -> Nothing
upd2 expr = case expr of
Nothing -> expr
Just e
| any (\v -> contains v e) $ leftVars seminf' -> Nothing
| otherwise -> expr
seminf' = Map.foldWithKey prop Map.empty $ varMap seminf
prop :: String -> SemInfVar -> Map.Map String SemInfVar -> Map.Map String SemInfVar
prop name sem other
= addVarMap other $ addVarMap (Map.singleton name $ SemInfVar (propLeft $ usedLeft sem) (Times 0)) $ propRight name $ usedRight sem
propLeft (Single (Just expr)) = Single $ Just $ fst $ propagate m expr
propLeft x = x
propRight :: String -> RightUse -> Map.Map String SemInfVar
propRight name right = case Map.lookup name m of
Just (Just e) -> Map.map (mult right) $ rightVarMap e
_ -> Map.singleton name (SemInfVar None right)
mult UnknownR sem = sem{ usedRight=UnknownR }
mult (Times n) sem = case usedRight sem of
Times n' -> sem{ usedRight = Times $ n*n' }
_ -> sem
mult MultipleR sem = case usedRight sem of
UnknownR -> sem
_ -> sem{ usedRight = MultipleR }
delUnused _ = id
backward = id
writesVar _ _ = False
readsVar _ _ = False
instance Simplification ImpLangExpr where
propagate m i@(Expr (LeftExpr (LVar (Var n _ _))) t)
| Map.member n m = case m Map.! n of
Nothing -> (i,m)
Just expr -> (fst $ propagate m expr, m)
| otherwise = (i,m)
propagate m (Expr (LeftExpr x) t) = (Expr (LeftExpr (fst $ propagate m x)) t, m)
propagate m (Expr (FunCall r s is) t) = (Expr (FunCall r s (map (fst . propagate m) is)) t, m)
propagate m x = (x,m)
delUnused _ = id
backward = id
writesVar (Expr (LeftExpr lv) _) var = writesVar lv var
writesVar _ _ = False
readsVar (Expr (LeftExpr lv) _) var = readsVar lv var
readsVar (Expr (AddressOf lv) _) var = readsVar lv var
readsVar (Expr (ConstExpr _) _) var = False
readsVar (Expr (FunCall _ _ es) _) var = any (\e -> readsVar e var) es
instance Simplification LeftValue where
propagate m l@(LVar (Var n _ _))
| Map.member n m = case m Map.! n of
Nothing -> (l,m)
Just expr -> (getLeftValue $ fst $ propagate m expr, m)
| otherwise = (l, m)
propagate m (ArrayElem lv ile) = (ArrayElem (fst $ propagate m lv) (fst $ propagate m ile), m)
propagate m (PointedVal lv) = (PointedVal (fst $ propagate m lv), m)
delUnused _ = id
backward = id
writesVar (LVar v) var = name v == var
writesVar (ArrayElem lv exp) var = writesVar lv var
writesVar (PointedVal lv) var = writesVar lv var
readsVar (LVar v) var = name v == var
readsVar (ArrayElem lv exp) var = readsVar lv var || readsVar exp var
readsVar (PointedVal lv) var = readsVar lv var
instance Simplification Parameter where
propagate m (In ile) = (In (fst $ propagate m ile), m)
propagate m (Out (k, ile)) = (Out (k, (fst $ propagate m ile)), m)
delUnused _ = id
backward = id
writesVar (In _) _ = False
writesVar (Out (_,exp)) var = writesVar exp var
readsVar (In exp) var = readsVar exp var
readsVar (Out (_,exp)) var = readsVarHelp (getLeftValue exp) var
makePropagateMap :: [Declaration] -> PropagateMap
makePropagateMap dl = foldr Map.union Map.empty $ map makePropagateMap' dl where
makePropagateMap' d = case usedLeft $ semInfVar d of
Single (Just e)
| usedRight (semInfVar d) == Times 1 || simpleExpr e -> Map.singleton (name $ var d) $ initVal d
| otherwise -> Map.empty
otherwise -> Map.empty
simpleExpr (Expr (LeftExpr (LVar _)) _) = True
simpleExpr (Expr (ConstExpr _) t) = simpleType t
simpleExpr _ = False
purgePropagateMap :: PropagateMap -> [Declaration] -> PropagateMap
purgePropagateMap m dl = Map.differenceWith (\_ _ -> Nothing) m (makePropagateMap dl)
combineExpr :: Maybe ImpLangExpr -> Maybe ImpLangExpr -> Maybe ImpLangExpr
combineExpr e1 e2
| e1 == e2 = e1
| otherwise = Nothing
makeUnusedSet :: [Declaration] -> ([Declaration],DelSet)
makeUnusedSet [] = ([],Set.empty)
makeUnusedSet (d:ds) = case usedRight $ semInfVar d of
Times 0 -> (fst result, Set.insert (name $ var d) $ snd result)
_ -> (d : fst result, snd result)
where
result = makeUnusedSet ds
readsVarHelp :: LeftValue -> String -> Bool
readsVarHelp (LVar _) _ = False
readsVarHelp (ArrayElem lv exp) var = readsVarHelp lv var || readsVar exp var
readsVarHelp (PointedVal lv) _ = False
doBackward :: [Declaration] -> [Program] -> CompleteProgram
doBackward ds ps
| cont = doBackward ds' ps'
| otherwise = CompPrg ds' (Seq ps' [])
where
(cont,ds',ps') = backwardRec ds ([],ps)
backwardRec :: [Declaration] -> ([Program],[Program]) -> (Bool, [Declaration], [Program])
backwardRec ds (xs,[]) = (False, ds, reverse xs)
backwardRec ds (xs,y:ys) = case backwardPossible ds xs y ys of
Nothing -> backwardRec ds (y:xs,ys)
Just (left,right,init) -> (True, fst result, init : snd result) where
result = backwardRepl left right ds xs ys
backwardPossible :: [Declaration] -> [Program] -> Program -> [Program] -> Maybe (LeftValue,String,Program)
backwardPossible ds xs y ys = case y of
(Primitive (Assign left (Expr (LeftExpr (LVar (Var name _ _))) _)) (SemInfPrim _ True))
-> check left name
(Primitive (CFun fname [In (Expr (LeftExpr (LVar (Var name _ _))) _), Out (_,(Expr (LeftExpr left) _))]) (SemInfPrim _ True))
| isPrefixOf "copy" fname -> check left name
| otherwise -> Nothing
_ -> Nothing
where
check left name
| isJust declarationOK && beforeOK && afterOK
= Just (left,name,fromJust declarationOK)
| otherwise
= Nothing
where
declarationOK = case find (declares name) ds of
Just d -> case initVal d of
Nothing -> Just Empty
Just expr
| simpleType (exprType expr) -> Just $ Primitive (Assign left expr) $ SemInfPrim Map.empty False
| otherwise -> Nothing
Nothing -> Nothing
beforeOK = case useBefore of
(False, _) -> False
(True, False) -> True
(True, True) -> case declarationOK of
Nothing -> False
Just Empty -> True
Just _ -> False
afterOK = not $ any (\p -> readsVar p name || writesVar p name) ys
useBefore = foldl step (True,False) xs
step (ok,out) prg = (ok',out') where
out' = out || outRead || outWritten
ok'
| not ok = False
| out && (varWritten || varRead) = False
| outWritten && (varWritten || varRead) = False
| outRead && varRead = False
| otherwise = True
outWritten = prg `writesVar` outName
outRead = prg `readsVar` outName
outName = getVarName left
varWritten = prg `writesVar` name
varRead = prg `readsVar` name
backwardRepl :: LeftValue -> String -> [Declaration] -> [Program] -> [Program] -> ([Declaration], [Program])
backwardRepl lv var ds xs ys = (filter (not . declares var) ds, replaceLExpr (reverse xs ++ ys) (var,lv))
toPrgList :: Program -> [Program]
toPrgList (Seq ps _) = ps
toPrgList p = [p]
declares :: String -> Declaration -> Bool
declares n d = n == name (var d)