--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE EmptyDataDecls, TypeFamilies, FlexibleInstances #-}

module Feldspar.Compiler.Plugins.ForwardPropagation (
    ForwardPropagation(..)
    ) 
    where

import Feldspar.Compiler.PluginArchitecture
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List as List
import Feldspar.Compiler.Plugins.PropagationUtils
import Feldspar.Compiler.Error
import Feldspar.Compiler.Options
import Feldspar.Compiler.Imperative.CodeGeneration (simpleType)

fwdPropError = handleError "PluginArch/ForwardPropagation" InternalError

-- ===========================================================================
-- == Copy propagation plugin (forward)
-- ===========================================================================

type VarStatFwd = VarStatistics (ExpressionData ForwardPropagationSemInf, [VariableData], Bool)
type OccurrencesFwd = Occurrences (ExpressionData ForwardPropagationSemInf, [VariableData], Bool)

data ForwardPropagation = ForwardPropagation

instance Plugin ForwardPropagation where
    type ExternalInfo ForwardPropagation = DebugOption
    executePlugin ForwardPropagation externalInfo procedure 
        | externalInfo == NoSimplification || externalInfo == NoPrimitiveInstructionHandling = procedure
        | otherwise = fst $ executeTransformationPhase ForwardPropagationTransform (fst globals1) procedureCollected1
            where 
                (procedureCollected1,globals1) = executeTransformationPhase ForwardPropagationCollect Occurrence_read procedure

instance TransformationPhase ForwardPropagation where
    type From ForwardPropagation = ()
    type To ForwardPropagation = ()
    type Downwards ForwardPropagation = ()
    type Upwards ForwardPropagation = ()

-- ====================
--       Collect
-- ====================

data ForwardPropagationSemInf

instance SemanticInfo ForwardPropagationSemInf where
    type ProcedureInfo             ForwardPropagationSemInf = ()
    type BlockInfo                 ForwardPropagationSemInf = VarStatFwd
    type ProgramInfo               ForwardPropagationSemInf = ()
    type EmptyInfo                 ForwardPropagationSemInf = ()
    type PrimitiveInfo             ForwardPropagationSemInf = ()
    type SequenceInfo              ForwardPropagationSemInf = ()
    type BranchInfo                ForwardPropagationSemInf = ()
    type SequentialLoopInfo        ForwardPropagationSemInf = VarStatFwd
    type ParallelLoopInfo          ForwardPropagationSemInf = ()
    type FormalParameterInfo       ForwardPropagationSemInf = ()
    type LocalDeclarationInfo      ForwardPropagationSemInf = ()
    type ExpressionInfo            ForwardPropagationSemInf = ()
    type ConstantInfo              ForwardPropagationSemInf = ()
    type FunctionCallInfo          ForwardPropagationSemInf = ()
    type LeftValueInfo             ForwardPropagationSemInf = ()
    type ArrayElemReferenceInfo    ForwardPropagationSemInf = Maybe VariableData --name of the indexed variable
    type InstructionInfo           ForwardPropagationSemInf = ()
    type AssignmentInfo            ForwardPropagationSemInf = ()
    type ProcedureCallInfo         ForwardPropagationSemInf = ()
    type ActualParameterInfo       ForwardPropagationSemInf = ()
    type IntConstantInfo           ForwardPropagationSemInf = ()
    type FloatConstantInfo         ForwardPropagationSemInf = ()
    type BoolConstantInfo          ForwardPropagationSemInf = ()
    type ArrayConstantInfo         ForwardPropagationSemInf = ()
    type VariableInfo              ForwardPropagationSemInf = Occurrence_place

instance Combine (VarStatFwd, Maybe VariableData) where
    combine a b = (combine (fst a) $ fst b, Nothing)

data ForwardPropagationCollect = ForwardPropagationCollect

