{-
 - Copyright (c) 2009, ERICSSON AB All rights reserved.
 - 
 - Redistribution and use in source and binary forms, with or without
 - modification, are permitted provided that the following conditions
 - are met:
 - 
 -     * Redistributions of source code must retain the above copyright
 -     notice,
 -       this list of conditions and the following disclaimer.
 -     * Redistributions in binary form must reproduce the above copyright
 -       notice, this list of conditions and the following disclaimer
 -       in the documentation and/or other materials provided with the
 -       distribution.
 -     * Neither the name of the ERICSSON AB nor the names of its
 -     contributors
 -       may be used to endorse or promote products derived from this
 -       software without specific prior written permission.
 - 
 - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 -}

module Feldspar.Compiler.Transformation.GraphToImperative where

import Feldspar.Core.Graph
import Feldspar.Core.Types hiding (typeOf)
import Feldspar.Compiler.Imperative.Representation hiding (Array)
import Feldspar.Compiler.Transformation.GraphUtils
import Data.List
import qualified Data.Map as Map

-- Transforms a hierarchical graph to a list of imperative functions.
    -- collect sources for each function
    -- compile each of them
    -- put the results in a list
graphToImperative :: String -> HierarchicalGraph -> [ImpFunction]
graphToImperative s g = map transformSourceToImpFunction sources where
    sources = this : collectSources g
    this    = ImpFunctionSource
            { functionName  = s
            , interface     = hierGraphInterface g
            , hierarchy     = graphHierarchy g
            }

-- A datastructure to represent all data needed for transformation to an
-- imperative function.
data ImpFunctionSource
    = ImpFunctionSource
    { functionName  :: String
    , interface     :: Interface
    , hierarchy     :: Hierarchy
    }

-- Just for debugging purposes:
instance Show ImpFunctionSource where
    show (ImpFunctionSource s _ _) = s

-- 'collectSources' walks thorugh the graph and collects the interfaces
-- and hierarchies of 'NoInline' nodes.
class Collect t where
    collectSources :: t -> [ImpFunctionSource]

instance Collect HierarchicalGraph where
    collectSources g    = collectSources $ graphHierarchy g

instance Collect Hierarchy where
    collectSources (Hierarchy xs)   = collectSources xs

instance (Collect t) => Collect [t] where
    collectSources xs   = concatMap collectSources xs

instance Collect (Node,[Hierarchy]) where
    collectSources (n,hs) = this ++ collectSources hs where
        this = case function n of
            NoInline name interface -> case hs of
                [hierarchy] -> [ImpFunctionSource name interface hierarchy]
                _           -> error $ "Graph error: malformed hierarchy list in the 'NoInline' node with id " ++ show (nodeId n)
            _ -> []

-- Transforms an interface and a hierarchy to an imperative function.
    -- transform top level nodes to declarations
    -- split the declarations into 'input' and 'local' groups
    -- generate output parameters
    -- transform each top-level node to a Program
transformSourceToImpFunction :: ImpFunctionSource -> ImpFunction
transformSourceToImpFunction (ImpFunctionSource n ifc (Hierarchy pairs))
    = Fun
    { funName       = n
    , inParameters  = inputDecls
    , outParameters = outputDecls
    , prg
        = CompPrg
        { locals    = localDecls
        , body      = Seq ( map transformNodeToProgram pairs
                         ++ copyToOutput (interfaceOutput ifc) (interfaceOutputType ifc) True) []
        }
    } where
        (inputDecls, localDecls) = partition isInputDecl declarations where
            isInputDecl d = isPrefixOf (varPrefix $ interfaceInput ifc) (name $ var d)
        outputDecls = tupleWalk transformSourceToDecl $ interfaceOutputType ifc
        transformSourceToDecl path typ
            = Decl
            { var       = Var (outName path) OutKind ctyp
            , declType  = ctyp
            , initVal   = Nothing
            , semInfVar = unknownSemInfVar
            } where
                ctyp = compileStorableType typ
        declarations = concatMap transformNodeToDeclaration topLevelNodes
        topLevelNodes = map fst pairs

