--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE FlexibleInstances #-}

module Feldspar.Compiler.Imperative.CodeGeneration where

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Imperative.Semantics
import Feldspar.Compiler.Error
import Feldspar.Compiler.Options

import qualified Data.List as List (last,find)

------------------------
-- C code generation --
------------------------

codeGenerationError = handleError "CodeGeneration"

data Place =
      Declaration_pl
      --value of var,           need type,          type array-style
      --declare variables
    | MainParameter_pl
      --value of var            need type,          type pointer-style
      --main fun parameters
    | ValueNeed_pl
      --value of var,           not need type       -
      --in Expressions
    | AddressNeed_pl
      --access of var,          not need type       -
      --output of fun
    | FunctionCallIn_pl
      --value of var,           not need type       - SPEC ARRAY FORMAT
      --input of fun 
    deriving (Eq,Show)

compToC :: ToC a => Platform -> a -> String
compToC m = toC m Declaration_pl

class ToC a where
    toC :: Platform -> Place -> a -> String

instance ToC Type where
    toC m _ t = case (List.find (\(t',_,_) -> t == t') $ types m) of
        Just (_,s,_)  -> s
        Nothing       -> codeGenerationError InternalError $ "Unhandled type in platform " ++ name m
    --arraytype handled in variable

instance ToC (Variable PrettyPrintSemanticInfo) where
    toC m p a@(Variable r t n _) = show_variable m p r t n NoRestrict

show_variable :: Platform -> Place -> VariableRole -> Type -> String -> IsRestrict -> String
show_variable m p r t n restr = listprint (id) " " [variableType, show_name r p t n ++ arrLn] --concat [addSpace $ variableType, show_name r p t n, arrLn]
    where
        (variableType,arrLn) = show_type p t restr
        show_type :: Place -> Type -> IsRestrict -> (String,String)
        show_type MainParameter_pl (ImpArrayType s t@(ImpArrayType s2 t2)) restr = decl_matr_type s t2 s2 restr         
        show_type Declaration_pl (ImpArrayType s t) restr = decl_arr_type t s ("","") 
        show_type MainParameter_pl (ImpArrayType s t) restr = decl_arr_type_0 t s restr
        show_type Declaration_pl t _ = (toC m p t,"")
        show_type MainParameter_pl t _ = (toC m p t,"")
        show_type _ _ _ = ("","")
        
        decl_arr_type_0 :: Type -> Length -> IsRestrict -> (String,String)
        decl_arr_type_0 t s Restrict = ((toC m Declaration_pl t) ++ " * const restrict",  "") 
        decl_arr_type_0 t s _        = ((toC m Declaration_pl t) ++ " *",  "")
        
        decl_matr_type :: Length -> Type -> Length -> IsRestrict -> (String,String)
        decl_matr_type mb t2 s2 Restrict = decl_arr_type t2 s2 (" (* const restrict", ")")       
        decl_matr_type mb t2 s2 _ = decl_arr_type t2 s2 (" (*", ")")
        
        decl_arr_type :: Type -> Length -> (String,String) -> (String,String)
        decl_arr_type (ImpArrayType s2 t2) mb (st1,st2) = decl_arr_type t2 s2 (st1,st2 ++ (show_brackets mb))
        decl_arr_type t mb (st1,st2) =  ((toC m Declaration_pl t) ++ st1,  st2 ++ show_brackets mb)
        
        show_brackets :: Length -> String
        show_brackets Undefined = codeGenerationError InternalError $ "Unattended unknown array size"
        show_brackets (Norm i) = concat["[",show i,"]"]
        show_brackets (Defined i)  = concat["[", show i, defaultArraySizeWarning, "]"]
        
        defaultArraySizeWarning :: String
        defaultArraySizeWarning  = " /* WARNING: Default size used!! */"

        show_name :: VariableRole -> Place-> Type -> String  -> String
        show_name _ FunctionCallIn_pl t@(ImpArrayType _ _) n = concat["&(",n,genIndex t,")"]
        show_name _ AddressNeed_pl t@(ImpArrayType _ _) n = concat["&(",n,genIndex t,")"]
        show_name _ _ (ImpArrayType _ _) n = n
        show_name Value place t n 
            | place == AddressNeed_pl = "&" ++ n
            | otherwise = n
        show_name FunOut place t n
            | place == AddressNeed_pl && List.last n == ']' = "&" ++ n
            | place == AddressNeed_pl && List.last n /= ']' = n
            | place == Declaration_pl = codeGenerationError InternalError $ "You can't declare output variable of the function"
            | place == MainParameter_pl = "* " ++ n
            | List.last n == ']' = n
            | otherwise = "(* " ++ n ++ ")"
        
        genIndex :: Type -> String
        genIndex (ImpArrayType _ t) = "[0]" ++ genIndex t
        genIndex _ = ""

instance ToC (Constant PrettyPrintSemanticInfo) where
    toC m p c = toC m p $ constantData c

instance ToC (ConstantData PrettyPrintSemanticInfo) where
    toC m p a@(ArrayConstant l) = "{" ++ (toCArray m p a) ++ "}"
    toC m _ c = case (List.find (\(t',_) -> t' == typeof c) $ values m) of
        Just (_,f) -> f c
        Nothing    -> case c of
            (IntConstant i)   -> show (intConstantValue i)
            (FloatConstant i) -> show (floatConstantValue i) ++ "f"
            (BoolConstant (BoolConstantType True _))  -> "1"
            (BoolConstant (BoolConstantType False _)) -> "0"
            _ -> codeGenerationError InternalError $ "Unhandled constant in platform " ++ name m

toCArray :: Platform -> Place -> ConstantData PrettyPrintSemanticInfo -> String
toCArray m p (ArrayConstant l) = listprint (toCArray m p) "," (map constantData $ arrayConstantValue l)
toCArray m p i = toC m p i

instance ToC (LeftValue PrettyPrintSemanticInfo) where
    toC m p lv = toC m p $ leftValueData lv

instance ToC (LeftValueData PrettyPrintSemanticInfo) where
    toC m p (VariableLeftValue v) = toC m p v
    toC m p (ArrayElemReferenceLeftValue leftArrayElemReference) = toC m p $ insertIndex (arrayName leftArrayElemReference) where
        insertIndex :: LeftValue PrettyPrintSemanticInfo -> LeftValue PrettyPrintSemanticInfo
        insertIndex (LeftValue (VariableLeftValue variable) semInf) = LeftValue (VariableLeftValue $
            variable {
                variableType = decrArrayDepth (variableType variable),
                variableName = (concat[variableName variable,"[",
                                       toC m ValueNeed_pl (arrayIndex leftArrayElemReference), "]"])
            }) semInf
        insertIndex (LeftValue (ArrayElemReferenceLeftValue leftArrayElemReference) semInf) = LeftValue (
            ArrayElemReferenceLeftValue $ leftArrayElemReference {
                arrayName  = (insertIndex (arrayName leftArrayElemReference)),
                arrayIndex = (arrayIndex leftArrayElemReference)
            }) semInf
instance ToC (ActualParameter PrettyPrintSemanticInfo) where
    toC m p ap = toC m p $ actualParameterData ap
              
instance ToC (ActualParameterData PrettyPrintSemanticInfo) where
    toC m p (InputActualParameter e) = toC m FunctionCallIn_pl e
    toC m p (OutputActualParameter l) = toC m AddressNeed_pl l

instance ToC (Expression PrettyPrintSemanticInfo) where
    toC m p expr = toC m p (expressionData expr)

instance ToC (ExpressionData PrettyPrintSemanticInfo) where
    toC m p (LeftValueExpression lv) = toC m p lv
    toC m p (ConstantExpression c) = toC m p c
    toC m p (FunctionCallExpression (FunctionCall InfixOp _ f [a,b] _)) = concat["(",toC m p a," ",f," ",toC m p b,")"]
    toC m p (FunctionCallExpression (FunctionCall _ t f x _)) = concat [f,"(",listprint (toC m p) ", " x,")"]

instance ToC (Procedure PrettyPrintSemanticInfo) where
    toC m p (Procedure n il ol pr semInf) = concat ["void ",n,"(",param,")\n{\n",prog,"}\n"]
        where
            param = listprint (toC m MainParameter_pl) ", " (il ++ ol)
            prog = ind (toC m Declaration_pl) pr

instance ToC (Block PrettyPrintSemanticInfo) where
    toC m p (Block d pr semInf) = listprint id "\n" [decl,toC m p pr]
        where
            decl = concat $ map (\a->toC m Declaration_pl a ++ ";\n") d

instance ToC (FormalParameter PrettyPrintSemanticInfo) where
    toC m p (FormalParameter v restr) = (helper p v restr) 
        where
            helper :: Place -> Variable PrettyPrintSemanticInfo -> IsRestrict -> String
            helper MainParameter_pl (Variable r t n _) restr
                    = show_variable m MainParameter_pl r t n restr
            helper _                (Variable r t n _) restr
                    = show_variable m Declaration_pl r t n restr

instance ToC (LocalDeclaration PrettyPrintSemanticInfo) where
    toC m p (LocalDeclaration v i isDefArrSize) = (helper p v i)
        where
            helper :: Place -> Variable PrettyPrintSemanticInfo -> (Maybe (Expression PrettyPrintSemanticInfo)) -> String
            helper MainParameter_pl v i = concat [toC m MainParameter_pl v,init i]
            helper _            v i = concat [toC m Declaration_pl v,init i]
            init :: Maybe (Expression PrettyPrintSemanticInfo) -> String
            init Nothing = ""
            init (Just e) = " = " ++ toC m ValueNeed_pl e

instance ToC (Instruction PrettyPrintSemanticInfo) where
    toC m p instruction = toC m p $ instructionData instruction

instance ToC (InstructionData PrettyPrintSemanticInfo) where
    toC m p (AssignmentInstruction assignment) =
        concat [toC m ValueNeed_pl (assignmentLhs assignment)," = ",toC m ValueNeed_pl (assignmentRhs assignment),";\n"]
    toC m p (ProcedureCallInstruction procedureCall) =
        concat [nameOfProcedureToCall procedureCall,"(",
                listprint (toC m p) ", " (actualParametersOfProcedureToCall procedureCall),");\n"]

instance ToC (Program PrettyPrintSemanticInfo) where
    toC m p (Program (EmptyProgram (Empty i)) seminf) = ""
    toC m p (Program (PrimitiveProgram (Primitive i seminf)) psi) = toC m p i
    toC m p (Program (SequenceProgram (Sequence ps _)) psi) = listprint (toC m p) "" ps
    toC m p (Program (BranchProgram (Branch con tPrg ePrg _)) psi)
        = concat ["if(",toC m ValueNeed_pl con,")\n{\n", ind (toC m p) tPrg,"}\nelse\n{\n",ind (toC m p) ePrg,"}\n"]
    toC m p (Program (SequentialLoopProgram (SequentialLoop condVar condCalc loopBody _)) psi) = concat["{\n",ind id whereBody,"}\n"]
        where
            whereBody = concat [toC m p condCalc,"while(",toC m ValueNeed_pl condVar,")\n",
                                "{\n",ind (toC m p) loopBody,ind (toC m p) (blockInstructions condCalc),"}\n"]
    toC m p (Program (ParallelLoopProgram (ParallelLoop v num step prg _)) psi) = concat ["{\n",ind id for_seq,"}\n"]
        where
            for_seq = concat [toC m Declaration_pl v,";\nfor(",for_init,for_test,for_inc,")\n{\n",ind (toC m p) prg,"}\n"]
            for_init = concat [toC m ValueNeed_pl v," = 0; "]
            for_test = concat [toC m ValueNeed_pl v," < ",toC m ValueNeed_pl num,"; "]
            for_inc = concat [toC m ValueNeed_pl v," += ",show step]

instance ToC a => ToC (Maybe a) where
     toC _ p Nothing = ""
     toC m p (Just a) = toC m p a

instance (ToC a) => ToC [a] where
    toC m p xs = listprint (toC m p) "\n" xs

----------------------
--   Type           --
----------------------

class HasType a where
    typeof :: a -> Type

instance (SemanticInfo t) => HasType (Variable t) where
    typeof (Variable r t s _) = t

instance (SemanticInfo t) => HasType (LeftValue t) where
    typeof lv = typeof $ leftValueData lv 

instance (SemanticInfo t) => HasType (LeftValueData t) where
    typeof (VariableLeftValue v) = typeof v
    typeof (ArrayElemReferenceLeftValue arrayElemReference) =
        decrArrayDepth (typeof (arrayName arrayElemReference))

instance (SemanticInfo t) => HasType (Constant t) where
    typeof c = typeof $ constantData c

instance (SemanticInfo t) => HasType (ConstantData t) where
    typeof (IntConstant _) = Numeric ImpSigned S32
    typeof (FloatConstant _) = FloatType
    typeof (BoolConstant _) = BoolType
    typeof arr@(ArrayConstant l) = ImpArrayType (Norm $ length innerConstList) elemtype
        where
            elemtype = case innerConstList of
                []  -> codeGenerationError InternalError $ "Const array with 0 elements: " ++ show arr
                _   -> checktype (typeof $ head innerConstList) (map typeof innerConstList)
            innerConstList = arrayConstantValue l
            checktype :: Type -> [Type] -> Type
            checktype t [] = t
            checktype t (x:xs)
                | t == x = checktype t xs
                | otherwise = codeGenerationError InternalError $ "Different element types in constant array: " ++ show arr

instance (SemanticInfo t) => HasType (Expression t) where
    typeof e = typeof $ expressionData e

instance (SemanticInfo t) => HasType (ExpressionData t) where
    typeof (LeftValueExpression lve) = typeof lve
    typeof (ConstantExpression c) = typeof c
    typeof (FunctionCallExpression functionCallExpression) = typeOfFunctionToCall functionCallExpression

instance (SemanticInfo t) => HasType (ActualParameter t) where
    typeof ap = typeof $ actualParameterData ap

instance (SemanticInfo t) => HasType (ActualParameterData t) where
    typeof (InputActualParameter e) = typeof e
    typeof (OutputActualParameter l) = typeof l

----------------------
-- Helper functions --
----------------------

ind :: (a-> String) -> a -> String
ind f x = unlines $ map (\a -> "    " ++ a) $ lines $ f x

listprint :: (a->String) -> String -> [a] -> String
listprint f s xs = listprint' s $ filter (\a -> a /= "")$ map f xs where
    listprint' _ [] = ""
    listprint' _ [x] = x
    listprint' s (x:y:xs) = x ++ s ++ listprint' s (y:xs)

parameterToExpression :: (SemanticInfo t) => ActualParameter t -> Expression t
parameterToExpression (ActualParameter (InputActualParameter e) _) = e
parameterToExpression (ActualParameter (OutputActualParameter lv) _) = Expression (LeftValueExpression lv) undefined -- TODO undefined

decrArrayDepth :: Type -> Type
decrArrayDepth (ImpArrayType _ t) = t
decrArrayDepth _ = codeGenerationError InternalError $ "A variable is indexed, but not array!"

simpleType :: Type -> Bool
simpleType BoolType = True
simpleType (Numeric _ _) = True
simpleType FloatType = True
simpleType (ImpArrayType _ _) = False
simpleType (UserType _) = True

toLeftValue :: (SemanticInfo t) => Expression t -> LeftValue t
toLeftValue (Expression (LeftValueExpression lv) _) = lv
toLeftValue e = codeGenerationError InternalError $ show e ++ " is not a left value."

contains :: (SemanticInfo t) => String -> Expression t -> Bool
contains n (Expression (LeftValueExpression lv) _) = contains' n (leftValueData lv) where
    contains' n (VariableLeftValue (Variable _ _ n' _) ) = n == n'
    contains' n (ArrayElemReferenceLeftValue arrayElemReference) = contains' n (leftValueData $ arrayName arrayElemReference) ||
                                                                contains n (arrayIndex arrayElemReference)
contains _ (Expression (ConstantExpression _) _) = False
contains n (Expression (FunctionCallExpression functionCallExpression) _)=
    any (contains n) (actualParametersOfFunctionToCall functionCallExpression)

getVarName :: (SemanticInfo t) => LeftValue t -> String
getVarName (LeftValue (VariableLeftValue ( Variable _ _ n _ )) _) = n
getVarName (LeftValue (ArrayElemReferenceLeftValue arrayElemReference) _) = getVarName (arrayName arrayElemReference)