instance TransformationPhase ForwardPropagationCollect where
    type From ForwardPropagationCollect = ()
    type To ForwardPropagationCollect = ForwardPropagationSemInf
    type Downwards ForwardPropagationCollect = Occurrence_place
    type Upwards ForwardPropagationCollect = (VarStatFwd, Maybe VariableData)
    downwardsBranchProgramInProgram self d orig = occurrenceDownwards orig
    downwardsSequentialLoopProgramInProgram self d orig = occurrenceDownwards orig
    downwardsParallelLoopProgramInProgram self d orig = occurrenceDownwards orig
    downwardsFormalParameter self d orig = occurrenceDownwards orig
    downwardsLocalDeclaration self d orig = occurrenceDownwards orig
    downwardsAssignmentInstructionInInstruction self d orig = occurrenceDownwards orig
    downwardsActualParameter self d orig = occurrenceDownwards orig
    downwardsInputActualParameterInActualParameter self d orig = occurrenceDownwards orig
    downwardsExpression self d orig = occurrenceDownwards orig
    transformBlock self d origBlock u = Block {
        blockDeclarations = recursivelyTransformedBlockDeclarations u,
        blockInstructions = recursivelyTransformedBlockInstructions u,
        blockSemInf = selectFromVarStatistics ( declaredVars origBlock) belowStatistics
    } where
        belowStatistics = checkFwdDeclaration (map fst $ upwardsInfoFromBlockDeclarations u) (fst $ upwardsInfoFromBlockInstructions u)
    transformVariable self d origVar = origVar {
        variableSemInf = d
    }
    upwardsVariable self d origVar newVar = case d of
        Occurrence_declare  -> (Map.singleton (variableData origVar) $ Occurrences Zero Zero, Just $ variableData origVar)
        Occurrence_read -> (Map.singleton (variableData origVar) $ Occurrences Zero (One ()), Just $ variableData origVar)
        Occurrence_write ->  (Map.singleton (variableData origVar) $ Occurrences (One Nothing) Zero, Just $ variableData origVar)
        Occurrence_notopt -> (Map.singleton (variableData origVar) $ Occurrences Multiple Multiple, Just $ variableData origVar) --LIE to save variables
    upwardsSequenceProgramInProgram self d origSeq u transSeq = (checkFwdSequence $ map fst $ upwardsInfoFromSequenceProgramList u, Nothing)
    upwardsBlock self d origBlock u newBlock = (deleteFromVarStatistics (declaredVars origBlock) belowStatistics, Nothing) where
        belowStatistics = foldl combine (fst $ upwardsInfoFromBlockInstructions u) $ map fst $ upwardsInfoFromBlockDeclarations u
    upwardsParallelLoopProgramInProgram self d origParLoop u transParLoop = (multipleVarStatistics $
        foldl combine (fst $ upwardsInfoFromParallelLoopConditionVariable u)
                    [fst $ upwardsInfoFromNumberOfIterations u, fst $ upwardsInfoFromParallelLoopCore u], Nothing)
    upwardsAssignmentInstructionInInstruction self d origAssign u transAssig = case leftValueData $ assignmentLhs origAssign of
        VariableLeftValue vlv -> (Map.insert var occ $ fst $ upwardsInfoFromAssignmentRhs u, Nothing)
            where
                var = variableData vlv
                occ = Occurrences (One $ Just (assRs, Map.keys $ fst $ upwardsInfoFromAssignmentRhs u, False)) Zero
                assRs = case transAssig of 
                    AssignmentInstruction newAssign -> expressionData $ assignmentRhs newAssign
                    _ -> fwdPropError $ "Internal error: ForwardPropagation/1!"
        ArrayElemReferenceLeftValue aer -> (combine (fst $ upwardsInfoFromAssignmentLhs u) (fst $ upwardsInfoFromAssignmentRhs u), Nothing)
    upwardsLocalDeclaration self d origDecl u newDecl = case  localInitValue newDecl of
        Nothing -> defaultCase
        Just exp -> case expressionData exp of
            ConstantExpression (Constant (ArrayConstant ac) ()) -> defaultCase
            initExp -> case upwardsInfoFromLocalInitValue u of
                Nothing -> defaultCase
                Just justUpFromLocalInitValue -> (Map.insert var (occ initExp $ fst justUpFromLocalInitValue) $ fst justUpFromLocalInitValue, Nothing)
        where
                    var = variableData $ localVariable origDecl
                    occ initExp justUpFromLocalInitValue = Occurrences (One $ Just (initExp, Map.keys justUpFromLocalInitValue, False)) Zero
                    defaultCase = (fst $ upwardsInfoFromLocalVariable u, Nothing)
    upwardsProcedureCallInstructionInInstruction self d origProcCall u transProcCall
        | List.isPrefixOf "copy" $ nameOfProcedureToCall origProcCall = case  map actualParameterData actParams of -- TODO: eliminate string constant
            [InputActualParameter inArr, InputActualParameter arrSize, OutputActualParameter outArr] ->
                case leftValueData outArr of
                    VariableLeftValue vlv -> (Map.insert (var vlv) (occ inArr) $ fst $ head ul, Nothing)
                    ArrayElemReferenceLeftValue aer -> defaultTr
            _ -> defaultTr
        | otherwise = defaultTr
        where 
            defaultTr = case ul of
                [] -> defaultValue
                otherwise -> foldl combine (head ul) (tail ul)
            ul = upwardsInfoFromActualParametersOfProcedureToCall u
            actParams = case transProcCall of
                ProcedureCallInstruction pc -> actualParametersOfProcedureToCall pc
                _ -> fwdPropError $ "Internal error: ForwardPropagation/2!"
            var vlv = variableData vlv
            occ inArr = Occurrences (One $ Just (expressionData inArr, Map.keys $ fst $ head ul, False)) Zero
    transformSequentialLoopProgramInProgram self d origSeqLoop u = SequentialLoopProgram $ origSeqLoop {
        sequentialLoopCondition = recursivelyTransformedSequentialLoopCondition u,
        conditionCalculation = (recursivelyTransformedConditionCalculation u) {
                blockSemInf = Map.empty 
            },
        sequentialLoopCore = recursivelyTransformedSequentialLoopCore u,
        sequentialLoopSemInf = blockSemInf $ recursivelyTransformedConditionCalculation u
    }
    
    upwardsSequentialLoopProgramInProgram self d origSeqLoop u newSeqLoop = (multipleVarStatistics $
        combine  (deleteFromVarStatistics [condVar] $ fst $ upwardsInfoFromSequentialLoopCondition u) $  fst $ upwardsInfoFromSequentialLoopCore u, Nothing)
        where
            condVar = head $ Map.keys $ fst $ upwardsInfoFromSequentialLoopCondition u
    transformArrayElemReferenceLeftValueInLeftValue self d origArrRef u = ArrayElemReferenceLeftValue $ ArrayElemReference {
        arrayName = recursivelyTransformedArrayName u,
        arrayIndex = recursivelyTransformedArrayIndex u,
        arrayElemReferenceSemInf = snd $ upwardsInfoFromArrayName u 
    }
    upwardsArrayElemReferenceLeftValueInLeftValue self d origArrayRef u transArrayRefe =
        (combine (fst $ upwardsInfoFromArrayName u) (fst $ upwardsInfoFromArrayIndex u), snd $ upwardsInfoFromArrayName u)
    --upwardsLeftValue self d origLV u transLV = upwardsInfoFromLeftValueData u
    upwardsVariableLeftValueInLeftValue self d origVar transVar = upwardsVariable self d origVar $ transformVariable self d origVar
    transformVariableLeftValueInLeftValue self d origVar = VariableLeftValue $ transformVariable self d origVar

