{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-|
    Module      :  Control.ERNet.Deployment.Local.Channel
    Description :  channel implementation using STM
    Copyright   :  (c) Michal Konecny
    License     :  BSD3

    Maintainer  :  mik@konecny.aow.cz
    Stability   :  experimental
    Portability :  portable

    A simple channel implementation using STM protected variables.
-}
module Control.ERNet.Deployment.Local.Channel 
(
    ChannelLocal,
    ChannelLocalAnyProt
)
where

import Control.ERNet.Deployment.Local.Logger

import Control.ERNet.Foundations.Protocol
import Control.ERNet.Foundations.Event
import qualified Control.ERNet.Foundations.Event.Logger as LG
import qualified Control.ERNet.Foundations.Channel as CH

import Control.Concurrent as Concurrent
import Control.Concurrent.STM
import Data.Number.ER.Misc.STM

--import System.Time
import Data.Time.Clock

import qualified Data.Map as Map
import Data.Maybe
import Data.Typeable

instance CH.Channel ChannelLocal ChannelLocal ChannelLocalAnyProt ChannelLocalAnyProt
    where
    castIn = castChannel
    castOut = castChannel
    castInIO = castChannelIO
    castOutIO = castChannelIO
    makeQuery = makeQuery
    makeQueryAnyProt = makeQueryAnyProt
    waitForQuery = waitForQuery
    waitForQueryMulti = waitForQueryMulti
    answerQuery = answerQuery
    answerQueryAnyProt = answerQueryAnyProt
    waitForAnswer = waitForAnswer
    waitForAnswerMulti = waitForAnswerMulti
    
instance CH.ChannelForScheduler LoggerLocal ChannelLocal ChannelLocal ChannelLocalAnyProt ChannelLocalAnyProt
    where
    new = newChannel    

{-|
    Union of channel types over instances of the 'ChannelComm.QERrotocol' class.
    
    (existential type) 
-}
data ChannelLocalAnyProt =
    forall q a. (QAProtocol q a, Show q, Show a) =>
    ChannelLocalAnyProt
    {
        chanyCH :: (ChannelLocal q a)
    ,
        chanyChT :: ChannelType
    }

data ChannelLocal q a =
    ChannelLocal
    {
--        chExampleQA :: (q,a),
        chTV :: TVar (ChannelState q a),
        chLogger :: LoggerLocal,
        chDestination :: String,
        chID :: Int -- ^ rank within its answering process (0..)
    }
    deriving (Typeable)
            
instance (Eq (ChannelLocal q a)) where
    ch1 == ch2 = (chID ch1) == (chID ch2)
instance (Ord (ChannelLocal q a)) where
    compare ch1 ch2 = compare (chID ch1) (chID ch2)
instance (Show (ChannelLocal q a)) where
    show ch = "CH" ++ show (chID ch)

type ChannelName = String

makeChName ::
    ChannelLocal q a ->
    ChannelName
makeChName channel
    | cId == 0 = procName
    | otherwise = procName ++ show (cId)
    where
    procName = chDestination channel
    cId = chID channel     

data ChannelState q a =
    ChannelState
        {
            chStNextId :: QueryId,
                -- ^ unique query Id to use for the next query
            chStQueriesNew :: [(QueryId, q)],
                -- ^ queries that have not been picked up yet
            chStQueriesCache :: Map.Map QueryId (QAState q),
                -- ^ queries that are being served or whose answers have been cached
            chStQueriesIndex :: Map.Map q QueryId,
                -- ^ index to queries that are being served or whose answers have been cached
            chStAnswers :: Map.Map QueryId (a, Bool)
                -- ^ answers that have not been picked up yet or have been cached + indicator whether to cache
        }

initialChState = 
    ChannelState 1 [] Map.empty Map.empty Map.empty

data QAState q =
    QAState
    {
        qaStQry :: q,
        qaStWaiting :: Int
            {-^ 
                This number indicates how many times this query has been made 
                and answers not picked up yet.
            
                When 0, this QAstate and the corresponding answer 
                can be removed from the channel state unless the answer is cached. 
            -} 
    }    
        
initialQAState qry = QAState qry 1

newChannel ::
    LoggerLocal {-^ for logging events -} ->
    String  {-^ name of channel responder process -} ->
    Int {-^ channel id -} ->
    ChannelType {-^ used to determine the embedded instance of QAProtocol -} ->
    IO (ChannelLocalAnyProt, ChannelLocalAnyProt)
newChannel logger procName cID chType =
    do
    cha <- auxNewCHA chType 
    return (cha, cha)
    where
    auxNewCHA chType@(ChannelType (q :: q) (a :: a)) =
        do
        cTV <- newTVarIO initialChState 
        return $ 
            ChannelLocalAnyProt 
                ((chan cTV) :: ChannelLocal q a) 
                chType
    chan cTV =
        ChannelLocal
            {
               chID = cID,
               chTV = cTV,
               chLogger = logger,
               chDestination = procName
            }


