{-
 - Copyright (c) 2009, ERICSSON AB All rights reserved.
 - 
 - Redistribution and use in source and binary forms, with or without
 - modification, are permitted provided that the following conditions
 - are met:
 - 
 -     * Redistributions of source code must retain the above copyright
 -     notice,
 -       this list of conditions and the following disclaimer.
 -     * Redistributions in binary form must reproduce the above copyright
 -       notice, this list of conditions and the following disclaimer
 -       in the documentation and/or other materials provided with the
 -       distribution.
 -     * Neither the name of the ERICSSON AB nor the names of its
 -     contributors
 -       may be used to endorse or promote products derived from this
 -       software without specific prior written permission.
 - 
 - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 -}

module Feldspar.Compiler.Optimization.Simplification where

import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List hiding (insert,union)
import Data.Maybe
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Optimization.Replace

doSimplification :: [ImpFunction] -> [ImpFunction]
doSimplification = map doSimplificationOne

doSimplificationOne :: ImpFunction -> ImpFunction
doSimplificationOne = backward . delUnused Set.empty . fst . computeSemInfVar . fst . propagate Map.empty . fst . computeSemInfVar

--------------------------------------------------------------
-- Computing semantic information for variable declarations --
--------------------------------------------------------------

class ComputeSemInfVar t where
    computeSemInfVar :: t -> (t, VariableMap)

instance ComputeSemInfVar ImpFunction where
    computeSemInfVar fun = (fun{ prg = fst result}, snd result)
        where
            result = computeSemInfVar $ prg fun

instance ComputeSemInfVar CompleteProgram where
    computeSemInfVar (CompPrg locals body) = (CompPrg locs $ fst result, rest)
        where
            result = computeSemInfVar body
            dresult = computeSemInfVar locals
            locs = map updateLocal locals
            rest = Map.filterWithKey (\k _ -> not $ isLocal k) (snd result)
            isLocal name = Prelude.filter (\(Decl (Var n _ _) _ _ _) -> n == name) locals /= []
            updateLocal d@(Decl (Var name _ _) _ _ _) = case Map.lookup name $ addVarMap (snd result) (snd dresult) of
                Nothing -> d
                Just inf -> d{ semInfVar = inf }

instance ComputeSemInfVar Program where
    computeSemInfVar Empty = (Empty,Map.empty)
    computeSemInfVar p@(Primitive _ seminf) = (p, varMap seminf)
    computeSemInfVar (Seq ps sem) = (Seq (map fst result) sem, foldr addVarMap Map.empty $ map snd result) where
        result = map computeSemInfVar ps
    computeSemInfVar (IfThenElse (Var cName k t) p1 p2 sem)
        = (IfThenElse (Var cName k t) (fst result1) (fst result2) sem, foldr addVarMap condResult [snd result1, snd result2]) where
            result1 = computeSemInfVar p1
            result2 = computeSemInfVar p2
            condResult = Map.singleton cName $ SemInfVar None UnknownR
    computeSemInfVar (SeqLoop z@(Var c _ _) cp bp sem)
        = (SeqLoop z (fst cresult) (fst bresult) sem, iterated) where
            cresult = addCondVarInf $ computeSemInfVar cp
            bresult = computeSemInfVar bp
            iterated = multiply $ addVarMap (snd cresult) (snd bresult)
            addCondVarInf (CompPrg locs bod, sem) = (CompPrg (map addCondVarInfToDecl locs) bod, sem)
            addCondVarInfToDecl d
                | var d == z    = d{ semInfVar = addSemInfVar (SemInfVar None MultipleR) $ semInfVar d}
                | otherwise     = d
    computeSemInfVar (ParLoop init test count body seminfo)
        = (ParLoop init test count (fst bodyResult) seminfo, result)
        where
          result = multiply $ addVarMap (snd bodyResult) testResult
          bodyResult = computeSemInfVar body
          testResult = rightVarMap $ (\(Expr core _) -> core) test

multiply m = fmap multiplyOne m
multiplyOne sem = sem{ usedLeft = multiplyLeft $ usedLeft sem, usedRight = multiplyRight $ usedRight sem }
multiplyLeft (Single _) = MultipleL
multiplyLeft l = l
multiplyRight (Times 0) = Times 0
multiplyRight (Times _) = MultipleR
multiplyRight r = r

instance (ComputeSemInfVar a) => ComputeSemInfVar [a] where
    computeSemInfVar xs = (fst result, foldr addVarMap Map.empty $ snd result) where
        result = unzip $ map computeSemInfVar xs

instance ComputeSemInfVar Declaration where
    computeSemInfVar d@(Decl (Var name _ _) _ (Just ini) _) = (d, Map.singleton name $ SemInfVar (Single $ Just ini) (Times 0) )
    computeSemInfVar d@(Decl (Var name _ _) _ Nothing _) = (d, Map.empty)