checkFwdSequence :: [VarStatFwd]  -> VarStatFwd
checkFwdSequence [] = defaultValue
checkFwdSequence xs = List.foldl checkInSeq Map.empty xs
    where
        checkInSeq :: VarStatFwd -> VarStatFwd -> VarStatFwd
        checkInSeq preSeq curr = combine curr $ Map.mapWithKey (updatePreSeq curr) preSeq
        updatePreSeq :: VarStatFwd -> VariableData -> OccurrencesFwd -> OccurrencesFwd
        updatePreSeq curr preSeqVar preSeqOcc = case writeVar preSeqOcc of
            One (Just (preSeqExp,preSeqVars,preSeqVarsWritten))
                | preSeqVarsWritten && curr `hasRead` preSeqVar -> Occurrences (One Nothing) $ readVar preSeqOcc
                | any (hasWrite curr) preSeqVars -> case (curr `hasRead` preSeqVar)  && not ((simpleType $ variableDataType preSeqVar) && readVar preSeqOcc /= Multiple) of
                    True -> Occurrences (One Nothing) $ readVar preSeqOcc
                    False -> Occurrences (One (Just (preSeqExp,preSeqVars ++ (addDep curr preSeqVar),True))) $ readVar preSeqOcc
                | otherwise -> case curr `getWrite` preSeqVar of
                    Nothing -> preSeqOcc
                    Just (exp,vars,varsWritten)
                        | exp == preSeqExp -> Occurrences Zero $ readVar preSeqOcc
                        | otherwise -> preSeqOcc
            _ -> preSeqOcc
        addDep curr preSeqVar = case curr `getWrite` preSeqVar of
            Nothing -> []
            Just (exp,vars,varsWritten) -> vars