castChannel ::
    (QAProtocol q a) =>
    String {-^ place where function used; for error messages -} ->
    ChannelLocalAnyProt ->
    (ChannelLocal q a)
castChannel locationDescr chA =
    case chA of
        ChannelLocalAnyProt ch chtp ->
            case cast ch of
                Just ch -> ch
                Nothing -> 
                    channelCastError locationDescr (makeChName ch) chtp
castChannelIO ::
    (QAProtocol q a) =>
    String {-^ place where function used; for error messages -} ->
    ChannelLocalAnyProt ->
    IO (ChannelLocal q a)
castChannelIO locationDescr chA =
    do
    (chA_, _) <- newChannel undefined "" 0 chtp    
    case (chA, chA_) of
        (ChannelLocalAnyProt ch _, ChannelLocalAnyProt ch_ _) -> 
            case [cast ch, cast ch_] of
                [Just ch, Just _] ->
                    return ch
                _ -> 
                    channelCastError locationDescr (makeChName ch) chtp
    where
    chtp = chanyChT chA

channelCastError ::
    String ->
    ChannelName ->
    ChannelType ->
    a
channelCastError locationDescr chnm chtp =
    error $ 
        locationDescr         
        ++ " failed casting channel " ++ chnm
        ++ " to " ++ show chtp    
    
    
makeQuery callingCh callingQryId channel qry =
    do
    qryId <- atomically updateChSt
    timeNow <- getCurrentTime
    LG.addEvent (chLogger channel)
        ERNetEvQryMade
        {
            ernetevTime = timeNow,
            ernetevQryId = qryId,
            ernetevFromId = makeChName callingCh,
            ernetevFromQryId = callingQryId,
            ernetevToId = makeChName channel,
            ernetevQry = qry
        }
    return qryId
    where
    updateChSt =
        do
        chSt <- readTVar cTV
        case (Map.lookup qry (chStQueriesIndex chSt)) of
            (Nothing) -> 
                do
                writeTVar cTV (updateNew chSt)
                return $ chStNextId chSt
            (Just qryId) ->
                do
                writeTVar cTV (updateOld chSt qryId)
                return qryId
        where
        cTV = chTV channel
        updateNew chSt =
            chSt
            {
                chStNextId = qryId + 1,
                chStQueriesNew = 
                    (chStQueriesNew chSt) ++ [(qryId, qry)],
                chStQueriesCache =
                    Map.insert qryId (initialQAState qry) (chStQueriesCache chSt),
                    -- remember the query so that answer can be logged with
                    -- the query included
                chStQueriesIndex =
                    Map.insert qry qryId (chStQueriesIndex chSt)
                    -- index the query so that we can recognise it when it comes again
            }
            where
            qryId = chStNextId chSt
        updateOld chSt qryId =
            chSt
            {
                chStQueriesCache =
                    Map.adjust cacheIncCount qryId (chStQueriesCache chSt)
            }
            where
            cacheIncCount qaSt@(QAState _ count) =
                qaSt { qaStWaiting = count + 1 }

makeQueryAnyProt locationDescr callingCHA callingQryId chA qry =
    case (callingCHA, qry) of
        (ChannelLocalAnyProt callingCH _, QueryAnyProt q) ->
            do
            ch <- castChannelIO locationDescr chA
            makeQuery callingCH callingQryId ch q
                
waitForQuery channel =
    do
    (qryId, qry) <- atomically waitUpdateChSt
--    timeNow <- getCurrentTime
--    addEvent (chLogTV channel)
--        ERNetEvQryReceived
--        {
--            ernetevTime = timeNow,
--            ernetevQryId = qryId,
--            ernetevToId = makeChName channel,
--            ernetevQry = qry
--        }
    return (qryId, qry)
    where
    waitUpdateChSt =
        do
        chSt <- readTVar cTV
        exploreState chSt
        where
        cTV = chTV channel
        exploreState chSt =
            case chStQueriesNew chSt of
                [] -> 
                    retry -- no queries on this channel, wait
                (qryData : otherQueries) -> -- a new query found on channel number chN
                    do
                    writeTVar cTV chSt' -- remove it from the queue
                    return qryData
                    where
                    chSt' = 
                        chSt
                        {
                            chStQueriesNew = otherQueries
                        }
                
waitForQueryMulti channels =
    do
    (chN, qryData) <- atomically waitUpdateChSt
--    timeNow <- getClockTime
--    addEvent (chLogTV channel)
--        ERNetEvQryReceived
--        {
--            ernetevTime = timeNow,
--            ernetevQryId = qryId,
--            ernetevToId = case channels !! chN of Just channel ->  makeChName channel
--            ernetevQry = qry
--        }
    return (chN, qryData)
    where
    waitUpdateChSt =
        do
        exploreChannels $ zip [0..] channels
        where
        exploreChannels [] = retry -- no new queries, keep waiting
        exploreChannels ((chN, ChannelLocalAnyProt ch _) : otherChannels) =
            do
            res <- exploreState $ chTV ch 
            case res of
                Nothing -> exploreChannels otherChannels
                Just qryData -> return (chN, qryData) 
        exploreState cTV =
            do
            chSt <- readTVar cTV
            case chStQueriesNew chSt of
                [] -> 
                    return Nothing -- no queries on this channel, try another one
                ((qryId, qry) : otherQueries) -> -- a new query found on channel number chN
                    do
                    writeTVar cTV chSt' -- remove it from the queue
                    return $ Just (qryId, QueryAnyProt qry)
                    where
                    chSt' = 
                        chSt
                        {
                            chStQueriesNew = otherQueries
                        }