-- Transforms a node to declarations. The number of generated declarations is
-- determined by the tuple leafs of the tuple structure in the node type.
    -- walk through the tuple structure in the node type
    -- variable name: "var" ++ 'node id' ++ 'path in the tuple structure'
    -- variable type: type of the leaf in the structure
transformNodeToDeclaration :: Node -> [Declaration]
transformNodeToDeclaration n = tupleWalk genDecl $ tupleZip (outTyps,initVals) where
    genDecl path (typ,ini)
        = Decl
        { var       = Var (varPrefix (nodeId n) ++ varPath path) Normal ctyp
        , declType  = ctyp
        , initVal   = ini
        , semInfVar = unknownSemInfVar
        } where
            ctyp = compileStorableType typ
    outTyps = outputType n
    initVals = case function n of
        Array d     -> case outTyps of
            One t -> One $ Just $ compileStorableData d t
            _       -> error "Error: malformed output type of array node."
{-        While ifc1 ifc2 -> fmap (\(d,t) -> Just $ transformSourceToExpr d t) $ tupleZip (input n, outTyps)
            initPart = case input n of
                Tup [cond,ini] -> ini
                _ -> error "Error in while loop: malformed input."
-}
        otherwise   -> genNothingTuple outTyps
    genNothingTuple (One _) = One Nothing
    genNothingTuple (Tup xs) = Tup $ map genNothingTuple xs
     

transformNodeListToDeclarations :: [Node] -> [Declaration]
transformNodeListToDeclarations ns = concatMap transformNodeToDeclaration ns

