{-# LANGUAGE FlexibleInstances, TypeFamilies #-}

module Feldspar.Compiler.Backend.C.Plugin.HandlePrimitives
    ( HandlePrimitives(..)
    , completeFunProcName
    ) where

import Data.List (find)
import Data.Maybe (fromJust)

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Backend.C.CodeGeneration (typeof, defaultMemberName)
import Feldspar.Transformation
import Feldspar.Compiler.Backend.C.Options
import Feldspar.Compiler.Error

handlePrimitivesError = handleError "PluginArch/HandlePrimitives" InternalError

data HandleTraceFunctions = HandleTraceFunctions

instance Default Bool where
    def = False

instance Combine Bool where
    combine x y = or [x,y]

instance Transformation HandleTraceFunctions where
    type From HandleTraceFunctions      = ()
    type To HandleTraceFunctions        = ()
    type Down HandleTraceFunctions      = ()
    type Up HandleTraceFunctions        = Bool
    type State HandleTraceFunctions     = ()

instance Transformable HandleTraceFunctions Expression where
    transform t s d p = tr { up = combine u' (up tr) } where
        tr = defaultTransform t s d p
        u' = case p of
            FunctionCall "trace" _ _ _ _ _ -> True
            _ -> False

instance Transformable HandleTraceFunctions Definition where
    transform t s d p@(Procedure n i o b _ _) = case up tr of
        False -> tr
        True -> tr
                { result = (result tr)
                    { procBody = (procBody $ result tr)
                        { blockBody = addTraceSE $ blockBody $ procBody $ result tr
                        }
                    }
                } 
        where
            tr = defaultTransform t s d p
            addTraceSE sequ@(Sequence _ _ _) = sequ { sequenceProgs = [traceStart] ++ (sequenceProgs sequ) ++ [traceEnd] }
            addTraceSE p               = Sequence [traceStart, p, traceEnd] () ()
            traceStart = ProcedureCall "traceStart" [] () ()
            traceEnd   = ProcedureCall "traceEnd" [] () ()
    transform t s d p = defaultTransform t s d p
data HandlePrimitives = HandlePrimitives


instance Transformation HandlePrimitives where
    type From HandlePrimitives      = ()
    type To HandlePrimitives        = ()
    type Down HandlePrimitives = (Int, Platform, Maybe (Expression ()))
    type Up HandlePrimitives   = ([Declaration ()], [Program ()])
    type State HandlePrimitives = Int


instance Plugin HandlePrimitives where
    type ExternalInfo HandlePrimitives = (Int, DebugOption, Platform)
    executePlugin _ (_,NoPrimitiveInstructionHandling,_) procedure = procedure
    executePlugin _ (defArrSize,_,platform) procedure
        = result $ transform HandlePrimitives 0 (defArrSize, platform, Nothing) $
          result $ transform HandleTraceFunctions ({-state-}) ({-down-}) procedure

instance Combine ([Declaration ()], [Program ()]) where
    combine (xl, xi) (yl, yi) = (xl ++ yl, xi ++ yi)

instance Default [Declaration ()] where
    def = []

instance Default [Program ()] where
    def = []



instance Transformable HandlePrimitives Block where
        transform t s d b = tr
            { result = addToBlock (result tr) (up tr)
            , up = ([],[])
            } where
            tr = case (up tr') of
                (_,[])  -> tr'
                _       -> handlePrimitivesError $ "transform Block: upwards program list is not empty."
            tr' = defaultTransform t s d b


instance Transformable HandlePrimitives Program where
    transform t s d p@(ProcedureCall "copy" [o@(Out out _), i@(In inp _)] _ _) = case typeof out of
        (ArrayType _ _) -> Result (ProcedureCall "copyArray" [out', inp'] () ()) arrS' arrU'
        _               -> Result (Assign lhs rhs () ()) assS' assU'
      where
        (Result out' arrS arrU1) = transform t s d o
        (Result inp' arrS' arrU2) = transform t arrS d i
        arrU' = arrU1 `combine` arrU2
        (Result lhs assS assU1) = transform t s d out
        (Result rhs assS' assU2) = transform t assS d inp
        assU' = assU1 `combine` assU2
    -- transform t s d@(das, pfm, _) p@(ProcedureCall "copy" [Out out _, In _ _] _ _) = tr { result = makeAssignment (das, pfm) inp' out' } where
        -- tr = case out of
            -- e@(VarExpr v _)       -> defaultTransform t s (das, pfm, Just e) p
            -- e@(ArrayElem _ _ _ _) -> defaultTransform t s (das, pfm, Just e) p
            -- e@(StructField _ _ _ _) -> defaultTransform t s (das, pfm, Just e) p
            -- _                     -> defaultTransform t s (das, pfm, Nothing) p
        -- inp' = aToE $ head $ filter isInparam $ procCallParams $ result tr
        -- out' = aToE $ head $ filter (not . isInparam) $ procCallParams  $ result tr
    transform t s d (SeqLoop c cc p inf1 inf2) = Result (SeqLoop (result tr1) cc' (result tr3) (convert inf1) $ convert inf2) (state tr3) ([],[]) where
            tr1 = transform t s d c
            tr2 = transform t (state tr1) d cc
            tr3 = transform t (state tr2) d p
            cc' = addToBlock (result tr2) (up tr1)
    transform t s d p =  defaultTransform t s d p


instance Transformable1 HandlePrimitives [] Program where
        transform1 t s d [] = Result1 [] s def
        transform1 t s d (x:xs) = Result1 (snd (up tr1) ++ [result tr1] ++ (result1 tr2)) (state1 tr2) (concatMap fst [up tr1,up1 tr2],[]) where
            tr1 = transform t s d x
            tr2 = transform1 t (state tr1) d xs


instance Transformable HandlePrimitives Declaration where
        transform t s d@(das, pfm, _) (Declaration v i inf) = Result (Declaration (result tr1) i' $ convert inf) (state1 tr2) u' where
            tr1 = transform t s d v
            tr2 = transform1 t (state tr1) d i
            (i',u') = case (up1 tr2) of
                u@(ls,[]) -> (result1 tr2, combine (up tr1) u)
                (ls, is)  -> (Nothing, (ls, is ++ [makeAssignment (das, pfm) (fromJust $ result1 tr2) (vToE $ result tr1)]))

instance Transformable HandlePrimitives Expression where
        transform t s d@(das, pfm, me) f@(FunctionCall nameS ot origRole origInps _ _) = res
          where
                res = case (nameS, origInps) of
                    ("getFst", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) fs
                    ("getSnd", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) sn
                    _ -> Result e' s' $ combine (up tr) (l',p')
                tr = defaultTransform t s (das, pfm, Nothing) f
                s2 = state tr
                (s',l',p',e') = case (nameS, inps, me) of
                    ("(!)", [arr, idx], _)    -> (s2, [], [], ArrayElem arr idx () ())
                    ("setIx", [arr, idx, val], _) -> (s2
                                                , []
                                                , [ makeAssignment d' val (ArrayElem arr idx () ()) ]
                                                , arr
                                                )
                    ("getFst", [l], _)        -> (s2, [], [], StructField l (defaultMemberName ++ "1") () ())
                    ("getSnd", [l], _)        -> (s2, [], [], StructField l (defaultMemberName ++ "2") () ())
                    ("pair", [a,b], Just e)   -> (s2
                                                , []
                                                , [ makeAssignment d' a (StructField e (defaultMemberName ++ "1") () ())
                                                  , makeAssignment d' b (StructField e (defaultMemberName ++ "2") () ())
                                                  ]
                                                , e
                                                )
                    ("pair", [a,b], Nothing)  -> (s3
                                                , [ makeDeclaration stc Nothing ]
                                                , [ makeAssignment d' a (StructField (VarExpr stc ()) (defaultMemberName ++ "1") () ())
                                                  , makeAssignment d' b (StructField (VarExpr stc ()) (defaultMemberName ++ "2") () ())
                                                  ]
                                                , VarExpr stc ()
                                                ) where (s3, stc) = makeVariable ot "stc" s2
                    ("trace", [lab, orig], Just e) -> (s2
                                                , []
                                                , [ makeAssignment d' orig e
                                                  , makeProcedureCall pfm (Proc "trace" firstInFP) [e, lab] []
                                                  ]
                                                , e
                                                )
                    ("trace", [lab, orig], Nothing) -> (s3
                                                , [ makeDeclaration trc Nothing ]
                                                , [ makeAssignment d' orig (VarExpr trc ())
                                                  , makeProcedureCall pfm (Proc "trace" firstInFP) [VarExpr trc (), lab] []
                                                  ]
                                                , VarExpr trc ()
                                                ) where (s3, trc) = makeVariable ot "trc" s2
                    _                         -> case (find matchPrimitive $ primitives pfm) of
                                                    Just (fd,Right tp)  -> transformPrgDesc d' s2 (tp fd inps ot)
                                                    Just (fd,Left cd)   -> transformCPrimDesc d' s2 cd inps ot
                                                    Nothing             -> (s2, [], [], result tr)
                matchPrimitive (fd,_) = (fName fd == nameS) && (matchTypes' (inputs fd) inps)
                inps = funCallParams $ result tr
                d' = (das, pfm)
        transform t s d@(das, pfm, _) p = defaultTransform t s (das, pfm, Nothing) p


addToBlock :: Block () -> ([Declaration ()], [Program ()]) -> Block ()
addToBlock b (ls,is)
    = b {
        locals = locals b ++ ls,
        blockBody = case (blockBody b) of
            (Sequence s () ()) -> Sequence (s ++ is) () ()
            p ->                  Sequence ([p] ++ is) () ()
    }



transformCPrimDesc :: (Int,Platform) -> Int -> CPrimDesc -> [Expression ()] -> Type -> (Int, [Declaration ()], [Program ()], Expression ())
transformCPrimDesc (_,pfm) serial cd inps ot
    = case (cd, length inps) of 
        (Op1 op, 1)   -> (serial, [], [], FunctionCall op ot PrefixOp inps () ())
        (Op2 op, 2)   -> (serial, [], [], FunctionCall op ot InfixOp inps () ())
        (Fun _ _, _)  -> (serial, [], [], FunctionCall (completeFunProcName pfm cd (map typeof inps) [ot]) ot SimpleFun inps () ())
        (Cas, 1)     ->  (serial, [], [], Cast ot (head inps) () ())
        (Assig, 1)    -> (serial, [], [], head inps)
        _             -> (serial', [makeDeclaration ov Nothing], [makeProcedureCall pfm cd inps [vToE ov]], vToE ov)
  where
    (serial', ov) = makeVariable ot "vhp" serial



transformPrgDesc :: (Int,Platform) -> Int -> PrgDesc -> (Int, [Declaration ()], [Program ()], Expression ())
transformPrgDesc down@(_,pfm) serial (PrgDesc crts lns rgt)
    = (serial', map (\(_,_,v,me) -> makeDeclaration v me) vars, ins, transformRgt vars rgt)
  where
    (serial', vars') = foldl transformCrtFold (serial, []) (map searchDuplicateLabels crts)
    (vars, ins) = foldl transformLineFold (vars', []) lns
    
    searchDuplicateLabels c = if (length $ filter (==c) crts) > 1 then handlePrimitivesError $ "multiple declaration"  ++ show c else c
    
    transformCrtFold (n ,vs) (Crt t v@(Var s) (Just r)) = (n', vs ++ [(v, True, mv, Just $ transformRgt vs r)])
      where
        (n', mv) = makeVariable t s n
    transformCrtFold (n ,vs) (Crt t v@(Var s) Nothing)  = (n', vs ++ [(v, False, mv, Nothing)])
      where
        (n', mv) = makeVariable t s n
    
    transformLineFold (vs, is) ln = case (ln) of 
            (Asg v r)           -> (updateVars [v],  is ++ [makeAssignment down (transformRgt' r) (transformVarL' v)])
            (Prc cd inps outs)  -> (updateVars outs, is ++ [makeProcedureCall pfm cd (map transformRgt' inps) (map transformVarL' outs)])
      where
        updateVars xs   = map (\y@(v',_,vv,mr) -> if elem v' xs then (v',True,vv,mr) else y) vs
        transformRgt'   = transformRgt vs
        transformVarL'  = vToE . transformVarL vs
    
    transformRgt vs (Exp e)           = e
    transformRgt vs (Fnc cd rgts ot)  = makeFunctionCallOrCast down cd (map (transformRgt vs) rgts) ot
    transformRgt vs (VarR v)          = vToE $ transformVarR vs v
    
    transformVarL vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of
            Just (_,_,vv,_) -> vv
            Nothing         -> handlePrimitivesError $ "Not declared: " ++ show v
    transformVarR vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of
            Just (_,True,vv,_)  -> vv
            Just (_,False,vv,_) -> vv     -- Do not check that is there any initial assignment -- quick bugfix with pair - set_pair macros
            -- Just _              -> handlePrimitivesError $ "The variable hasn't got value yet: " ++ show v
            Nothing             -> handlePrimitivesError $ "Not declared: " ++ show v




makeFunctionCallOrCast :: (Int,Platform) ->  CPrimDesc -> [Expression ()] -> Type -> Expression ()
makeFunctionCallOrCast down cd inps ot
    = case (transformCPrimDesc down (-1) cd inps ot) of
        (_, [], [], ed) -> ed
        _               -> handlePrimitivesError $ "it's not a FunctionCall: " ++ show cd ++ "number of inputs: " ++ (show $ length inps)


makeVariable :: Type -> String -> Int -> (Int, Variable ())
makeVariable t s n = (n+1, Variable (s ++ show n) t Value ())


makeDeclaration :: Variable () -> Maybe (Expression ()) -> Declaration ()
makeDeclaration v me = Declaration v me ()


makeAssignment :: (Int,Platform) -> Expression () -> Expression ()  -> Program ()
makeAssignment (defArrSize,pfm) inp out
    = case (sameVariable inp out, typeof inp) of
        (True, _)           -> Empty () ()
        (_, ArrayType _ t)  -> ProcedureCall "copyArray" [eToOut out, eToIn inp] () ()
          where
            -- size = prod_const (arraySize (typeof out) defArrSize) (SizeOf (Left $ baseType t) () ())
            -- baseType (ArrayType _ t) = baseType t
            -- baseType t               = t
        _                   -> Assign out inp () ()
  where
    sameVariable (VarExpr v1 _) (VarExpr v2 _)  | v1 == v2  = True
                                                | otherwise = False
    sameVariable (ArrayElem a1 i1 _ _) (ArrayElem a2 i2 _ _)  | a1 == a2 && i1 == i2  = True
                                                              | otherwise             = False
    sameVariable _ _ = False

makeProcedureCall :: Platform -> CPrimDesc -> [Expression ()] -> [Expression ()] -> Program ()
makeProcedureCall pfm cd@(Proc _ _) inps outs = ProcedureCall (completeFunProcName pfm cd its ots) (inps' ++ outs') () ()
  where
    inps' = map eToIn inps
    outs' = map eToOut outs
    its = map typeof inps
    ots = map typeof outs
    
makeProcedureCall _ cd _ _ = handlePrimitivesError $ "Wrong C pirmitive description in makeProcedureCall:\n" ++ show cd




matchTypes' :: [TypeDesc] -> [Expression ()] -> Bool
matchTypes' [] []     = True
matchTypes' [] (y:ys) = False
matchTypes' (x:xs) [] = False
matchTypes' (x:xs) (y:ys) = (machTypes x $ typeof y) && (matchTypes' xs ys)


completeFunProcName :: Platform -> CPrimDesc -> [Type] -> [Type] -> String
completeFunProcName pfm desc its ots
    | funPf desc == noneFP  = cName desc
    | otherwise             = cName desc ++ ifFun ++ apsToName
  where
    ifFun = case desc of
        Fun _ _ -> "_fun"
        Proc _ _ -> ""
    apsToName = concatMap (("_"++) . (toFunName pfm)) apsToNameList
    apsToNameList = (take (useInputs $ funPf desc) its) ++ (take (useOutputs $ funPf desc) ots)


toFunName :: Platform -> Type -> String
toFunName pfm (ArrayType _ t@(ArrayType _ _)) = toFunName pfm t
toFunName pfm (ArrayType _ t)                    = "arrayOf_" ++ toFunName pfm t
toFunName pfm t = case (find (\(t',_,_) -> t == t') $ types pfm) of
    Just (_,_,s)  -> map (\c -> if c == ' ' then '_' else c) $ s
    Nothing       -> handlePrimitivesError $ "Unhandled type in platform " ++ name pfm


-- arraySize :: Type -> Int -> Expression ()
-- arraySize a@(ArrayType _ t) defaultArraySize = toExp $ arraySize' a
  -- where
    -- arraySize' :: Type -> (Int,Int)
    -- arraySize' (ArrayType (LiteralLen n) t) = (n * fst at, snd at) where
        -- at = arraySize' t
    -- arraySize' (ArrayType UndefinedLen t) = (fst at, 1 + snd at) where
        -- at = arraySize' t
    -- arraySize' _ = (1,0)
    -- toExp :: (Int,Int) -> Expression ()
    -- toExp (c, 0) =  intToCe $ toInteger c
    -- toExp (c, i) = prod_const (toExp (c, i-1)) (vToE $ Variable defaultArraySizeConstantName (NumType Unsigned S32) Value ())


prod_const a b = FunctionCall "*" (NumType Unsigned S32) InfixOp [a,b] () ()

isInparam (In _ _)  = True
isInparam (Out _ _) = False


aToE (In x ())  = x
aToE (Out x ()) = x

eToIn x  = In x ()
eToOut x = Out x ()


-- ceToInt (Expression (ConstantExpression (Constant (IntConstant (IntConstantType x _)) _)) _) = x
intToCe x = ConstExpr (IntConst x () ()) ()

vToE v = VarExpr v ()