{-# LANGUAGE FlexibleInstances, TypeFamilies #-}

module Feldspar.Compiler.Imperative.Plugin.Unroll where

import Data.List (elem)

import Feldspar.Compiler.Backend.C.Options
import Feldspar.Transformation

-- ============================
-- == Unroll's Semantic info ==
-- ============================

data SemInfPrg = SemInfPrg
    {    position    :: Int
    ,    varNames    :: [String]
    ,    loopVar        :: String
    } deriving (Eq, Show)

data UnrollSemInf

instance Annotation UnrollSemInf Module where
    type Label UnrollSemInf Module = ()

instance Annotation UnrollSemInf Definition where
    type Label UnrollSemInf Definition = ()

instance Annotation UnrollSemInf Struct where
    type Label UnrollSemInf Struct = ()

instance Annotation UnrollSemInf StructMember where
    type Label UnrollSemInf StructMember = ()

instance Annotation UnrollSemInf Union where
    type Label UnrollSemInf Union = ()

instance Annotation UnrollSemInf UnionMember where
    type Label UnrollSemInf UnionMember = ()

instance Annotation UnrollSemInf Procedure where
    type Label UnrollSemInf Procedure = ()

instance Annotation UnrollSemInf Prototype where
    type Label UnrollSemInf Prototype = ()

instance Annotation UnrollSemInf GlobalVar where
    type Label UnrollSemInf GlobalVar = ()

instance Annotation UnrollSemInf Block where
    type Label UnrollSemInf Block = ()

instance Annotation UnrollSemInf Program where
    type Label UnrollSemInf Program = Maybe SemInfPrg

instance Annotation UnrollSemInf Empty where
    type Label UnrollSemInf Empty = ()

instance Annotation UnrollSemInf Comment where
    type Label UnrollSemInf Comment = ()

instance Annotation UnrollSemInf Assign where
    type Label UnrollSemInf Assign = ()

instance Annotation UnrollSemInf ProcedureCall where
    type Label UnrollSemInf ProcedureCall = ()

instance Annotation UnrollSemInf Sequence where
    type Label UnrollSemInf Sequence = ()

instance Annotation UnrollSemInf Branch where
    type Label UnrollSemInf Branch = ()

instance Annotation UnrollSemInf Switch where
    type Label UnrollSemInf Switch = ()

instance Annotation UnrollSemInf SeqLoop where
    type Label UnrollSemInf SeqLoop = ()

instance Annotation UnrollSemInf ParLoop where
    type Label UnrollSemInf ParLoop = ()

instance Annotation UnrollSemInf SwitchCase where
    type Label UnrollSemInf SwitchCase = ()

instance Annotation UnrollSemInf ActualParameter where
    type Label UnrollSemInf ActualParameter = ()

instance Annotation UnrollSemInf Declaration where
    type Label UnrollSemInf Declaration = ()

instance Annotation UnrollSemInf Expression where
    type Label UnrollSemInf Expression = ()

instance Annotation UnrollSemInf FunctionCall where
    type Label UnrollSemInf FunctionCall = ()

instance Annotation UnrollSemInf Cast where
    type Label UnrollSemInf Cast = ()

instance Annotation UnrollSemInf SizeOf where
    type Label UnrollSemInf SizeOf = ()

instance Annotation UnrollSemInf ArrayElem where
    type Label UnrollSemInf ArrayElem = ()

instance Annotation UnrollSemInf StructField where
    type Label UnrollSemInf StructField = ()

instance Annotation UnrollSemInf UnionField where
    type Label UnrollSemInf UnionField = ()

instance Annotation UnrollSemInf Constant where
    type Label UnrollSemInf Constant = ()

instance Annotation UnrollSemInf IntConst where
    type Label UnrollSemInf IntConst = ()

instance Annotation UnrollSemInf FloatConst where
    type Label UnrollSemInf FloatConst = ()

instance Annotation UnrollSemInf BoolConst where
    type Label UnrollSemInf BoolConst = ()

instance Annotation UnrollSemInf ArrayConst where
    type Label UnrollSemInf ArrayConst = ()

instance Annotation UnrollSemInf ComplexConst where
    type Label UnrollSemInf ComplexConst = ()

instance Annotation UnrollSemInf Variable where
    type Label UnrollSemInf Variable = ()

-- ==
-- == Plugin
-- ==

instance Default Bool where
    def = False

instance Combine Bool where
    combine = (||)

instance Default (Maybe SemInfPrg) where def = Nothing    


instance Plugin UnrollPlugin where
    type ExternalInfo UnrollPlugin = UnrollStrategy
    executePlugin UnrollPlugin ei p = case ei of
        NoUnroll -> p
        Unroll unrollCount -> result $ transform Unroll_2 () Nothing $ result $ transform Unroll_1 () unrollCount p
    
data UnrollPlugin = UnrollPlugin
instance Transformation UnrollPlugin where
    type From UnrollPlugin      = ()
    type To UnrollPlugin        = ()
    type Down UnrollPlugin      = ()
    type Up UnrollPlugin        = ()
    type State UnrollPlugin     = ()

data Unroll_1 = Unroll_1
instance Transformation Unroll_1 where
    type From Unroll_1      = ()
    type To Unroll_1        = UnrollSemInf
    type Down Unroll_1      = Int
    type Up Unroll_1        = Bool
    type State Unroll_1     = ()

instance Transformable Unroll_1 Program where
    transform t s d p@(ParLoop _ _ _ _ _ _)
        | up tr == False && unrollPossible = tr'
        | otherwise = tr
        where
        tr = defaultTransform t s d p
        tr' = tr 
            { result = (result tr)
                { pLoopStep = d
                , pLoopBlock = loopCore
                    { locals = unrollDecls
                    , blockBody = Sequence prgs () Nothing
                    }
                }
            , up = True
            }
        prgs = map (\(i,p) -> p{ programLabel = (Just $ SemInfPrg i varNames loopCounter) }) $ zip [0,1..] replPrg
        replPrg = replicate d $ blockBody loopCore
        unrollDecls = concat $ map (\(i,ds) -> renameDecls ds i) $ zip [0,1..] replDecls
        renameDecls ds i = map (\d -> renameDeclaration d ((getVarNameDecl d) ++ "_u" ++ (show i))) ds
        replDecls = replicate d $ locals loopCore
        loopCore = pLoopBlock $ result tr 
        loopBound = pLoopBound $ result tr
        loopCounter = varName $ pLoopCounter $ result tr
        varNames = map (\d -> getVarNameDecl d) $ locals loopCore
        unrollPossible = case loopBound of
            (ConstExpr (IntConst i _ _) _) -> mod i (toInteger d) == 0
            _                              -> False
    transform t s d p = defaultTransform t s d p


data Unroll_2 = Unroll_2    
instance Transformation Unroll_2     where
    type From Unroll_2      = UnrollSemInf
    type To Unroll_2        = ()
    type Down Unroll_2      = Maybe SemInfPrg
    type Up Unroll_2        = ()
    type State Unroll_2     = ()

instance Transformable Unroll_2 Program where
    transform t s d p = defaultTransform t s d' p where
        d' = case programLabel p of
            Nothing -> d
            x       -> x 

instance Transformable Unroll_2 Expression where
    transform t s d l = case d of
        Nothing -> tr
        Just x ->  case l of
            VarExpr n _
                | varName n == loopVar x -> tr 
                    { result = FunctionCall 
                        { funCallName = "+"
                        , returnType = NumType Signed S32
                        , funRole = InfixOp
                        , funCallParams = 
                            [ result tr
                            , ConstExpr (IntConst (toInteger $ position x) () ()) ()
                            ]
                        , funCallLabel = ()
                        , exprLabel = ()
                        }
                    }
                | otherwise ->  tr
            _ ->  tr
        where
            tr = defaultTransform t s d l


instance Transformable Unroll_2 Variable where
    transform t s d v = case d of
        Just x
            | (varName v) `elem` (varNames x) -> tr
                { result = (result tr)
                    { varName = (varName v) ++ "_u" ++ (show $ position x)
                    , varLabel = ()
                    }
                }
            | otherwise -> tr
        Nothing -> tr
        where
            tr = defaultTransform t s d v


-- helper functions : 
isJust (Just x) = True
isJust _ = False
getVarNameDecl d = varName $ declVar d
renameDeclaration d n = d { declVar = renameVariable (declVar d) n }
renameVariable v n = v { varName = n    }