--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE EmptyDataDecls, TypeFamilies, FlexibleInstances #-}

module Feldspar.Compiler.Plugins.BackwardPropagation (
    BackwardPropagation(..)
    )
    where

import Feldspar.Compiler.PluginArchitecture
import Feldspar.Compiler.Plugins.PropagationUtils
import qualified Data.Map as Map
import qualified Data.List as List
import qualified Data.Set as Set
import Data.Maybe
import Feldspar.Compiler.Options

-- ===========================================================================
-- == Copy propagation plugin (backward)
-- ===========================================================================

type VarStatBck = VarStatistics ()

data BackwardPropagation = BackwardPropagation

instance TransformationPhase BackwardPropagation where
    type From BackwardPropagation = InitSemInf
    type To BackwardPropagation = ()
    type Downwards BackwardPropagation = ()
    type Upwards BackwardPropagation = ()

instance Plugin BackwardPropagation where
    type ExternalInfo BackwardPropagation = DebugOption
    executePlugin BackwardPropagation externalInfo procedure
        | externalInfo == NoSimplification = fst $ executeTransformationPhase BackwardPropagation () procedure
        | otherwise = fst $ executeTransformationPhase PropagationTransform [] $ fst $ executeTransformationPhase BackwardPropagationCollect (Occurrence_read,False) procedure

-- ====================
--       Collect
-- ====================

instance Default [(VariableData, LeftValueData ())] where
    defaultValue = []

-- meaning (out,var,out written in a sequence before out=var)
instance Default [(VariableData, LeftValueData (),Bool)] where
    defaultValue = []

instance Combine (VarStatBck, [(VariableData, LeftValueData (),Bool)]) where
    combine (m1,x1) (m2,x2) = (combine m1 m2, x1 ++ x2)

instance Default (Maybe (VariableData, LeftValueData (),Bool)) where
    defaultValue = Nothing

data BackwardPropagationSemInf

instance SemanticInfo BackwardPropagationSemInf where
    type ProcedureInfo             BackwardPropagationSemInf = ()
    type BlockInfo                 BackwardPropagationSemInf = [(VariableData, LeftValueData ())] --replacements inside block
    type ProgramInfo               BackwardPropagationSemInf = ()
    type EmptyInfo                 BackwardPropagationSemInf = ()
    type PrimitiveInfo             BackwardPropagationSemInf = Maybe (VariableData, LeftValueData (), Bool) --if the primitive is a copy assignment the datas of the assigment, just because when we delete primitives at 2nd phase we need this 
    type SequenceInfo              BackwardPropagationSemInf = ()
    type BranchInfo                BackwardPropagationSemInf = ()
    type SequentialLoopInfo        BackwardPropagationSemInf = ()
    type ParallelLoopInfo          BackwardPropagationSemInf = ()
    type FormalParameterInfo       BackwardPropagationSemInf = ()
    type LocalDeclarationInfo      BackwardPropagationSemInf = ()
    type ExpressionInfo            BackwardPropagationSemInf = ()
    type ConstantInfo              BackwardPropagationSemInf = ()
    type FunctionCallInfo          BackwardPropagationSemInf = ()
    type LeftValueInfo             BackwardPropagationSemInf = ()
    type ArrayElemReferenceInfo    BackwardPropagationSemInf = ()
    type InstructionInfo           BackwardPropagationSemInf = ()
    type AssignmentInfo            BackwardPropagationSemInf = ()
    type ProcedureCallInfo         BackwardPropagationSemInf = ()
    type ActualParameterInfo       BackwardPropagationSemInf = ()
    type IntConstantInfo           BackwardPropagationSemInf = ()
    type FloatConstantInfo         BackwardPropagationSemInf = ()
    type BoolConstantInfo          BackwardPropagationSemInf = ()
    type ArrayConstantInfo         BackwardPropagationSemInf = ()
    type VariableInfo              BackwardPropagationSemInf = ()

data BackwardPropagationCollect = BackwardPropagationCollect

