--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE TypeFamilies #-}

module Feldspar.Compiler.Plugins.HandlePrimitives
    ( HandlePrimitives(..)
    , makeAssignment
    , makePrimitive
    ) where


import Data.List (find)

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Imperative.Semantics (SemanticInfo)
import Feldspar.Compiler.Imperative.CodeGeneration (simpleType, typeof, toLeftValue)
import Feldspar.Compiler.PluginArchitecture (TransformationPhase(..), Plugin(..), InfoFromPrimitiveParts(..), InfoFromProcedureParts(..))
import Feldspar.Compiler.PluginArchitecture.DefaultConvert (Combine(..))
import Feldspar.Compiler.Options
import Feldspar.Compiler.Error


handlePrimitivesError = handleError "PluginArch/HandlePrimitives" InternalError


data HandleTraceFunctions = HandleTraceFunctions


instance TransformationPhase HandleTraceFunctions where
    type From HandleTraceFunctions = ()
    type To HandleTraceFunctions = ()
    type Downwards HandleTraceFunctions = ()
    type Upwards HandleTraceFunctions = Bool
    upwardsPrimitiveProgramInProgram = upwardsPrimitiveProgramInProgram'
    transformProcedure = transformProcedure'


data HandlePrimitives = HandlePrimitives


instance TransformationPhase HandlePrimitives where
    type From HandlePrimitives = ()
    type To HandlePrimitives = ()
    type Downwards HandlePrimitives = (Int,Platform)
    type Upwards HandlePrimitives = ()
    transformPrimitiveProgramInProgram = transformPrimitive'


instance Plugin HandlePrimitives where
    type ExternalInfo HandlePrimitives = (Int, DebugOption, Platform)
    executePlugin _ (_,NoPrimitiveInstructionHandling,_) procedure = procedure
    executePlugin _ (defArrSize,_,platform) procedure
        = fst $ executeTransformationPhase HandlePrimitives (defArrSize,platform)
        $ fst $ executeTransformationPhase HandleTraceFunctions () procedure



instance Combine Bool where
    combine x y = or [x,y]
        
        

upwardsPrimitiveProgramInProgram' :: HandleTraceFunctions -> () -> Primitive () -> InfoFromPrimitiveParts HandleTraceFunctions -> ProgramConstruction () -> Bool
upwardsPrimitiveProgramInProgram' _ _ old _ _ |nameS=="trace"  = True
                                              | otherwise       = False
  where
    nameS = nameOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData $ primitiveInstruction old



transformProcedure' :: HandleTraceFunctions -> () -> Procedure () -> InfoFromProcedureParts HandleTraceFunctions -> Procedure ()
transformProcedure' _ _ old upwardInfos
    | containsTrace = old{ 
                        procedureBody = (procedureBody old){ 
                          blockInstructions = (blockInstructions $ procedureBody old){ 
                            programConstruction = addTraceSE $programConstruction $ blockInstructions $ procedureBody old } } }
    | otherwise = old
  where
    containsTrace = upwardsInfoFromProcedureBody upwardInfos
    addTraceSE (SequenceProgram sequ) = SequenceProgram sequ{ sequenceProgramList = [traceStart] ++ (sequenceProgramList sequ) ++ [traceEnd] }
    addTraceSE _                      = SequenceProgram (Sequence [traceStart, blockInstructions $ procedureBody old, traceEnd] ())
    traceStart = Program (PrimitiveProgram $ Primitive (Instruction (ProcedureCallInstruction $ ProcedureCall "traceStart" [] ()) ()) ()) ()
    traceEnd = Program (PrimitiveProgram $ Primitive (Instruction (ProcedureCallInstruction $ ProcedureCall "traceEnd" [] ()) ()) ()) ()



transformPrimitive' :: HandlePrimitives -> (Int,Platform) -> Primitive () -> InfoFromPrimitiveParts HandlePrimitives -> ProgramConstruction ()
transformPrimitive' _ (defArrSize,pfm) old modified'
    = case (nameS, inps, outs) of
        
        ("(!)", [arr, idx], [out])
            -> mkPrg $ makeAssignment pfm
                (lToe $ LeftValue
                    (ArrayElemReferenceLeftValue $ ArrayElemReference
                        (toLeftValue arr) idx ()
                    ) ()
                ) out defArrSize
        
        ("setIx", [original, idx, val], [result])
            -> SequenceProgram $ Sequence 
                [ Program (mkPrg $ makeAssignment pfm original result defArrSize) ()
                , Program (mkPrg $ makeAssignment pfm val
                        (LeftValue (ArrayElemReferenceLeftValue $ ArrayElemReference result  idx ()) ())
                        defArrSize
                    ) ()
                ] ()
        
        ("copy", [in1], [out]) -> mkPrg $ makeAssignment pfm in1 out defArrSize
        
        ("trace", [label, original], [result])
            -> SequenceProgram $ Sequence 
                [ Program (mkPrg $ makeAssignment pfm original result defArrSize) ()
                , Program (mkPrg $ makePrimitive pfm (Proc "trace" firstInFP) [lToe result, label] [] 0) ()
                ] ()
        
        _ -> case (find matchPrimitive $ primitives pfm) of
              Just (fd,Right tp) -> SequenceProgram (Sequence plist ())
                where
                  plist = map (\(cd',inps',outs') -> Program (mkPrg $ makePrimitive pfm cd' inps' outs' 0) ()) $ tp fd inps outs
              Just (fd,Left cd)  -> mkPrg $ makePrimitive pfm cd inps outs defArrSize
              Nothing            -> mkPrg $ modified
        
  where
    nameS = nameOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData $ primitiveInstruction old
    as = actualParametersOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData modified
    modified = recursivelyTransformedPrimitiveInstruction modified'
    mkPrg x = PrimitiveProgram (Primitive x ())
    
    inps = map aToE $ filter isInparam as
    outs = map aToL $ filter (not . isInparam) as
    
    matchPrimitive (fd,_) = (fName fd == nameS) && (matchTypes' (inputs fd) inps)
    
    matchTypes' :: [TypeDesc] -> [Expression ()] -> Bool
    matchTypes' [] []     = True
    matchTypes' [] (y:ys) = False
    matchTypes' (x:xs) [] = False
    matchTypes' (x:xs) (y:ys) = (machTypes x $ typeof y) && (matchTypes' xs ys)