--------------------
-- Simplification --
--------------------

type PropagateMap = Map.Map String (Maybe ImpLangExpr)
type DelSet = Set.Set String

class Simplification a where
    propagate   :: PropagateMap -> a -> (a,PropagateMap)
    delUnused   :: DelSet -> a -> a
    backward    :: a -> a
    writesVar   :: a -> String -> Bool
    readsVar    :: a -> String -> Bool

instance Simplification ImpFunction where
    propagate m (Fun n ips ops cprg) = (Fun n ips ops $ fst $ propagate m cprg, Map.empty)
    delUnused s (Fun n ips ops cprg) = Fun n ips ops $ delUnused s cprg
    backward fun = fun { prg = backward $ prg fun }
    writesVar fun var = False   -- Should not be used.
    readsVar fun var = False    -- Should not be used.

instance Simplification CompleteProgram where
    propagate m (CompPrg dl b) = (CompPrg dl $ fst result, purgePropagateMap (snd result) dl) where
        result = propagate (Map.union m $ makePropagateMap dl) b
    delUnused s (CompPrg dl b) = CompPrg (fst result) $ delUnused (Set.union s $ snd result) b where
        result = makeUnusedSet dl
    backward (CompPrg dl b) = doBackward dl $ toPrgList $ backward b
    writesVar (CompPrg _ b) var = writesVar b var
    readsVar (CompPrg _ b) var = readsVar b var