instance TransformationPhase BackwardPropagationCollect where
    type From BackwardPropagationCollect = InitSemInf
    type To BackwardPropagationCollect = BackwardPropagationSemInf
    type Downwards BackwardPropagationCollect = (Occurrence_place, Bool)
    type Upwards BackwardPropagationCollect = (VarStatBck, [(VariableData, LeftValueData (),Bool)])
    downwardsBranchProgramInProgram self d orig = (occurrenceDownwards orig, False)
    downwardsSequentialLoopProgramInProgram self d orig = (occurrenceDownwards orig, False)
    downwardsParallelLoopProgramInProgram self d orig = (occurrenceDownwards orig, False)
    downwardsFormalParameter self d orig = (occurrenceDownwards orig, False)
    downwardsLocalDeclaration self d orig = (occurrenceDownwards orig, isJust $ localInitValue orig)
    downwardsAssignmentInstructionInInstruction self d orig = (occurrenceDownwards orig, False)
    downwardsActualParameter self d orig = (occurrenceDownwards orig, False)
    downwardsExpression self d orig = (occurrenceDownwards orig, False)
    upwardsVariable self (d,me) origVar newVar =  case d of
        Occurrence_declare
            | me -> (Map.singleton (variableData origVar) $ Occurrences (One Nothing) Zero, [])
            | otherwise -> (Map.singleton (variableData origVar) $ Occurrences Zero Zero, [])
        Occurrence_read -> (Map.singleton (variableData origVar) $ Occurrences Zero (One ()), [])
        Occurrence_write -> (Map.singleton (variableData origVar) $ Occurrences (One Nothing) Zero, [])
        Occurrence_notopt -> (Map.singleton (variableData origVar) $ Occurrences Multiple Multiple, [])
    upwardsPrimitiveProgramInProgram self d origPrimitive u newPrimitive = case newPrimitive of
        PrimitiveProgram newPr -> case primitiveSemInf newPr of 
            Just e -> (fst $ upwardsInfoFromPrimitiveInstruction u, [e])
            Nothing -> upwardsInfoFromPrimitiveInstruction u
        _ -> upwardsInfoFromPrimitiveInstruction u

    upwardsBlock self d origBlock u newBlock = (deleteFromVarStatistics (map (fst) $ blockSemInf newBlock) $ fst $ upwardsInfoFromBlockInstructions u,[])
    upwardsSequenceProgramInProgram self d origiSeq u transformedSequence = checkInSequence $ upwardsInfoFromSequenceProgramList u
    transformBlock self d origBlock u = Block {
            blockDeclarations = recursivelyTransformedBlockDeclarations u,
            blockInstructions = recursivelyTransformedBlockInstructions u,
            blockSemInf = checkInDeclatation origBlock $ upwardsInfoFromBlockInstructions u
        } 
    transformPrimitiveProgramInProgram self d origPrimitive u = PrimitiveProgram $ Primitive {
            primitiveInstruction = recursivelyTransformedPrimitiveInstruction u,
            primitiveSemInf = getNames origPrimitive
        }

getNames :: (SemanticInfo t) => Primitive t -> Maybe (VariableData, LeftValueData (),Bool)
getNames pr = getNames' $ instructionData $ primitiveInstruction pr where
    getNames' (AssignmentInstruction _) = Nothing
    getNames' (ProcedureCallInstruction pc)
        | goodName pc = getParamNames $ map actualParameterData $ actualParametersOfProcedureToCall pc
        | otherwise = Nothing
    goodName pc = "copy" == (nameOfProcedureToCall pc)
    getParamNames [InputActualParameter i, OutputActualParameter o] = pairJust (getIName i) (getOName o)
    getParamNames _ = Nothing
    pairJust (Just a) (Just b) = Just (a,b,False)
    pairJust _ _ = Nothing
    getIName i = getExpName $ expressionData i
    getOName o = Just $ deleteSemInf $ leftValueData o
    getExpName (LeftValueExpression lv) = getLvName_noarr $ leftValueData lv
    getExpName _ = Nothing
    getLvName_noarr (VariableLeftValue v) = Just $ variableData v
    getLvName_noarr _ = Nothing

getLvName :: (SemanticInfo t) => LeftValueData t -> VariableData
getLvName (VariableLeftValue v) = variableData v
getLvName (ArrayElemReferenceLeftValue aer) = getLvName $ leftValueData $ arrayName aer

