{-
 - Copyright (c) 2009, ERICSSON AB All rights reserved.
 - 
 - Redistribution and use in source and binary forms, with or without
 - modification, are permitted provided that the following conditions
 - are met:
 - 
 -     * Redistributions of source code must retain the above copyright
 -     notice,
 -       this list of conditions and the following disclaimer.
 -     * Redistributions in binary form must reproduce the above copyright
 -       notice, this list of conditions and the following disclaimer
 -       in the documentation and/or other materials provided with the
 -       distribution.
 -     * Neither the name of the ERICSSON AB nor the names of its
 -     contributors
 -       may be used to endorse or promote products derived from this
 -       software without specific prior written permission.
 - 
 - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 -}

module Feldspar.Compiler.Imperative.Representation where

import Data.Maybe
import qualified Data.Map as Map

------------------------------------------------
-- Data types to encode an imperative program --
------------------------------------------------

data Size =
        S4
    |   S8
    |   S16
    |   S32
    |   S64
    deriving (Eq,Show)

data Signedness =
        ImpSigned
    |   ImpUnsigned
    deriving (Eq,Show)

data Type =
        BoolType
    |   FloatType
    |   Numeric Signedness Size
    |   ImpArrayType (Maybe Int) Type
    |   Pointer Type
    deriving (Eq,Show)

data ImpLangExpr =
        Expr
        { exprCore :: UntypedExpression
        , exprType :: Type
        }
    deriving (Eq,Show)

data Variable =
        Var { name :: String, kind :: ParameterKind, vartype :: Type}
    deriving (Eq,Show)

data LeftValue = 
       LVar Variable
    |  ArrayElem
            LeftValue    -- array variable
            ImpLangExpr -- index 
    |  PointedVal LeftValue
    deriving (Eq, Show) 

data UntypedExpression =
        LeftExpr LeftValue  
    |   AddressOf LeftValue
    |   ConstExpr Constant
    |   FunCall FunRole String [ImpLangExpr]
    deriving (Eq,Show)

data Constant
    = IntConst Int
    | FloatConst Float
    | BoolConst Bool
    | ArrayConst Int [Constant]
    deriving (Eq,Show)

data FunRole = SimpleFun | InfixOp | PrefixOp deriving (Eq,Show)

data Instruction =
        Assign LeftValue ImpLangExpr
    |   CFun String [Parameter]
    deriving (Eq,Show)

data Parameter
    = In ImpLangExpr
    | Out (ParameterKind,ImpLangExpr)
    deriving (Eq,Show)

data ParameterKind = Normal | OutKind
    deriving (Eq,Show)

data ImpFunction =
    Fun { funName :: String, 
          inParameters :: [Declaration],
          outParameters :: [Declaration],
          prg :: CompleteProgram
        }
    deriving (Eq,Show)

data CompleteProgram =
    CompPrg { 
                locals :: [Declaration], 
                body :: Program
            }
    deriving (Eq,Show)

data Declaration
    = Decl
    { var :: Variable
    , declType :: Type
    , initVal :: Maybe ImpLangExpr
    , semInfVar :: SemInfVar
    }
    deriving (Eq,Show)

data Program =
        Empty
    |   Primitive Instruction SemInfPrim
    |   Seq [Program] SemInfPrgSeq
    |   IfThenElse 
            Variable                        -- condition variable
            CompleteProgram                 -- then part
            CompleteProgram                 -- else part
            SemInfIf                        -- semantic info
    |   SeqLoop
            Variable                        -- condition variable
            CompleteProgram                 -- condition calculation
            CompleteProgram                 -- loop body
            SemInfSeqLoop                   -- semantic info
    |   ParLoop
            Variable                        -- counter (this is expected to be an integer)
            ImpLangExpr                     -- number of iterations
            Int                             -- step
            CompleteProgram                 -- loop body
            SemInfParLoop                   -- semantic info
    deriving (Eq,Show)

data Array =
        Array
            Variable    -- array typed var
            Type        -- element type
            Int         -- length of array  
    deriving (Eq,Show)

------------------------
-- C code genetartion --
------------------------

class ToC a where
    toC :: Int -> a -> String

compToC :: ToC a => a -> String
compToC x = toC 0 x

instance ToC Size where
    toC sc S8 = "char"
    toC sc S16 = "short"
    toC sc S32 = "int"
    toC sc S64 = "long"

instance ToC Signedness where
    toC sc ImpSigned = "signed"
    toC sc ImpUnsigned = "unsigned"

instance ToC ImpLangExpr where
    toC sc (Expr ue t) = toC sc ue