checkFwdDeclaration :: [VarStatFwd] -> VarStatFwd -> VarStatFwd
checkFwdDeclaration [] blockStat = blockStat
checkFwdDeclaration declStat blockStat = checkFwdSequence $ declStat ++ [blockStat]

-- ====================
--  ForwardPropagation
-- ====================

type VarWrite t = [(VariableData,ExpressionData t)]

toVarWrite :: VarStatFwd -> VarWrite ForwardPropagationSemInf
toVarWrite vs = Map.foldWithKey (getExp) [] vs where
    getExp :: VariableData -> OccurrencesFwd -> VarWrite ForwardPropagationSemInf -> VarWrite ForwardPropagationSemInf
    getExp name (Occurrences (One (Just (exp,_,_))) reads) vw 
        | reads /= Multiple && notConstArray exp = (name,exp):vw --used once and complex expr
        | simpleExpr exp = (name,exp):vw --used several and simple expr
        | otherwise = vw
    getExp name _ vw = vw
    notConstArray e = case e of
        ConstantExpression (Constant c _) -> simplConst c
        _ -> True
    simpleExpr e = case e of
        ConstantExpression (Constant c _) -> simplConst c
        LeftValueExpression l -> case leftValueData l of
            VariableLeftValue v -> True
            ArrayElemReferenceLeftValue a -> simpleExpr $ expressionData $ arrayIndex a
        _ -> False
    simplConst (ArrayConstant ac) = False
    simplConst _ = True

data ForwardPropagationTransform = ForwardPropagationTransform