checkInSequence :: [(VarStatBck, [(VariableData, LeftValueData (), Bool)])]  -> (VarStatBck, [(VariableData, LeftValueData (), Bool)])
checkInSequence [] = defaultValue
checkInSequence xs = (varstat $ map fst xs, mapMaybe (checkSeq xs False False False) $ foldl (\ls (vs,s) -> s++ls) [] xs)
    where
        varstat :: [VarStatBck] -> VarStatBck
        varstat = foldl combine defaultValue
        checkSeq :: [(VarStatBck, [(VariableData, LeftValueData (), Bool)])] -> Bool{-usedVar-} -> Bool{-usedOut-} -> Bool{-after-} -> (VariableData {-var-}, LeftValueData () {-out-}, Bool) -> Maybe (VariableData, LeftValueData (), Bool)
        checkSeq [] _ usedOut _  (var,outD,outUsedLower) = Just (var,outD,usedOut)
        checkSeq ((vs,s):ys) usedVar usedOut after sp@(var,outD,outUsedLower)
            | after && (vs `notUse` var)  = checkSeq ys usedVar usedOut after sp
            | after {- && (vs `hasUse` var) -} = Nothing
            | {-(not after) && -} (sp `List.elem` s) && ((not outUsedLower) || (not usedVar)) = checkSeq ys usedVar usedOut True sp
            | {-(not after) && -} usedVar && (vs `notUse` out) = checkSeq ys usedVar usedOut after sp
            | {-(not after) && -} usedVar {- && (vs `hasUse` out)-} = Nothing
            | {-(not after) && (not usedVar) && -} (vs `hasRead` var) && (vs `notUse` out) = checkSeq ys True usedOut after sp
            | {-(not after) && (not usedVar) && -} (vs `hasRead` var) {- && (vs `hasUse` out) -} = Nothing
            | {-(not after) && (not usedVar) && -} (vs `hasWrite` var) && (vs `hasWrite` out) = Nothing
            | {-(not after) && (not usedVar) && -} (vs `hasWrite` var) {- && (vs `notWrite` out)-} = checkSeq ys True usedOut after sp
            | {-(not after) && (not usedVar) && (vs `notUse` var) && -} (vs `hasUse` out) = checkSeq ys usedVar True after sp
            | {-(not after) && (not usedVar) && (vs `notUse` var) && (vs `notUse` out)-} otherwise = checkSeq ys usedVar usedOut after sp
            where
                out = getLvName outD
{-
check the sequence format:
______________
|   use out   |
|  ___________|
|__|=         |
|   use var   |
|_____________|
out = var
______________
| not use var |
|_____________|

|
-}

checkInDeclatation :: Block InitSemInf -> (VarStatBck, [(VariableData, LeftValueData (), Bool)]) -> [(VariableData, LeftValueData ())]
checkInDeclatation origBlock u = mapMaybe (checkDecl $ decl) (snd u) where
    decl = blockDeclarations origBlock
    checkDecl :: [LocalDeclaration InitSemInf] -> (VariableData, LeftValueData (), Bool) -> Maybe (VariableData, LeftValueData ())
    checkDecl lds (var,outD,outUsedLower) = case List.find (\ld -> var == declaredVar ld) lds of
        Nothing -> Nothing
        Just ld -> case localInitValue ld of
            Nothing -> Just (var,outD)
            Just exp -> case outUsedLower of
                True -> Nothing
                False -> Just (var,outD)
{-
check var get initValue, because it is a write, and it means we can't use out because "out=var"
-}

-- ====================
--  BackwardPropagation
-- ====================

data PropagationTransform = PropagationTransform