instance ToC Type where
    toC sc BoolType = "int"
    toC sc FloatType = "float"
    toC sc (Numeric sig siz) = (toC sc sig) ++ " " ++ (toC sc siz)
    toC sc (ImpArrayType _ t) = (toC sc t) ++ "[]"   -- TODO: ImpArrayType Just ...
    toC sc (Pointer t) = (toC sc t) ++ "*"

instance ToC Variable where
    toC sc (Var s k t)
        | simpleType t && k == OutKind = "*" ++ s
        | otherwise = s
    
instance ToC LeftValue where
    toC sc (LVar v) = toC sc v
    toC sc (ArrayElem v e) = (toC sc v) ++ "[" ++ (toC sc e) ++ "]"
    toC sc (PointedVal v) = ("*(" ++ toC sc v ++ ")")

instance ToC UntypedExpression where
    toC sc (LeftExpr v) = (toC sc v)
    toC sc (AddressOf v) = ("&(" ++ toC sc v ++ ")")
    toC sc (ConstExpr c) = toC sc c
    toC sc (FunCall InfixOp s [a,b]) = "(" ++ toC sc a ++ " " ++ s ++ " " ++ toC sc b ++ ")"
    toC sc (FunCall _ s es) = s ++ "(" ++ (listprint (toC sc) ", " es) ++ ")"

instance ToC Constant where
    toC sc (IntConst i) = show i
    toC sc (FloatConst i) = show i ++ "f"
    toC sc (BoolConst True) = "1"
    toC sc (BoolConst False) = "0"
    toC sc (ArrayConst ln elements) = "{" ++ toCArray (ArrayConst ln elements) ++ "}"

toCArray:: Constant -> String
toCArray (ArrayConst ln elements) = listprint toCArray "," elements
toCArray i = toC 0 i

instance ToC Instruction where
    toC sc (Assign v e) = (toC sc v) ++ " = " ++ (toC sc e)
    toC sc (CFun s es) = s ++ "(" ++ (listprint (toC sc) ", " es) ++ ")"

instance ToC Parameter where
    toC sc (In e) = toC sc e
    toC sc (Out (kind,e))
        | kind == Normal && simpleType (exprType e) = "&(" ++ toC sc e ++ ")"
        | otherwise                 = toC sc e
        
instance ToC ImpFunction where
    toC sc (Fun funName inParameters outParameters prg) =
        "void " ++ funName
        ++ "( " ++ ( listprint toCParam ", " $ inParameters ++ outParameters ) ++ " )" -- function parameters
        ++ "\n{\n" ++ (toC (sc+1) prg) ++ "}\n\n"   -- core function
        where
            toCParam:: Declaration -> String
            toCParam (Decl v BoolType _ _) = toC 0 BoolType ++ (' ' : (toC 0 v))
            toCParam (Decl v FloatType _ _) = toC 0 FloatType ++ (' ' : (toC 0 v))
            toCParam (Decl v n@(Numeric sig siz) _ _) = (toC 0 n) ++ " " ++ (toC 0 v)
            toCParam (Decl v (Pointer t) _ _) = (toC 0 t) ++ "* " ++ (toC 0 v)
            toCParam (Decl v t _ _) = (toCPrimType t) ++ " " ++ (toC 0 v) ++ arrayDepths t

arrayDepths :: Type -> String
arrayDepths (ImpArrayType (Just n) t) = "["++(show n)++"]" ++ arrayDepths t
arrayDepths (ImpArrayType Nothing t) = "[16]" ++ arrayDepths t
arrayDepths _ = ""

instance ToC CompleteProgram where
    toC sc (CompPrg locals body) = (foldl (++) "" (map (\x-> (toC sc x)) locals)) ++ "\n" ++ (toC sc body)

instance ToC Declaration where
    toC sc (Decl var declType initExpr inf)
        = tab sc ++ (toCdecl var declType "" (isInit initExpr)) ++ (declMay initExpr) ++ ";\n"
                -- without seminf
        -- = tab sc ++ (toCdecl var declType "" (isInit initExpr)) ++ (declMay initExpr) ++ "; " ++ show inf ++ "\n"
                -- with seminf
        where
            declMay :: (Maybe ImpLangExpr) -> String
            declMay (Just initVal) = " = " ++ (toC 0 initVal)
            declMay Nothing = ""
         
            toCdecl:: Variable -> Type -> String -> Bool -> String
            toCdecl var (ImpArrayType _ t) _ True = (toCPrimType t) ++ (replicateArrayDepth t "*" 1) ++ " " ++ (toC 0 var)
            toCdecl var (ImpArrayType Nothing t) str False = (toCdecl var t (str ++ "[16]") False)   
            toCdecl var (ImpArrayType (Just ln) t) str False = (toCdecl var t (str ++ "["++ show ln ++"]") False) 
            toCdecl var declType str _ = (toC 0 declType) ++ " " ++ (toC 0 var) ++ str
            
            isInit Nothing = False
            isInit (Just initExpr) = 
               case exprCore initExpr of
                  (ConstExpr _)  -> False
                  _              -> True 