-- Transforms a node and its subgraphs (if any) to an imperative program.
transformNodeToProgram :: (Node, [Hierarchy]) -> Program
transformNodeToProgram (n,hs) = case function n of
    Input           -> Empty
    Array _         -> Empty
    Function s      -> Primitive
                            (CFun s $ passInArgs (input n) (inputType n) ++ passOutArgs (nodeId n) (outputType n))
                            (SemInfPrim Map.empty False)
    -- non-inlined function node:
        -- call the non-inlined function
        -- actual arguments come from the node input and the node id
    NoInline s ifc  -> Primitive
                            (CFun s $ passInArgs (input n) (inputType n) ++ passOutArgs (nodeId n) (outputType n))
                            (SemInfPrim Map.empty False)
    -- conditional node:
        -- condition: first element of the input tuple
        -- then branch: compiled from the first interface and the first hierarchy
        -- else branch: compiled from the second interface and the second hierarchy
    Feldspar.Core.Graph.IfThenElse thenIfc elseIfc -> case hs of
        [thenH, elseH] -> case (input n, inputType n) of
            (Tup [cond, inp], Tup [One condTyp, inTyp])
                | interfaceInputType thenIfc /= inTyp || interfaceInputType elseIfc /= inTyp
                    -> error "Error in 'ifThenElse' node: incorrect interface input type."
                | compileStorableType condTyp /= Feldspar.Compiler.Imperative.Representation.BoolType
                    -> error "Error in 'ifThenElse' node: node output is expected to be 'Bool'."
                | otherwise -> Feldspar.Compiler.Imperative.Representation.IfThenElse
                    condVar                         -- condition variable
                    (mkBranch n thenIfc thenH)      -- then part
                    (mkBranch n elseIfc elseH)      -- else part
                    []                              -- semantic info
                        where
                            mkBranch :: Node -> Interface -> Hierarchy -> CompleteProgram
                            mkBranch n ifc h@(Hierarchy pairs) = CompPrg
                                (transformNodeListToDeclarations $ map fst pairs)
                                (Seq (copyResult inp (interfaceInput ifc) inTyp False
                                     ++ transformNodeListToPrograms pairs
                                     ++ copyResult (interfaceOutput ifc) (nodeId n) (outputType n) True)
                                 [])
                            condVar = case cond of
                                One (Variable (id,path)) -> Var (varName id path) Normal Feldspar.Compiler.Imperative.Representation.BoolType
                                _ -> error "Error in 'ifThenElse' node: condition is not a variable."
                                    -- TODO: it seems that in case of constant condition the program is already simplified on the graph level
            otherwise -> error $ "Error in 'ifThenElse' node: incorrect node input or node input type"
        otherwise -> error $ "Error in 'ifThenElse' node: two hierarchies expected, found " ++ show (length hs)
    -- while node:
        -- state variables: id of the while node
        -- condition calculation: first interface and hierarchy
            -- input gets the state
        -- condition: output of condition calculation
        -- body: second interface and hierarchy
            -- input gets the state
            -- output is written back to the state
    While condIfc bodyIfc   -> Seq
        (copyResult (input n) (nodeId n) (outputType n) False ++
        [SeqLoop
            -- condition variable:
            (case interfaceOutput condIfc of
                One (Variable (id,path)) -> Var (varName id path) Normal Feldspar.Compiler.Imperative.Representation.BoolType
                _ -> error "Error in a while loop: Malformed interface output of condition calculation." 
                    -- TODO: should this hold?
            )
            -- condition calculation (CompleteProgram)
            (CompPrg
                (transformNodeListToDeclarations condNodes)
                (Seq (copyStateToCond ++ calculationCond) [])
            )
            -- loop body (CompleteProgram)
            (CompPrg
                (transformNodeListToDeclarations bodyNodes)
                (Seq (copyStateToBody ++ calculationBody ++ copyResultToState) [])
            )
            -- semantic info (SemInfSeqLoop)
            []
        ]) [] where
                (Hierarchy condHier, Hierarchy bodyHier) = case hs of
                    [c,b]   -> (c,b)
                    _       -> error $ "Error in a while node: expected 2 hierarchies, but found " ++ show (length hs)
                condNodes = map fst condHier
                bodyNodes = map fst bodyHier
                copyStateToCond = copyNode (nodeId n) (interfaceInput condIfc) (outputType n) False
                calculationCond = transformNodeListToPrograms condHier
                copyStateToBody = copyNode (nodeId n) (interfaceInput bodyIfc) (outputType n) False
                calculationBody = transformNodeListToPrograms bodyHier
                copyResultToState = copyResult (interfaceOutput bodyIfc) (nodeId n) (outputType n) True
                -- initState = tupleWalk genInitCopy tupleZip (input n, outputType n)
                -- genInitCopy path (i,t) =
    -- parallel node:
        -- number of iterations: first parameter of 'Parallel' constructor
            -- (vs. input of the node, may change later)
        -- index variable: input node of the embedded graph
        -- body: embedded graph and its interface
    Parallel _ ifc  ->
        ParLoop (Var (varName inpId []) Normal $ Numeric ImpSigned S32) num 1 prg []  where
            num = case (input n, inputType n) of
                (One inp, One intyp)    -> transformSourceToExpr inp intyp
                otherwise               -> error "Invalid input of a Parallel node."
            hist = case hs of
                [(Hierarchy hist)] -> hist
                _                  -> error "More than one Hierarchy in a Parallel construct"  
            isInp (node,hs) = case (function node) of
                Input -> True
                _     -> False
            (inps,notInps) = partition isInp hist
            inpId = case inps of
                [(node,hs)] -> nodeId node
                _           -> error "More than one input node inside the Hierarchy of a Parallel construct" 
            topLevelNodes = map fst notInps 
            declarations = concatMap transformNodeToDeclaration topLevelNodes
            outSrc = case interfaceOutput ifc of
                One src -> src
                _       -> error "The interfaceOutput of a Parallel is not (One ...) "
            outTyp = case interfaceOutputType ifc of
                One typ -> typ
                _       -> error "The interfaceOutputType of a Parallel is not (One ...) "
            prg = CompPrg
                { locals = declarations
                , body   = Seq  ( map transformNodeToProgram notInps ++
                                  [ Primitive ( makeCopyFromExprs
                                                    (transformSourceToExpr outSrc outTyp)
                                                    (Expr (LeftExpr $ ArrayElem (LVar (Var (varName (nodeId n) []) Normal intType)) (Expr (genVar inpId [] intType) intType)) intType) -- TODO: fix the type
                                              )
                                    (SemInfPrim Map.empty True)
                                  ]
                                ) []
{-
                                  [ Primitive 
                                    (Assign
                                        (ArrayElem (LVar (Var (varName (nodeId n) []))) (Expr (genVar inpId []) intType)) 
                                        (transformSourceToExpr outSrc outTyp)
                                    ) (SemInfPrim Map.empty True)
                                  ]
                                ) []
-}
                }