instance Simplification Program where
    propagate m Empty = (Empty, m)
    propagate m (Primitive instr seminf) = (Primitive (fst $ propagate m instr) (fst seminf'), snd seminf') where
        seminf' = propagate m seminf
    propagate m s@(Seq ps seminf) = (Seq (fst result) seminf, snd result) where
        result = propagate m ps
    propagate m (IfThenElse v cp1 cp2 seminf)
        = (IfThenElse v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where
            result1 = propagate m cp1
            result2 = propagate m cp2
    propagate m (SeqLoop v cp1 cp2 seminf)
        = (SeqLoop v (fst result1) (fst result2) seminf, Map.intersectionWith combineExpr (snd result1) (snd result2)) where
            result1 = propagate m cp1
            result2 = propagate m cp2
    propagate m (ParLoop v i1 i2 cp seminf) = (ParLoop v i1 i2 (fst result) seminf, snd result) where
        result = propagate m cp
    delUnused _ Empty = Empty
    delUnused s p@(Primitive _ seminf)
        | all (\v -> Set.member v s) $ leftVars $ varMap seminf  = Empty
        | otherwise                                     = p
    delUnused s (Seq ps seminf) = Seq (delUnused s ps) seminf
    delUnused s (IfThenElse v cp1 cp2 seminf) = IfThenElse v (delUnused s cp1) (delUnused s cp2) seminf
    delUnused s (SeqLoop v cp1 cp2 seminf) = SeqLoop v (delUnused s cp1) (delUnused s cp2) seminf
    delUnused s (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (delUnused s cp) seminf
    backward (Seq ps seminf) = Seq (backward ps) seminf
    backward (IfThenElse v cp1 cp2 seminf) = IfThenElse v (backward cp1) (backward cp2) seminf
    backward (ParLoop v i1 i2 cp seminf) = ParLoop v i1 i2 (backward cp) seminf
    backward (SeqLoop v cp1 cp2 seminf) = SeqLoop v (backward cp1) (backward cp2) seminf
    backward x = x
    writesVar Empty _ = False
    writesVar (Primitive i _) var = writesVar i var
    writesVar (Seq ps _) var = writesVar ps var
    writesVar (IfThenElse v cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var
    writesVar (SeqLoop _ cp1 cp2 _) var = writesVar cp1 var || writesVar cp2 var
    writesVar (ParLoop _ _ _ cp _) var = writesVar cp var
    readsVar Empty _ = False
    readsVar (Primitive i _) var = readsVar i var
    readsVar (Seq ps _) var = readsVar ps var
    readsVar (IfThenElse v cp1 cp2 _) var = (name v == var) || readsVar cp1 var || readsVar cp2 var
    readsVar (SeqLoop _ cp1 cp2 _) var = readsVar cp1 var || readsVar cp2 var
    readsVar (ParLoop _ _ _ cp _) var = readsVar cp var

instance (Simplification a) => Simplification [a] where
    propagate m [] = ([],m)
    propagate m (x:xs) = (fst xresult : fst xsresult, snd xsresult) where
        xresult = propagate m x
        xsresult = propagate (snd xresult) xs
    delUnused s xs = map (delUnused s) xs
    backward xs = map backward xs
    writesVar xs var = any (\x -> writesVar x var) xs
    readsVar xs var = any (\x -> readsVar x var) xs

instance Simplification Instruction where
    propagate m (Assign left right) = (Assign (fst $ propagate m left) (fst $ propagate m right), m)
    propagate m (CFun name ps) = (CFun name $ map (fst . propagate m) ps, m)
    delUnused _ = id
    backward = id
    writesVar (Assign left _) var = writesVar left var
    writesVar (CFun _ ps) var = any (\p -> writesVar p var) ps
    readsVar (Assign left right) var = readsVarHelp left var || readsVar right var
    readsVar (CFun _ ps) var = any (\p -> readsVar p var) ps

instance Simplification SemInfPrim where
    propagate m seminf = (seminf{ varMap = seminf' }, updated) where
        updated = Map.map upd2 $ Map.mapWithKey upd1 m
        upd1 name expr = case Map.lookup name seminf' of
            Nothing -> expr
            Just sem -> case usedLeft sem of
                None -> expr
                Single e -> e
                _ -> Nothing
        upd2 expr = case expr of
            Nothing -> expr
            Just e
                | any (\v -> contains v e) $ leftVars seminf' -> Nothing
                | otherwise -> expr
        seminf' = Map.foldWithKey prop Map.empty $ varMap seminf
        prop :: String -> SemInfVar -> Map.Map String SemInfVar -> Map.Map String SemInfVar
        prop name sem other
            = addVarMap other $ addVarMap (Map.singleton name $ SemInfVar (propLeft $ usedLeft sem) (Times 0)) $ propRight name $ usedRight sem
        propLeft (Single (Just expr)) = Single $ Just $ fst $ propagate m expr
        propLeft x = x
        propRight :: String -> RightUse -> Map.Map String SemInfVar
        propRight name right = case Map.lookup name m of
            Just (Just e) -> Map.map (mult right) $ rightVarMap e
            _ -> Map.singleton name (SemInfVar None right)
        mult UnknownR sem = sem{ usedRight=UnknownR }
        mult (Times n) sem = case usedRight sem of
            Times n'    -> sem{ usedRight = Times $ n*n' }
            _           -> sem
        mult MultipleR sem = case usedRight sem of
            UnknownR    -> sem
            _           -> sem{ usedRight = MultipleR }
    delUnused _ = id
    backward = id
    writesVar _ _ = False
    readsVar _ _ = False

instance Simplification ImpLangExpr where
    propagate m i@(Expr (LeftExpr (LVar (Var n _ _))) t)
        | Map.member n m = case m Map.! n of
            Nothing -> (i,m)
            Just expr -> (fst $ propagate m expr, m)
        | otherwise = (i,m)
    propagate m (Expr (LeftExpr x) t) = (Expr (LeftExpr (fst $ propagate m x)) t, m)
    propagate m (Expr (FunCall r s is) t) = (Expr (FunCall r s (map (fst . propagate m) is)) t, m)
    propagate m x = (x,m)
    delUnused _ = id
    backward = id
    writesVar (Expr (LeftExpr lv) _) var = writesVar lv var
    writesVar _ _ = False
    readsVar (Expr (LeftExpr lv) _) var = readsVar lv var
    readsVar (Expr (AddressOf lv) _) var = readsVar lv var
    readsVar (Expr (ConstExpr _) _) var = False
    readsVar (Expr (FunCall _ _ es) _) var = any (\e -> readsVar e var) es

instance Simplification LeftValue where
    propagate m l@(LVar (Var n _ _))
        | Map.member n m = case m Map.! n of
            Nothing -> (l,m)
            Just expr -> (getLeftValue $ fst $ propagate m expr, m)
        | otherwise = (l, m)
    propagate m (ArrayElem lv ile) = (ArrayElem (fst $ propagate m lv) (fst $ propagate m ile), m)
    propagate m (PointedVal lv) = (PointedVal (fst $ propagate m lv), m)
    delUnused _ = id
    backward = id
    writesVar (LVar v) var = name v == var
    writesVar (ArrayElem lv exp) var = writesVar lv var
    writesVar (PointedVal lv) var = writesVar lv var
    readsVar (LVar v) var = name v == var
    readsVar (ArrayElem lv exp) var = readsVar lv var || readsVar exp var
    readsVar (PointedVal lv) var = readsVar lv var

instance Simplification Parameter where
    propagate m (In ile) = (In (fst $ propagate m ile), m)
    propagate m (Out (k, ile)) = (Out (k, (fst $ propagate m ile)), m)
    delUnused _ = id
    backward = id
    writesVar (In _) _ = False
    writesVar (Out (_,exp)) var = writesVar exp var
    readsVar (In exp) var = readsVar exp var
    readsVar (Out (_,exp)) var = readsVarHelp (getLeftValue exp) var

makePropagateMap :: [Declaration] -> PropagateMap
makePropagateMap dl = foldr Map.union Map.empty $ map makePropagateMap' dl where
    makePropagateMap' d = case usedLeft $ semInfVar d of
        Single (Just e)
            | usedRight (semInfVar d) == Times 1 || simpleExpr e    -> Map.singleton (name $ var d) $ initVal d
            | otherwise                                             -> Map.empty
        otherwise                                                   -> Map.empty
    simpleExpr (Expr (LeftExpr (LVar _)) _) = True
    simpleExpr (Expr (ConstExpr _) t) = simpleType t
    simpleExpr _ = False

purgePropagateMap :: PropagateMap -> [Declaration] -> PropagateMap
purgePropagateMap m dl = Map.differenceWith (\_ _ -> Nothing) m (makePropagateMap dl)

combineExpr :: Maybe ImpLangExpr -> Maybe ImpLangExpr -> Maybe ImpLangExpr
combineExpr e1 e2
    | e1 == e2  = e1
    | otherwise = Nothing

makeUnusedSet :: [Declaration] -> ([Declaration],DelSet)
makeUnusedSet [] = ([],Set.empty)
makeUnusedSet (d:ds) = case usedRight $ semInfVar d of
    Times 0 -> (fst result, Set.insert (name $ var d) $ snd result)
    _       -> (d : fst result, snd result)
    where
        result = makeUnusedSet ds

readsVarHelp :: LeftValue -> String -> Bool
readsVarHelp (LVar _) _ = False
readsVarHelp (ArrayElem lv exp) var = readsVarHelp lv var || readsVar exp var
readsVarHelp (PointedVal lv) _ = False

-----------------------------
-- Backward simplification --
-----------------------------

doBackward :: [Declaration] -> [Program] -> CompleteProgram
doBackward ds ps
    | cont      = doBackward ds' ps'
    | otherwise = CompPrg ds' (Seq ps' [])
    where
        (cont,ds',ps') = backwardRec ds ([],ps)

backwardRec :: [Declaration] -> ([Program],[Program]) -> (Bool, [Declaration], [Program])
backwardRec ds (xs,[]) = (False, ds, reverse xs)
backwardRec ds (xs,y:ys) = case backwardPossible ds xs y ys of
    Nothing -> backwardRec ds (y:xs,ys)
    Just (left,right,init) -> (True, fst result, init : snd result) where
        result = backwardRepl left right ds xs ys

backwardPossible :: [Declaration] -> [Program] -> Program -> [Program] -> Maybe (LeftValue,String,Program)
backwardPossible ds xs y ys = case y of
    (Primitive (Assign left (Expr (LeftExpr (LVar (Var name _ _))) _)) (SemInfPrim _ True))
        -> check left name
    (Primitive (CFun fname [In (Expr (LeftExpr (LVar (Var name _ _))) _), Out (_,(Expr (LeftExpr left) _))]) (SemInfPrim _ True))
        | isPrefixOf "copy" fname -> check left name -- TODO: eliminate string constant
        | otherwise -> Nothing
    _   -> Nothing
    where
        check left name
            | isJust declarationOK && beforeOK && afterOK
                = Just (left,name,fromJust declarationOK)
            | otherwise
                = Nothing
            where
            declarationOK = case find (declares name) ds of
                Just d  -> case initVal d of
                    Nothing -> Just Empty
                    Just expr
                        | simpleType (exprType expr) -> Just $ Primitive (Assign left expr) $ SemInfPrim Map.empty False
                        | otherwise -> Nothing
                Nothing -> Nothing
            beforeOK = case useBefore of
                (False, _)      -> False
                (True, False)   -> True
                (True, True)    -> case declarationOK of
                    Nothing         -> False
                    Just Empty      -> True
                    Just _          -> False
            afterOK = not $ any (\p -> readsVar p name || writesVar p name) ys
            useBefore = foldl step (True,False) xs
            step (ok,out) prg = (ok',out') where
                out' = out || outRead || outWritten
                ok'
                    | not ok                                = False
                    | out && (varWritten || varRead)        = False
                    | outWritten && (varWritten || varRead) = False
                    | outRead && varRead                    = False
                    | otherwise                             = True
                outWritten = prg `writesVar` outName
                outRead = prg `readsVar` outName
                outName = getVarName left
                varWritten = prg `writesVar` name
                varRead = prg `readsVar` name

backwardRepl :: LeftValue -> String -> [Declaration] -> [Program] -> [Program] -> ([Declaration], [Program])
backwardRepl lv var ds xs ys = (filter (not . declares var) ds, replaceLExpr (reverse xs ++ ys) (var,lv))

toPrgList :: Program -> [Program]
toPrgList (Seq ps _) = ps
toPrgList p = [p]

declares :: String -> Declaration -> Bool
declares n d = n == name (var d)