instance TransformationPhase ForwardPropagationTransform where
    type From ForwardPropagationTransform = ForwardPropagationSemInf
    type To ForwardPropagationTransform = ()
    type Downwards ForwardPropagationTransform = VarStatFwd
    type Upwards ForwardPropagationTransform = Set.Set VariableData
    downwardsBlock self d origBlock = combine d $ blockSemInf origBlock
    downwardsSequentialLoopProgramInProgram self d origSeqLoop = combine d $ sequentialLoopSemInf origSeqLoop
    transformLeftValueExpressionInExpression self d origLV u = case leftValueData origLV of
            VariableLeftValue origVar -> case List.find (\(vn,e) -> (vn == variableData origVar)) varwrite of
                    Nothing -> defaultTr
                    Just repl -> expressionData $ fst $ walkExpression self d $ Expression (snd repl) ()
            ArrayElemReferenceLeftValue origArr -> defaultTr
        where
            varwrite = toVarWrite d
            defaultTr = LeftValueExpression $ LeftValue {
                leftValueData = recursivelyTransformedLeftValueData u,
                leftValueSemInf = ()
            }
    transformVariableLeftValueInLeftValue self d origVar = case List.find (\(vn,e) -> (vn == var)) varwrite of
            Nothing -> defaultTr
            Just repl  -> case repl of
                    (_,LeftValueExpression lv) -> leftValueData $ fst $ walkLeftValue self d lv
                    _ -> defaultTr
        where 
            var = variableData origVar
            varwrite = toVarWrite d
            defaultTr = VariableLeftValue $ origVar {
                variableSemInf = ()
            }
    transformArrayElemReferenceLeftValueInLeftValue self d origArrayRef u = case List.find (\(vn,e) -> (vn == var)) varwrite of
            Nothing -> defaultTr
            Just repl  -> case repl of
                    (_,LeftValueExpression lv) -> case leftValueData lv of
                        VariableLeftValue vlv -> defaultTr
                        ArrayElemReferenceLeftValue aer -> ArrayElemReferenceLeftValue $ ArrayElemReference {
                            arrayName = fst $ walkLeftValue self (swapArrayIndex d var aer origArrayRef) $ arrayName origArrayRef,
                            arrayIndex = fst $ walkExpression self d $ arrayIndex aer,
                            arrayElemReferenceSemInf = ()
                        }
                    _ -> defaultTr
        where
            swapArrayIndex :: VarStatFwd -> VariableData -> ArrayElemReference ForwardPropagationSemInf -> ArrayElemReference ForwardPropagationSemInf -> VarStatFwd
            swapArrayIndex d var rep orig = Map.adjust (swapArrayIndex2 var rep orig) var d
            swapArrayIndex2 var rep orig x = x {
                writeVar = One $ Just ( LeftValueExpression $ LeftValue {
                    leftValueData = ArrayElemReferenceLeftValue $ ArrayElemReference {
                        arrayName = arrayName rep,
                        arrayIndex = arrayIndex orig,
                        arrayElemReferenceSemInf = Just  var
                    },  
                    leftValueSemInf = () 
                },[],False)
            }
            var = getJust $ arrayElemReferenceSemInf origArrayRef
            getJust (Just a) = a
            getJust _ = fwdPropError $ "Internal error: ForwardPropagation/3!"
            varwrite = toVarWrite d
            defaultTr = ArrayElemReferenceLeftValue $ ArrayElemReference {
                arrayName = recursivelyTransformedArrayName u,
                arrayIndex = recursivelyTransformedArrayIndex u,
                arrayElemReferenceSemInf = convert $ arrayElemReferenceSemInf origArrayRef
            }
    upwardsVariable self d origVar newVar = case variableSemInf origVar of
        Occurrence_declare  -> Set.empty
        Occurrence_read -> Set.empty
        Occurrence_write -> Set.singleton (variableData origVar)
        Occurrence_notopt -> Set.empty
    upwardsBlock self d origBlock u transformedBlock = foldl (\s e -> Set.delete e s) (upwardsInfoFromBlockInstructions u) (declaredVars origBlock) --Not need just optimalize compliler (not try delete locals outside block)
    transformBlock self d origBlock u = delUnusedDecl (map fst $ toVarWrite $ combine d $ blockSemInf origBlock) origBlock (recursivelyTransformedBlockDeclarations u) (recursivelyTransformedBlockInstructions u)
    transformPrimitiveProgramInProgram self d originalPrimitive u
            | canDelete && deletablePrimitive  = EmptyProgram $ Empty ()
            | otherwise = PrimitiveProgram $ Primitive {
                    primitiveInstruction = recursivelyTransformedPrimitiveInstruction u,
                    primitiveSemInf = ()
                }
        where
            canDelete = Set.isSubsetOf (upwardsInfoFromPrimitiveInstruction u) (Set.fromList $ map fst $ toVarWrite d)
            deletablePrimitive = case instructionData $ primitiveInstruction originalPrimitive of
                ProcedureCallInstruction pc -> List.isPrefixOf "copy" $ nameOfProcedureToCall pc
                AssignmentInstruction ass -> True
    --need because of the new pluginarcitecture walk structure
    upwardsVariableLeftValueInLeftValue self d origVar transVar = upwardsVariable self d origVar $ transformVariable self d origVar
    transformLeftValue self d origLV u = case transformLeftValueExpressionInExpression self d origLV u of
		LeftValueExpression lv -> lv
		_  -> fwdPropError $ "Internal error: ForwardPropagation/4!"