transformNodeListToPrograms :: [(Node, [Hierarchy])] -> [Program]
transformNodeListToPrograms pairs = map transformNodeToProgram pairs

-- Generates the common prefix of variables belonging to the given node id.
varPrefix :: NodeId -> String
varPrefix id = "var" ++ show id

-- Generates a variable's id list that describes the variable's location
-- inside the nodes it belongs to.
varPath :: [Int] -> String
varPath path = concatMap (\id -> '_' : show id) path

-- Generates a variable from its id and location.
varName :: NodeId -> [Int] -> String
varName id path = varPrefix id ++ varPath path

-- Generates a variable
genVar :: NodeId -> [Int] -> Type -> UntypedExpression
genVar id path typ = LeftExpr $ LVar $ Var (varName id path) Normal typ

-- Prefix of output parameters
outPrefix :: String
outPrefix = "out"

-- Generaes the name of an output parameter
outName :: [Int] -> String
outName path = outPrefix ++ varPath path

-- Generates an output variable
genOut :: [Int] -> Type -> UntypedExpression
genOut path typ = LeftExpr $ LVar $ Var (outName path) OutKind typ

-- Generates input parameters of a function call from the node input.
passInArgs :: Tuple Source -> Tuple StorableType -> [Parameter]
passInArgs tup typs = tupleWalk genArg $ tupleZip (tup,typs) where
    genArg _ (Constant primData, StorableType _ typ) = In $ compilePrimData primData typ
    genArg _ (Variable (id, path), typ) = In $ Expr (genVar id path ctyp) $ ctyp
        where
            ctyp = compileStorableType typ

-- Generates output parameters of a function call from the node id and output type.
passOutArgs :: NodeId -> Tuple StorableType -> [Parameter]
passOutArgs id typs = tupleWalk genArg typs where
    genArg path t = Out (Normal,Expr (genVar id path ctyp) $ ctyp)
        where
            ctyp = compileStorableType t

-------------------------------------------------
-- Compilation of type and data representation --
-------------------------------------------------

-- Transforms a 'StorableType' to an imperative 'Type'
compileStorableType :: StorableType -> Type
compileStorableType (StorableType dims elemTyp) = case dims of
    []      -> compilePrimitiveType elemTyp
    (d:ds)  -> ImpArrayType (Just d) $ compileStorableType $ StorableType ds elemTyp

-- Transforms a 'PrimitiveType' to an imperative 'Type'
compilePrimitiveType :: PrimitiveType -> Type
compilePrimitiveType typ = case typ of
    UnitType    -> Feldspar.Compiler.Imperative.Representation.BoolType
    Feldspar.Core.Types.BoolType
                -> Feldspar.Compiler.Imperative.Representation.BoolType
    IntType     -> Numeric ImpSigned S32
    Feldspar.Core.Types.FloatType
                -> Feldspar.Compiler.Imperative.Representation.FloatType    -- TODO: think about the imperative typesystem!

-- Transforms an array or primitive data to an imperative constant.
compileStorableDataToConst :: StorableData -> Constant
compileStorableDataToConst (PrimitiveData pd) = compilePrimDataToConst pd
compileStorableDataToConst (StorableData len ds) = ArrayConst len $ map compileStorableDataToConst ds

