-----------------------------------------------------------------------------
--
-- Module      :  Control.Concurrent.Network.Protocol
-- Copyright   :  (C) 2010, Paul Sonkoly
-- License     :  BSD style
--
-- Maintainer  :  Paul Sonkoly
-- Stability   :  provisional
-- Portability :
--
-- | Protocol implementation. (Internal module).
--
--   A packet contains a 2 byte packet type identifier, and a possibly empty list of
--    2 byte length field and length byte ByteString pairs. The number of pairs in the
--    list can be deduced from the packet identifier and the direction of the packet (slave
--    to master - or master to slave).
-----------------------------------------------------------------------------

module Control.Concurrent.Network.Protocol
    (
    -- * Functions
      readProtoId
    , writeProtoId
    , readBinary
    , writeBinary
    , readByteString
    , writeByteString
    -- * Slave communications with master
    , slaveID
    , numSlaves
    , printMsg
    -- * Protocol identifiers
    , ProtoId(..)
    -- * Poll operators
    , Equality(..)
    ) where

import System.IO
import Data.Int
import Data.Binary
import Data.ByteString.Lazy as DBL
import Data.Maybe

import Control.Concurrent.Network.Slave

type MsgLen  = Int16

data ProtoId =
      NNV   -- ^ New empty 'NVar'
    | PNV   -- ^ Put 'NVar'
    | TNV   -- ^ Take 'NVar'
    | PWO   -- ^ Poll with a condition
    | PMS   -- ^ print a message on master
    | SID   -- ^ get slave id
    | NSL   -- ^ number of slaves
    deriving (Enum, Show) -- Enum not really needed it's only for Binary instance

instance Binary ProtoId where
    put x = put (fromIntegral (fromEnum x)::Int16)
    get   = (get :: Get Int16) >>= return . toEnum . fromIntegral

data Equality = EQOP | NEQOP deriving Eq

instance Binary Equality where
    put EQOP  = put (fromIntegral 0::Int16)
    put NEQOP = put (fromIntegral 1::Int16)
    get = do
        g <- (get :: Get Int16)
        case g of
            0 -> return EQOP
            1 -> return NEQOP


-- | Reads a packet ID from 'h'
readProtoId :: Handle -> IO ProtoId
readProtoId h = hGet h 2 >>= return . decode


-- | Writes a packet ID to 'h'
writeProtoId :: Handle -> ProtoId -> IO ()
writeProtoId h p = hPut h $ encode p


-- | Reads a ByteString from 'h'
readByteString :: Handle -> IO ByteString
readByteString h = do
    l <- hGet h 2
    m <- hGet h $ fromIntegral (decode l::MsgLen)
    return m


-- | Reads a 'Binary' from 'h'
readBinary :: (Binary a) => Handle -> IO a
readBinary h = readByteString h >>= return . decode


-- | Writes a 'ByteString' to 'h'
writeByteString :: Handle -> ByteString -> IO ()
writeByteString h bs = let
    len = encode (fromIntegral $ DBL.length bs::MsgLen)
    in hPut h len >> hPut h bs


-- | Writes a binary to 'h'
writeBinary :: (Binary a) => Handle -> a -> IO ()
writeBinary h a = return (encode a) >>= writeByteString h


-- | Returns the slave ID of the caller
slaveID :: NCContext -> IO Int
slaveID nc = writeProtoId (hdl nc) SID >> readBinary (hdl nc)


-- | Prints a message on master
printMsg :: NCContext -> String -> IO ()
printMsg nc msg = writeProtoId (hdl nc) PMS >> writeBinary (hdl nc) msg

-- | Number of slaves
numSlaves :: NCContext -> IO Int
numSlaves nc = writeProtoId (hdl nc) NSL >> readBinary (hdl nc)