answerQuery useCache channel (qryId, ans) =
    do
    atomically updateChSt
--    timeNow <- getClockTime
--    addEvent (chLogTV channel)
--        ERNetEvAnsMade
--        {
--            ernetevTime = timeNow,
--            ernetevQryId = qryId,
--            ernetevToId = makeChName channel,
--            ernetevAns = ans
--        }

--    putStrLn $ "answerQuery: " ++ answeringProcess ++ " inserted answer for qryId=" ++ show qryId
    return ()
    where
    updateChSt =
        do
        modifyTVar cTV update
        where
        cTV = chTV channel
        update chSt =
            chSt
            {
                chStAnswers =
                    Map.insert qryId (ans, useCache) (chStAnswers chSt)
            }

answerQueryAnyProt locationDescr useCache chA (qryId, ans) =
    case ans of 
        AnswerAnyProt a ->
            answerQuery useCache (castChannel locationDescr chA) (qryId, a)
                
waitForAnswer waitingCh waitingQryId channel qryId =
    do
--    putStrLn $ "waitForAnswer: " ++ waitingProcess ++ " waiting for answer to qryId=" ++ show qryId
    (qry, ans) <- atomically waitUpdateChSt
    timeNow <- getCurrentTime
    LG.addEvent (chLogger channel)
        ERNetEvAnsReceived
        {
            ernetevTime = timeNow,
            ernetevQryId = qryId,
            ernetevFromId = makeChName waitingCh,
            ernetevFromQryId = waitingQryId,
            ernetevToId = makeChName channel,
            ernetevAns = ans,
            ernetevQry = qry
        }
    return ans
    where
    waitUpdateChSt =
        do
        chSt <- readTVar cTV
        case Map.lookup qryId (chStAnswers chSt) of
            Nothing -> retry
            Just (ans, isCached) ->
                do
                writeTVar cTV chSt'
                return (qry, ans)
                where
                (chSt', qry) = 
                    waitForAnswerAUX qryId chSt isCached
    cTV = chTV channel

waitForAnswerAUX qryId chSt isCached =
    (chSt', qry)
    where  
    (Just (QAState qry count)) = 
        qryId `Map.lookup` (chStQueriesCache chSt)
    chSt' 
        | count > 1 || isCached = 
            chSt
            {
                chStQueriesCache =
                    Map.insert qryId (QAState qry (count - 1)) 
                        (chStQueriesCache chSt)
            }
        | otherwise = 
            -- delete all references to this query: 
            chSt
            {
                chStQueriesCache =
                    Map.delete qryId (chStQueriesCache chSt),
                chStQueriesIndex =
                    Map.delete qry (chStQueriesIndex chSt),
                chStAnswers =
                    Map.delete qryId (chStAnswers chSt)
            }                        

waitForAnswerMulti waitingCHA waitingQryId channelIds =
    do
    (chN, chn, qryId, qry, ans) <- atomically $ waitUpdateChSt $ zip [0..] channelIds
    timeNow <- getCurrentTime
    case (waitingCHA, chn, qry, ans) of
        (ChannelLocalAnyProt waitingCH _, 
         ChannelLocalAnyProt ch _, 
         QueryAnyProt q, AnswerAnyProt a) -> 
            LG.addEvent (chLogger ch)
                ERNetEvAnsReceived
                {
                    ernetevTime = timeNow,
                    ernetevQryId = qryId,
                    ernetevFromId = makeChName waitingCH,
                    ernetevFromQryId = waitingQryId,
                    ernetevToId = makeChName ch,
                    ernetevAns = a,
                    ernetevQry = fromJust $ cast q
                }
    return (chN, ans)
    where
    waitUpdateChSt [] = retry
    waitUpdateChSt ((chN, (channel, qryId)) : otherQueryInfos) =
        case channel of
            (ChannelLocalAnyProt ch chtp) ->
                do
                chSt <- readTVar cTV
                case Map.lookup qryId (chStAnswers chSt) of
                    Nothing -> waitUpdateChSt otherQueryInfos
                    Just (ans, isCached) ->
                        do
                        writeTVar cTV chSt'
                        return (chN, ChannelLocalAnyProt ch chtp, qryId, QueryAnyProt qry, AnswerAnyProt ans)
                        where
                        (chSt', qry) = 
                            waitForAnswerAUX qryId chSt isCached
                where
                cTV = chTV ch