-- Transforms a primitive data to an imperative constant.
compilePrimDataToConst :: PrimitiveData -> Constant
compilePrimDataToConst UnitData = BoolConst False
compilePrimDataToConst (BoolData x) = BoolConst x
compilePrimDataToConst (IntData x) = IntConst x
compilePrimDataToConst (FloatData x) = FloatConst x   -- TODO

-- Transforms an array or primitive data to an imperative typed expression.
compileStorableData :: StorableData -> StorableType -> ImpLangExpr
compileStorableData (PrimitiveData pd) (StorableType _ elemTyp) = compilePrimData pd elemTyp
compileStorableData a@(StorableData len ds) typ = Expr (ConstExpr $ compileStorableDataToConst a) $ compileStorableType typ

-- Transforms a primitive data to an imperative typed expression.
compilePrimData :: PrimitiveData -> PrimitiveType -> ImpLangExpr
compilePrimData d t = Expr (ConstExpr $ compilePrimDataToConst d) $ compilePrimitiveType t

charType = Numeric ImpSigned S8
intType = Numeric ImpSigned S32

-- Transforms a Source to an imperative expression.
transformSourceToExpr :: Source -> StorableType -> ImpLangExpr
transformSourceToExpr (Constant primData) (StorableType _ typ) = compilePrimData primData typ
transformSourceToExpr (Variable (id,path)) typ = Expr (genVar id path ctyp) $ ctyp
    where
        ctyp = compileStorableType typ

-- Generates a copy call from variable ids and types.
makeCopyFromIds :: (NodeId,[Int],StorableType) -> (NodeId,[Int],StorableType) -> Instruction
makeCopyFromIds (idFrom,pathFrom,typeFrom) (idTo,pathTo,typeTo) =
    makeCopyFromExprs
        (Expr (genVar idFrom pathFrom ctypFrom) ctypFrom)
        (Expr (genVar idTo pathTo ctypTo) ctypTo)
            where
                ctypTo = compileStorableType typeTo
                ctypFrom = compileStorableType typeFrom

-- Generates a copy call from two expressions.
makeCopyFromExprs :: ImpLangExpr -> ImpLangExpr -> Instruction
makeCopyFromExprs from to = CFun "copy" [In from, Out (Normal,to)]

-- Generates copies for all variables of a node to all variables of another node.
copyNode :: NodeId -> NodeId -> Tuple StorableType -> Bool -> [Program]
copyNode fromId toId typeStructure isOutputCopying =
    tupleWalk
        (\path typ -> 
            Primitive
                (makeCopyFromIds (fromId,path,typ) (toId,path,typ))
                (SemInfPrim Map.empty isOutputCopying)
        )
        typeStructure

-- Generates copies from sources to all variables of a node.
copyResult :: Tuple Source -> NodeId -> Tuple StorableType -> Bool -> [Program]
copyResult ifcOut nid outTyp isOutputCopying =
    tupleWalk
        (\path (out,typ) ->
            Primitive
                (makeCopyFromExprs (transformSourceToExpr out typ) (Expr (genVar nid path $ compileStorableType typ) $ compileStorableType typ))
                (SemInfPrim Map.empty isOutputCopying)
                --TODO: ctyp = compileStorableType typ
        )
        (tupleZip (ifcOut, outTyp))

-- Generates copies from sources to output variables.
copyToOutput :: Tuple Source -> Tuple StorableType -> Bool -> [Program]
copyToOutput ifcOut outTyp isOutputCopying =
    tupleWalk
        (\path (out,typ) ->
            Primitive
                (makeCopyFromExprs (transformSourceToExpr out typ) (Expr (genOut path $ compileStorableType typ) $ compileStorableType typ))
                (SemInfPrim Map.empty isOutputCopying)
                -- TODO : ctyp
        )
        (tupleZip (ifcOut, outTyp))