instance ToC Program where
    toC sc Empty = ""
    toC sc (Primitive i seminf)
        = (tab sc) ++ (toC sc i) ++ ";\n"                               -- without seminf
        -- = (tab sc) ++ (toC sc i) ++ ";\n" ++ toC (sc+1) seminf ++ "\n"  -- with seminf
    toC sc (Seq ps _) = foldr (++) "" $ map (toC sc) ps
    toC sc (IfThenElse con tPrg ePrg _) 
        = (tab sc) ++ "if(" ++ (toC sc con) ++ ")\n"++ (tab sc) ++"{\n" ++ (toC (sc+1) tPrg) ++ (tab sc) ++ "}\n"
             ++ (tab sc) ++ "else\n" ++ (tab sc) ++ "{\n" ++ (toC (sc+1) ePrg) ++ (tab sc) ++ "}\n"
    toC sc (SeqLoop condVar condCalc loopBody _) 
        = (tab sc) ++ "{\n" ++ (toC (sc+1) condCalc) ++ (tab $ sc+1)
          ++ "while(" ++ (toC 0 condVar) ++ ")\n" ++ tab (sc+1) ++ "{\n" 
          ++ (toC (sc+2) loopBody) ++ (toC (sc+2) (body condCalc)) ++ (tab $ sc+1) ++ "}\n" ++ (tab sc) ++ "}\n"
    toC sc (ParLoop (Var cv _ _) num step prg _) = (tab sc) ++ "{\n" ++ toCPar (sc+1) ++ (tab sc) ++ "}\n"
        where toCPar sc =
                 (tab sc) ++ "int " ++ cv ++ ";\n"
                 ++ (tab sc) ++ "for( " ++ cv ++ " = 0; " ++ cv ++ " < " ++ (toC 0 num) ++ "; " ++ cv ++ " += " ++ (show step) ++")\n"
                 ++ (tab sc) ++ "{\n" ++ (toC (sc+1) prg) ++ (tab sc) ++ "}\n"

instance ToC SemInfPrim where
    toC sc seminf
        | output seminf = tab sc ++ "// !!!\n" ++ stat 
        | otherwise     = stat
        where
            stat = tab sc ++ "// " ++ listprint (\(var,stat) -> var ++ " in this instruction: " ++ show stat) ("\n" ++ tab sc ++ "// ")  (Map.toList $ varMap seminf)
                                     
instance ToC a => ToC (Maybe a) where
     toC sc Nothing = ""
     toC sc (Just a) = toC sc a

instance (ToC a) => ToC [a] where
    toC sc xs = concatMap (toC sc) xs

instance ToC Array where
    toC sc (Array v t i) = (toC sc v)

----------------------
-- Helper functions --
----------------------

simpleType :: Type -> Bool
simpleType BoolType = True
simpleType FloatType = True
simpleType (Numeric _ _) = True
simpleType (ImpArrayType _ _) = False
simpleType (Feldspar.Compiler.Imperative.Representation.Pointer _) = False

toCPrimType:: Type -> String
toCPrimType (ImpArrayType _ t) = toCPrimType t
toCPrimType t = toC 0 t

isArrayType:: Type -> String
isArrayType (ImpArrayType _ t) = "* const"
isArrayType _ = ""

tab sc = replicate (sc * 4) ' '

listprint :: (a->String) -> String -> [a] -> String
listprint _ _ [] = ""
listprint f _ [x] = f x
listprint f s (x:y:xs) = f x ++ s ++ listprint f s (y:xs)

toLeftValue :: ImpLangExpr -> LeftValue
toLeftValue (Expr (LeftExpr lv) _) = lv
toLeftValue e = error $ "Error: " ++ toC 0 e ++ " is not a left value."

replicateArrayDepth:: Type -> String -> Int-> String    --String: what to replicate; Int: modifier
replicateArrayDepth t n m = filter (/=' ') $ unwords $ replicate ( (arrayDepth t) +m) n 
arrayDepth:: Type -> Int
arrayDepth (ImpArrayType _ t) = 1 + (arrayDepth t)
arrayDepth _ = 0

getVariable :: ImpLangExpr -> Maybe Variable
getVariable (Expr (LeftExpr (LVar v)) _) = Just v
getVariable _ = Nothing