instance TransformationPhase PropagationTransform where
    type From PropagationTransform = BackwardPropagationSemInf
    type To PropagationTransform = ()
    type Downwards PropagationTransform = [(VariableData, LeftValueData ())]
    type Upwards PropagationTransform = ()
    downwardsBlock self d origBlock = unChain $ foldl addChain (blockSemInf origBlock) d
    downwardsLocalDeclaration self d origLocDecl = []
    transformBlock self d orig u = delUnusedDecl (map fst $ downwardsBlock self d orig) orig (recursivelyTransformedBlockDeclarations u) (recursivelyTransformedBlockInstructions u)
    transformPrimitiveProgramInProgram self d origPrimitive u = 
        case isIdentity $ instructionData newInstr of
            True -> EmptyProgram $ Empty ()
            False -> makedPrim
        where
            makedPrim = PrimitiveProgram $ Primitive {
                primitiveInstruction = newInstr,
                primitiveSemInf =()
            }
            newInstr = recursivelyTransformedPrimitiveInstruction u
            isIdentity (AssignmentInstruction a) = isIdentity' (assignmentRhs a) (assignmentLhs a)
            isIdentity (ProcedureCallInstruction p)
                | nameOfProcedureToCall p == "copy" = case map actualParameterData $ actualParametersOfProcedureToCall p of
                    [InputActualParameter i,OutputActualParameter o] -> isIdentity' i o
                    _ -> False
                | otherwise = False
            isIdentity' e l = List.elem (expressionData e) [LeftValueExpression $ swapArrayIndex l, LeftValueExpression l]
            swapArrayIndex :: LeftValue () -> LeftValue ()
            swapArrayIndex l = setIndex $ (\(a,b)->(a,reverse b)) $ getIndex l
            getIndex :: LeftValue () -> (Variable (), [Expression ()])
            getIndex lv = getIndex2 $ leftValueData lv
            getIndex2 (VariableLeftValue v) = (v,[])
            getIndex2 (ArrayElemReferenceLeftValue a) = (fst $ getIndex $ arrayName a, (arrayIndex a):(snd $ getIndex $ arrayName a))            
            setIndex :: (Variable (), [Expression ()]) -> LeftValue ()
            setIndex x = LeftValue (setIndex2 x) ()
            setIndex2 (v,[]) = VariableLeftValue v
            setIndex2 (v,(x:xs)) = ArrayElemReferenceLeftValue $ ArrayElemReference (setIndex (v,xs)) x ()
    transformVariableLeftValueInLeftValue self d origVar = case List.find (\(a,b) -> a == variableData origVar) d of
            Nothing -> VariableLeftValue $ origVar {
                    variableSemInf = ()
                }
            Just (var,out) -> out
            
unChain :: [(VariableData, LeftValueData ())] -> [(VariableData, LeftValueData ())]
unChain s = unchain' s where
    unchain' s
        | s == unchain'' s = s
        | otherwise = unchain'' s
    unchain'' s = foldl addChain [] s

addChain :: [(VariableData, LeftValueData ())] -> (VariableData, LeftValueData ()) -> [(VariableData, LeftValueData ())]
addChain [] pair = [pair]
addChain (x@(mibe1,mit1):xs) r@(mibe2,mit2)
    | (getLvName mit1) == mibe2 = (mibe1,changeInnerArrayName mit1 mit2):r:xs
    | (getLvName mit2) == mibe1 = (mibe2,changeInnerArrayName mit2 mit1):x:xs
    | otherwise = x:(addChain xs r)
    where
        changeInnerArrayName :: LeftValueData () {-toChange-} -> LeftValueData () {-newName-} -> LeftValueData ()
        changeInnerArrayName toChange (ArrayElemReferenceLeftValue aer) = ArrayElemReferenceLeftValue aer {
            arrayName = LeftValue (changeInnerArrayName toChange $ leftValueData $ arrayName aer) ()
        } 
        changeInnerArrayName (ArrayElemReferenceLeftValue aer) newName@(VariableLeftValue _) = ArrayElemReferenceLeftValue aer {
                arrayName = LeftValue (changeInnerArrayName (leftValueData $ arrayName aer) newName) ()
        }
        changeInnerArrayName (VariableLeftValue _) newName@(VariableLeftValue _) = newName

{-
addChain [ (a,   b) ] (b,   c)     =    [ (a,   b), (a,       c) ]
addChain [ (a,   b) ] (b[i],c)     =    [ (a,   b), (a[i],    c) ]
addChain [ (a[m],b) ] (b[i],c)     =    [ (a[m],b), (a[m][i], c) ]
addChain [ (b,   c) ] (a,   b)     =    [ (a,   b), (a,       c) ]
addChain [ (b,   c) ] (a[i],b)     =    [ (a,   b), (a[i],    c) ]
addChain [ (b[i],c) ] (a[m],b)     =    [ (a[m],b), (a[m][i], c) ]

but arrayof(arrayof(lv,index1)index2) = lv[index2][index1]
so first go down in newNames indexes and put these outwards
then go down toChanges indexes, and when no indexes change

-}