makeAssignment :: Platform -> Expression () -> LeftValue () -> Int -> Instruction ()
makeAssignment pfm in1 out defArrSize = makePrimitive pfm Assig [in1] [out] defArrSize



makePrimitive :: Platform -> CPrimDesc -> [Expression ()] -> [LeftValue ()] -> Int -> Instruction ()

makePrimitive pfm Assig [in1] [out] defArrSize
    | simpleType (typeof in1) = Instruction (AssignmentInstruction $ Assignment out in1 ()) ()
    | otherwise = case (typeof in1) of
        (ImpArrayType _ t) -> makePrimitive pfm (Proc "copy" firstInFP) [in1, intToCe$ arraySize (typeof in1) defArrSize] [out] 0
        _                  -> handlePrimitivesError $ "Unknown type in makePrimitive:\n" ++ show (typeof in1)

makePrimitive pfm Assig _ _ _ = handlePrimitivesError $ "Wrong number of parameters for an assignment. (Parallel assignment not allowed.)"

makePrimitive pfm desc inps outs _
    | isNotProc desc && simpleType (typeof $ head outs)
                = Instruction (AssignmentInstruction $ Assignment (head outs) (Expression (FunctionCallExpression funCall) ()) ()) ()
    | otherwise = Instruction (ProcedureCallInstruction procCall) ()
  where
    
    funCall = case (desc, length inps, length outs) of
        (Op1 op, 1, 1) -> FunctionCall PrefixOp (typeof $ head outs) op inps ()
        (Op2 op, 2, 1) -> FunctionCall InfixOp (typeof $ head outs) op inps ()
        (Fun _ _, _, 1)   -> FunctionCall SimpleFun (typeof $ head outs) completeFunName inps ()
        _                 -> errorMessage
    
    procCall = case (desc, length inps, length outs) of
        (Fun _ _, _, 1)   -> procCall'
        (Proc _ _, _, _)  -> procCall'
        _                 -> errorMessage
    
    procCall' = ProcedureCall completeProcName (inps' ++ outs') ()
    inps' = map eToA inps
    outs' = map lToA outs
    completeFunName | funPf desc == noneFP  = cName desc
                    | otherwise             = cName desc ++ "_fun" ++ apsToName
    completeProcName  | funPf desc == noneFP  = cName desc
                      | otherwise             = cName desc ++ apsToName
    apsToName = concatMap (("_"++) . (toFunName pfm) . typeof) apsToNameList
    apsToNameList = (take (useInputs $ funPf desc) inps') ++ (take (useOutputs $ funPf desc) outs')
    isNotProc (Proc _ _)  = False
    isNotProc _           = True
    errorMessage = handlePrimitivesError $ "Wrong C pirmitive description or different number of parameter:\n"
                                                  ++ show desc ++ "\n" ++ concatMap ((", "++) . show . typeof) inps



toFunName :: Platform -> Type -> String
toFunName pfm (ImpArrayType _ t@(ImpArrayType _ _)) = toFunName pfm t
toFunName pfm (ImpArrayType _ t)                    = "arrayOf_" ++ toFunName pfm t
toFunName pfm t = case (find (\(t',_,_) -> t == t') $ types pfm) of
    Just (_,_,s)  -> map (\c -> if c == ' ' then '_' else c) $ s
    Nothing       -> handlePrimitivesError $ "Unhandled type in platform " ++ name pfm



arraySize :: Type -> Int -> Int
arraySize a@(ImpArrayType _ t) defaultArraySize = arraySize' a
  where
    arraySize' (ImpArrayType (Norm n) t) = n * arraySize' t
    arraySize' (ImpArrayType (Defined n) t) = n * arraySize' t
    arraySize' (ImpArrayType Undefined t) = defaultArraySize * arraySize' t
    arraySize' _ = 1



isInparam (ActualParameter (InputActualParameter _) _)  = True
isInparam (ActualParameter (OutputActualParameter _) _) = False



aToE (ActualParameter (InputActualParameter x) ())  = x
aToL (ActualParameter (OutputActualParameter x) ()) = x

-- adToE (InputActualParameter x)  = x
-- adToL (OutputActualParameter x) = x

eToA x = ActualParameter (InputActualParameter x) ()
lToA x = ActualParameter (OutputActualParameter x) ()

-- adToA x = ActualParameter x ()

-- ceToInt (Expression (ConstantExpression (Constant (IntConstant (IntConstantType x _)) _)) _) = x
intToCe x = Expression (ConstantExpression $ Constant (IntConstant $ IntConstantType x ()) ()) ()

lToe x = Expression (LeftValueExpression x) ()