{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Metaverse.Circuit (
    Circuit,
    circuitAgentID,
    circuitSessionID,
    circuitCode,
    circuitSend,
    circuitSendSync,
    circuitSource,
    connectToSim
    )
    where

import Control.Concurrent

import Control.Arrow (first)

import Control.Monad
import Control.Monad.Trans
import Control.Monad.State hiding (get, put)
import qualified Control.Monad.State as S

import Data.Char
import Data.Digest.MD5
import Data.Word
import Data.Int
import Data.Bits
import Data.List

import qualified Data.Map as M
import Data.Map (Map)

import Data.Binary
import Data.Binary.Put
import Data.Binary.Get
import Data.Binary.IEEE754

import Data.Time.Clock

import Network.XmlRpc.Client
import Network.XmlRpc.Internals

import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString

import Network.Metaverse.Login
import Network.Metaverse.Utils
import Network.Metaverse.PacketTypes
import Network.Metaverse.Packets

import System.Random
import System.Info.MAC

import System.IO.Unsafe -- for logging

import qualified Data.ByteString      as B
import qualified Data.ByteString.Lazy as L

------------------------------------------------------------------------

data Circuit = Circuit {
    -- Some information that's nice to have access to when connected to a
    -- circuit.  This needs to be accessible to the next layer up, because
    -- this information is embedded into various message fields.
    circuitAgentID    :: UUID,
    circuitSessionID  :: UUID,
    circuitCode       :: Word32,

    circuitSocket     :: Socket,
    circuitAddr       :: SockAddr,

    circuitIncoming   :: Chan PacketBody,
    circuitAccounting :: MVar Accounting
    }

{-
    Packet accounting.  This layer of the communication system handles
    sequencing packets, tracking and resending dropped packets, acknowledging
    packets from the server, and pruning duplicate packets due to lost acks.
-}

data Accounting = Accounting {
    acctClosed        :: Bool,

    acctSequence      :: SequenceNum,
    acctRecentPackets :: [SequenceNum],
    acctPendingAcks   :: [(UTCTime, SequenceNum)],
    acctReliableQueue :: TaskQueue SequenceNum,
    acctConfirmations :: Map SequenceNum (MVar Bool),

    acctPendingPings  :: [Word8]
    }

{-
    Convenient for composing actions that need to atomically modify the
    connection accounting information.
-}
runWithMVar :: MVar a -> StateT a IO b -> IO b
runWithMVar v m = modifyMVar v (fmap (fmap swap) (runStateT m))
    where swap (a,b) = (b,a)

{-
    Generates the next sequence number in line.
-}
nextSequence :: StateT Accounting IO SequenceNum
nextSequence = do
    seq <- fmap acctSequence S.get
    modify $ \s -> s { acctSequence = seq + 1 }
    return seq

sendRaw :: Socket -> SockAddr -> Packet -> IO ()
sendRaw sock addr packet = sendAllTo sock (serialize packet) addr

getAcks :: Int -> StateT Accounting IO [SequenceNum]
getAcks size = do
    let nacks = size `div` 4
    pending <- fmap acctPendingAcks S.get
    let (sending, leftovers) = splitAt (min 255 nacks) pending
    modify (\s -> s { acctPendingAcks = leftovers })
    return (map snd sending)

sendWithAcks :: Socket -> SockAddr -> Packet -> StateT Accounting IO ()
sendWithAcks sock addr packet = do
    acks <- getAcks $ 10000 - 7 - packetLength (packetBody packet)
    liftIO $ sendRaw sock addr packet { packetAcks = acks }

data Reliability = Unreliable
                 | Reliable (Maybe (MVar Bool))

isReliable :: Reliability -> Bool
isReliable Unreliable   = False
isReliable (Reliable _) = True

circuitSendImpl :: Circuit
                -> Reliability -- ^ MVar to notify when packet is acked
                -> PacketBody  -- ^ The payload to send along
                -> StateT Accounting IO ()
circuitSendImpl circ rel body = do
    let sock = circuitSocket circ
    let addr = circuitAddr   circ
    seq <- nextSequence
    let packet = Packet
            (shouldZerocode body) (isReliable rel) False seq B.empty body []
    sendWithAcks sock addr packet
    reliableAccounting rel circ packet

reliableAccounting :: Reliability
                   -> Circuit
                   -> Packet
                   -> StateT Accounting IO ()
reliableAccounting Unreliable _ _ = return ()
reliableAccounting (Reliable mv) circ packet = do
    flip (maybe $ return ()) mv $ \ v -> do
        con   <- fmap acctConfirmations S.get
        modify $ \s -> s { acctConfirmations = M.insert seq v con }
    queue <- fmap acctReliableQueue S.get
    liftIO $ schedule queue seq retryTime (retry retryCount)
  where
    retryTime  = 1500000 -- TODO: Find the right value
    retryCount = 3
    sock    = circuitSocket circ
    addr    = circuitAddr   circ
    seq     = packetSequence packet
    retried = packet { packetRetransmit = True }
    retry 0 = flip (maybe $ return ()) mv $ \ v -> do
        putMVar v False
        runWithMVar (circuitAccounting circ) $ do
            con   <- fmap acctConfirmations S.get
            modify $ \s -> s { acctConfirmations = M.delete seq con }
    retry n = runWithMVar (circuitAccounting circ) $ do
        sendWithAcks sock addr retried
        queue <- fmap acctReliableQueue S.get
        liftIO $ schedule queue seq retryTime (retry (n-1))

circuitSend :: Circuit -> Bool -> PacketBody -> IO ()
circuitSend circ reliable msg = runWithMVar (circuitAccounting circ) $ do
    circuitSendImpl circ
        (if reliable then Reliable Nothing else Unreliable) msg

circuitSendSync :: Circuit -> PacketBody -> IO Bool
circuitSendSync circ msg = do
    v <- newEmptyMVar
    runWithMVar (circuitAccounting circ) $
        circuitSendImpl circ (Reliable (Just v)) msg
    takeMVar v

{-
    A process that occasionally sends PacketAck messages for any
    outstanding acks.  This ensures that if there's no communication
    for any reason, there's still acks going out.
-}
ackSender :: Circuit -> IO ()
ackSender circ = do
    cont <- runWithMVar (circuitAccounting circ) $ do
        acks <- fmap acctPendingAcks S.get
        t    <- liftIO $ getCurrentTime
        let ackThreshold = 0.75 -- TODO: Find the right values
        when (not (null acks)
            && t `diffUTCTime` fst (head acks) > ackThreshold) $ do
            acks <- getAcks (10000 - 7)
            circuitSendImpl circ Unreliable (PacketAck (map PacketAck_Packets acks))
        return True
    when cont $ do
        threadDelay 500000 -- TODO: Find the right frequency
        ackSender circ

confirmPacket :: SequenceNum -> StateT Accounting IO ()
confirmPacket seq = do
    q <- fmap acctReliableQueue S.get
    m <- fmap acctConfirmations S.get
    liftIO $ cancel q seq
    case M.lookup seq m of
        Nothing -> return ()
        Just mv -> do
            liftIO $ putMVar mv True
            modify $ \s -> s { acctConfirmations = M.delete seq m }

recvRaw :: Socket -> IO (Packet, SockAddr)
recvRaw sock = fmap (first deserialize) (recvFrom sock 10000)

packetReceiver :: Circuit -> IO ()
packetReceiver circ = do
    let sock = circuitSocket circ
    let addr = circuitAddr   circ
    (packet, addr') <- recvRaw sock
    when (addr == addr') $ runWithMVar (circuitAccounting circ) $ do
        mapM_ confirmPacket (packetAcks packet)

        when (packetReliable packet) $ do
            t <- liftIO $ getCurrentTime
            modify $ \s -> s {
                acctPendingAcks = acctPendingAcks s ++ [ (t, packetSequence packet) ]
            }

        recent <- fmap acctRecentPackets S.get
        when (packetReliable packet) $ modify $ \s ->
            s { acctRecentPackets = take 100 (packetSequence packet : acctRecentPackets s) }

        {-
            Handle some built-in packets.
        -}
        case packetBody packet of
            PacketAck acks -> do
                mapM_ confirmPacket (map packetAck_Packets_ID acks)
            StartPingCheck (StartPingCheck_PingID x y) -> do
                circuitSendImpl circ Unreliable $  CompletePingCheck
                    (CompletePingCheck_PingID x)
            _ -> do
                when (not (packetRetransmit packet)
                    || not (packetSequence packet `elem` recent)) $ do
                    liftIO $ writeChan (circuitIncoming circ) (packetBody packet)

    packetReceiver circ

circuitSource :: Circuit -> IO (IO PacketBody)
circuitSource circ = do chan <- dupChan (circuitIncoming circ)
                        return $ readChan chan

{-
pingSender :: Circuit -> Word8 -> IO ()
pingSender conn n = do
    threadDelay 5000000
    runWithMVar () conn $ sendPacket False (StartPingCheck (StartPingCheck_PingID n 0))
    pingSender conn (n+1)
-}

connectToSim :: MVToken -> IO Circuit
connectToSim token = do
    sock   <- socket AF_INET Datagram defaultProtocol
    host   <- inet_addr (tokenSimIP token)
    let port = fromIntegral (tokenSimPort token)

    acct   <- newEmptyMVar
    inc    <- newChan

    let circ = Circuit {
        circuitAgentID    = tokenAgentID token,
        circuitSessionID  = tokenSessionID token,
        circuitCode       = tokenCircuitCode token,
        circuitSocket     = sock,
        circuitAddr       = SockAddrInet port host,
        circuitIncoming   = inc,
        circuitAccounting = acct
        }

    queue  <- newTaskQueue

    forkIO (ackSender      circ)
    forkIO (packetReceiver circ)
--    forkIO (pingSender     conn 0)

    putMVar acct $ Accounting {
        acctClosed        = False,
        acctSequence      = 1,
        acctRecentPackets = [],
        acctPendingAcks   = [],
        acctReliableQueue = queue,
        acctConfirmations = M.empty,
        acctPendingPings  = []
        }

    circuitSendSync circ $ UseCircuitCode $ UseCircuitCode_CircuitCode
        (circuitCode circ) (circuitSessionID circ) (circuitAgentID circ)

    return circ