contains :: String -> ImpLangExpr -> Bool
contains n (Expr e _) = contains' n e where
    contains' n (LeftExpr lv) = contains'' n lv
    contains' n (AddressOf lv) = contains'' n lv
    contains' _ (ConstExpr _) = False
    contains' n (FunCall _ _ es) = any (contains n) es
    contains'' n (LVar (Var n' _ _)) = n == n'
    contains'' n (ArrayElem lv exp) = contains'' n lv || contains n exp
    contains'' n (PointedVal lv) = contains'' n lv

getVarName :: LeftValue -> String
getVarName (LVar (Var n _ _)) = n
getVarName (ArrayElem lv _) = getVarName lv
getVarName (PointedVal lv) = getVarName lv

getLeftValue :: ImpLangExpr -> LeftValue
getLeftValue (Expr (LeftExpr lv) t) = lv
getLeftValue e = error $ "Error in Compiler.Imperative.Representation.getLeftValue:\n" ++ toC 0 e

{-
isInParam :: Parameter -> Bool
isInParam (In _) = True
isInParam _ = False
-}

--------------------------------------
-- Semantics of imperative programs --
--------------------------------------

type VariableMap = Map.Map String SemInfVar

data SemInfPrim
    = SemInfPrim
    { varMap :: VariableMap
    , output :: Bool
    }
    deriving (Eq,Show)

data SemInfVar
    = SemInfVar
    { usedLeft :: LeftUse
    , usedRight :: RightUse
    }
    deriving (Eq)

instance Show SemInfVar where
    show sem = show (usedLeft sem) ++ ", " ++ show (usedRight sem)

unknownSemInfVar = SemInfVar UnknownL UnknownR

data LeftUse = UnknownL | None | Single (Maybe ImpLangExpr) | MultipleL
    deriving (Eq)    

data RightUse = UnknownR | Times Int | MultipleR
    deriving (Eq)    

getValue :: SemInfVar -> ImpLangExpr
getValue s = case usedLeft s of
    Single (Just expr)  -> expr
    otherwise           -> error $ "Error in Representation.getValue for the semantic information:\n" ++ show s

leftVars :: VariableMap -> [String]
leftVars sem = Map.keys $ Map.filter isLeft sem where
    isLeft :: SemInfVar -> Bool
    isLeft sem
        | usedLeft sem == None  = False
        | otherwise             = True

instance Show LeftUse where
    show l = "set: " ++ case l of
        UnknownL -> "no information"
        None -> "never"
        Single Nothing -> "once"
        Single (Just e) -> "once (" ++ toC 0 e ++ ")"
        MultipleL -> "multiple times"

instance Show RightUse where
    show r = "used: " ++ case r of
        UnknownR -> "no information"
        Times i -> show i ++ " times"
        MultipleR -> "multiple times"

type SemInfPrgSeq = [String]
type SemInfBr = [String]
type SemInfParLoop = [String]
type SemInfIf = [String]
type SemInfSeqLoop = [String]
type SemInfSeq = [String]

--------------------------------------------------------
-- Computing statistics of variables in an expression --
-- on the right and left hand sides of an assignement --
--------------------------------------------------------

class RightVarMap a where
    rightVarMap :: a -> VariableMap

instance RightVarMap ImpLangExpr where
    rightVarMap e = rightVarMap $ exprCore e

instance RightVarMap UntypedExpression where
    rightVarMap (LeftExpr lv) = rightVarMap lv
    rightVarMap (AddressOf lv) = rightVarMap lv
    rightVarMap (ConstExpr _) = Map.empty
    rightVarMap (FunCall _ _ es) = foldr addVarMap Map.empty $ map rightVarMap es

instance RightVarMap LeftValue where
    rightVarMap (LVar (Var name _ _)) = Map.singleton name $ SemInfVar None (Times 1)
    rightVarMap (ArrayElem lv e) = addVarMap (rightVarMap lv) (rightVarMap e)
    rightVarMap (PointedVal e) = rightVarMap e

leftVarMap :: LeftValue -> Maybe ImpLangExpr -> VariableMap
leftVarMap (LVar (Var name _ _)) expr = Map.singleton name $ SemInfVar (Single expr) (Times 0)
leftVarMap (ArrayElem lv e) _ = addVarMap (leftVarMap lv Nothing) (rightVarMap e)
leftVarMap (PointedVal e) _ = leftVarMap e Nothing

addVarMap :: VariableMap -> VariableMap -> VariableMap
addVarMap m1 m2 = Map.unionWith addSemInfVar m1 m2 where

addSemInfVar s1 s2
    = SemInfVar
    { usedLeft = combineLeft (usedLeft s1) (usedLeft s2)
    , usedRight = combineRight (usedRight s1) (usedRight s2)
    } where
        combineLeft UnknownL _ = UnknownL
        combineLeft _ UnknownL = UnknownL
        combineLeft None x = x
        combineLeft x None = x
        combineLeft _ _ = MultipleL
        combineRight UnknownR _ = UnknownR
        combineRight _ UnknownR = UnknownR
        combineRight (Times x) (Times y) = Times (x + y)
        combineRight _ _ = MultipleR