--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE FlexibleInstances, TypeFamilies #-}

module Feldspar.Compiler.Plugins.Unroll where

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Options
import Prelude
import Feldspar.Compiler.Imperative.Semantics
import Feldspar.Compiler.PluginArchitecture


instance Plugin UnrollPlugin where
    type ExternalInfo UnrollPlugin = UnrollStrategy
    executePlugin UnrollPlugin ei p = case ei of
        NoUnroll -> p
        Unroll unrollCount -> fst $ executeTransformationPhase Unroll_2 Nothing $ fst $ executeTransformationPhase Unroll_1 unrollCount p
    
data UnrollPlugin = UnrollPlugin
instance TransformationPhase UnrollPlugin where
    type From UnrollPlugin = ()
    type To UnrollPlugin = ()
    type Downwards UnrollPlugin = ()
    type Upwards UnrollPlugin = ()

data Unroll_1 = Unroll_1
instance TransformationPhase Unroll_1 where
    type From Unroll_1 = ()
    type To Unroll_1 = UnrollSemInf
    type Downwards Unroll_1 = Int
    type Upwards Unroll_1 = Bool
    upwardsParallelLoopProgramInProgram _ _ _ _ _ = True
    transformParallelLoopProgramInProgram Unroll_1 d pl u = trParLoop1 d pl u

data Unroll_2 = Unroll_2    
instance TransformationPhase Unroll_2     where
    type From Unroll_2 = UnrollSemInf
    type To Unroll_2 = ()
    type Downwards Unroll_2 = Maybe SemInfPrg
    type Upwards Unroll_2 = ()
    downwardsProgram Unroll_2 d p
        | programSemInf p == Nothing = d
        | otherwise = programSemInf p
    transformVariable Unroll_2 d v = trVariable d v
    transformVariableLeftValueInLeftValue Unroll_2 d v = VariableLeftValue $ trVariable d $ v
    transformLeftValueExpressionInExpression Unroll_2 d lvie u = trLVIE d lvie u

data UnrollSemInf = UnrollSemInf
instance SemanticInfo UnrollSemInf where
    type ProcedureInfo             UnrollSemInf = ()
    type BlockInfo                 UnrollSemInf = ()
    type ProgramInfo               UnrollSemInf = Maybe SemInfPrg
    type EmptyInfo                 UnrollSemInf = Maybe SemInfPrg
    type PrimitiveInfo             UnrollSemInf = Maybe SemInfPrg
    type SequenceInfo              UnrollSemInf = Maybe SemInfPrg
    type BranchInfo                UnrollSemInf = ()
    type SequentialLoopInfo        UnrollSemInf = ()
    type ParallelLoopInfo          UnrollSemInf = ()
    type FormalParameterInfo       UnrollSemInf = ()
    type LocalDeclarationInfo      UnrollSemInf = ()
    type ExpressionInfo            UnrollSemInf = ()
    type ConstantInfo              UnrollSemInf = ()
    type FunctionCallInfo          UnrollSemInf = ()
    type LeftValueInfo             UnrollSemInf = ()
    type ArrayElemReferenceInfo    UnrollSemInf = ()
    type InstructionInfo           UnrollSemInf = ()
    type AssignmentInfo            UnrollSemInf = ()
    type ProcedureCallInfo         UnrollSemInf = ()
    type ActualParameterInfo       UnrollSemInf = ()
    type IntConstantInfo           UnrollSemInf = ()
    type FloatConstantInfo         UnrollSemInf = ()
    type BoolConstantInfo          UnrollSemInf = ()
    type ArrayConstantInfo         UnrollSemInf = ()
    type VariableInfo              UnrollSemInf = ()

instance Combine Bool where
    combine = (||)    

data SemInfPrg = SemInfPrg
    {    position    :: Int
    ,    varNames    :: [String]
    ,    loopVar        :: String
    } deriving (Eq, Show)
instance Default (Maybe SemInfPrg) where defaultValue = Nothing    

trLVIE :: Downwards Unroll_2 -> LeftValue UnrollSemInf -> InfoFromLeftValueParts Unroll_2 -> ExpressionData ()
trLVIE d (LeftValue leftValue _) u = case d of
    Just x -> result x
    otherwise -> orig
    where
        name = case leftValue of
            VariableLeftValue d -> Just $ getVarName d
            otherwise    ->    Nothing
        result x = case name of
            Just n
                | n == loopVar x -> FunctionCallExpression $ FunctionCall InfixOp (Numeric ImpSigned S32) ("+") ([loopVarPar, plusPar]) ()
                | otherwise -> orig
            otherwise -> orig
            where
                loopVarPar = Expression orig ()
                num = position x
                plusPar =  Expression (ConstantExpression $ Constant (IntConstant $ IntConstantType num ()) ()) ()
        orig = LeftValueExpression $ LeftValue (recursivelyTransformedLeftValueData u ) ()  
    
trVariable d v
    | d /= Nothing && elementOf (varNames (valueFromJust d)) (getVarName v) = v { variableName = (variableName v) ++ "_u" ++ (show $ position $ valueFromJust d), variableSemInf = ()}
    | otherwise = v {variableSemInf = ()}

trParLoop1 :: Downwards Unroll_1 -> ParallelLoop () -> InfoFromParallelLoopParts Unroll_1 -> ProgramConstruction UnrollSemInf
trParLoop1 d pl u
    | ( upwardsInfoFromParallelLoopCore u ) == False && (unrollPossible || varInExpr ) = ParallelLoopProgram newParLoop
    | otherwise = ParallelLoopProgram trPl
    where
        newParLoop = trPl {    parallelLoopStep = unrollNum
                        ,    parallelLoopCore = newLoopCore
                        ,    parallelLoopSemInf = ()}
        newLoopCore = origLoopCore 
                        {    blockDeclarations = unrollDecls
                        ,    blockInstructions = unrollPrg
                        ,    blockSemInf = ()}
        unrollPrg = Program (SequenceProgram $ Sequence prgs (Nothing)) (Nothing)
        prgs = map (\(i,p) -> writeSemInfToPrg p (Just $ SemInfPrg i varNames loopCounter)) $ zip [0,1..] replPrg
        writeSemInfToPrg prg semInf = prg { programSemInf = semInf }        
        replPrg = replicate unrollNum origPrg
        origPrg = blockInstructions $ origLoopCore
        unrollDecls = concat $ map (\(i,ds) -> renameDecls ds i) $ zip [0,1..] replDecls
        renameDecls ds i = map (\d -> renameDeclaration d ((getVarNameDecl d) ++ "_u" ++ (show i))) ds
        replDecls = replicate unrollNum origDecls
        origDecls = blockDeclarations $ origLoopCore
        origLoopCore = recursivelyTransformedParallelLoopCore u
        iterExpr = recursivelyTransformedNumberOfIterations u
        loopCounter' = recursivelyTransformedParallelLoopConditionVariable u
        trPl = ParallelLoop loopCounter' iterExpr (parallelLoopStep pl) origLoopCore ()
        unrollNum = d
        loopCounter = getVarName $ recursivelyTransformedParallelLoopConditionVariable u
        varNames = map (\d -> getVarNameDecl d) origDecls
        iterTemp = iterNumFromExpr iterExpr
        origIterNum = valueFromJust iterTemp
        iterNumIsConstant = isJust iterTemp
        unrollPossible = iterNumIsConstant && ( mod origIterNum d == 0 )
        varInExpr = not $ isJust iterTemp

-- helper functions : 
iterNumFromExpr (Expression (ConstantExpression (Constant (IntConstant (IntConstantType i _)) _)) _) = Just i
iterNumFromExpr _ = Nothing
isJust (Just x) = True
isJust _ = False
getVarNameDecl d = getVarName $ localVariable d
getVarName v = variableName v
valueFromJust (Just v) = v
valueFromJust Nothing = error "This was Nothing"
renameDeclaration d n = d { localVariable = renameVariable (localVariable d) n }
renameVariable v n = v { variableName = n    }
elementOf ss s = (length $ filter (\s' -> s' == s) ss) > 0