module Spread.Client.Message
    (
     OutMsg(..),InMsg(..),Message(..),Group,PrivateGroup,PrivateName,GroupId,OrderingType(..),Cause(..),groupName,mkPrivateGroup,privateName,GroupMsg(..),
     MembershipMsg(..),receive_internal,multicast_internal,mkGroup,mkPrivateName,putPadded,makeGroup,KillMsg(..),RejectedMsg(..))
     where
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Char8 as Ch
import Data.Bits
import Data.Word
import Data.Binary.Get
import Data.Binary.Put
import Data.Map (fromList,findWithDefault)
import Control.Monad
import Data.Maybe
import Data.List (elemIndex,find)
import System.IO (hFlush)
import Spread.Constants


runGetS bs = flip runGet (L.fromChunks [bs])
data Ask p m a = Result a | Ask p (m (Ask p m a))

instance (Functor m, Monad m) => Monad (Ask p m) where
    return = Result
    Ask i m >>= f = Ask i (fmap (>>= f) m)
    Result a >>= f = f a

runAsk' gen (Result a) = return a
runAsk' gen (Ask i m) = do foo <- gen i m
                           runAsk' gen foo
runAsk gen = runAsk' $ flip (fmap . runGet) . gen 
result i = fmap Result i
ask i m = return $ Ask i m
instance Functor m => Functor (Ask p m) where
    fmap f (Result a) = Result $ f a
    fmap f (Ask i m) = Ask i (fmap (fmap f) m)

-- | Represents the orderings as specified by the Spread toolkit.
data OrderingType = Unreliable | Reliable | Fifo | Causal | Agreed | Safe deriving (Eq,Ord,Read,Show)
orderingTable = fromList orderingList
orderingList = zip [Unreliable,Reliable,Fifo,Causal,Agreed,Safe] 
                [uNRELIABLE_MESS,rELIABLE_MESS,fIFO_MESS,cAUSAL_MESS,aGREED_MESS,sAFE_MESS]
getOrdering t = fst . fromJust . find (flip isSet t . snd) $ orderingList
-- | Message to be sent.
data OutMsg = Outgoing { outOrdering :: !OrderingType 
                       , outDiscard :: !Bool -- ^ If True you won't get a copy of this message back from the server.
                       , outData :: !ByteString -- ^ Message body.
                       , outGroups :: ![Group] -- ^ Recipients of the message
                       , outMsgType :: !Word16 -- ^ To be used by the application to identify the kind of message.
                       } 
              deriving (Show)
-- | Message received.
data InMsg = Incoming { inOrdering :: !OrderingType
                      , inSender :: !PrivateGroup 
                      , inData :: !ByteString
                      , inGroups :: ![Group]
                      , inMsgType :: !Word16
                      , inEndianMismatch :: !Bool -- ^ True if the message has been sent with a different endian order.
                      } deriving (Show)
-- | Union Type of messages that can be received from the server.
data Message = Regular !InMsg
             | Membership !MembershipMsg 
             | Rejected !RejectedMsg deriving (Show)

instance Sendable OutMsg where
    getType m = outMsgType m
    getData m = outData m
    getGroups m = outGroups m
    getServiceType m = (if outDiscard m then (sELF_DISCARD .|.) else id ) $ (findWithDefault 0 (outOrdering m) orderingTable)
instance Sendable GroupMsg where
   getGroups m = [grp m]
   getServiceType (Leaving _) = lEAVE_MESS
   getServiceType (Joining _) = jOIN_MESS
instance Sendable KillMsg where
   getGroups (Kill g) = [fromPrivateGroup g]
   getServiceType (Kill _) = kILL_MESS

--TODO Rejected Message
                       
-- | Message regarding changes in group membership.                             
data MembershipMsg = Transient { changingGroup :: !Group } 
                   | Reg { changingGroup :: !Group, index :: !Int, numMembers :: !Int, members :: ![PrivateGroup], groupId :: !GroupId, cause :: !Cause } 
                   | SelfLeave { changingGroup :: !Group } deriving (Show)

-- | Messages used to join or leave a group.
data GroupMsg = Joining {grp :: !Group} | Leaving {grp :: !Group} deriving Show
data KillMsg = Kill PrivateGroup
-- | A 'Group' is a collection of clients identified by a name.
newtype Group = G {groupName :: ByteString } deriving (Eq,Show)
mkGroup :: B.ByteString -> Maybe Group
mkGroup s = if B.all (\b -> (b > 36) && (b < 126)) s then Just (G (B.take mAX_GROUP_NAME s)) else Nothing
makeGroup = mkGroup . Ch.pack 

-- | A 'PrivateGroup' identifies a connection.
type PrivateGroup = Group
mkPrivateGroup = G . B.take mAX_GROUP_NAME
toPrivateGroup = id
privateGroupName = groupName
-- | Initial part of a 'PrivateGroup' name that is chosen by the client when connecting.
newtype PrivateName = PrivateName {privateName :: ByteString} deriving (Eq,Show)
mkPrivateName = PrivateName . B.take mAX_PRIVATE_NAME 
-- | Identifier for a membership message.
data GroupId = GId !Word32 !Word32 !Word32 deriving (Eq,Show)
-- | What caused a membership message.
data Cause = Join {joining :: !PrivateGroup} 
           | Leave {leaving :: !PrivateGroup} 
           | Disconnect { disconnecting :: !PrivateGroup} 
           | Network {sets :: ![[PrivateGroup]], localSet :: ![PrivateGroup] } deriving (Show)


sameEndian i = (i .&. eNDIAN_TYPE) == aRCH_ENDIAN
clearEndian i = i .&. (complement eNDIAN_TYPE)
setEndian i = (i .&. complement eNDIAN_TYPE ) .|. aRCH_ENDIAN
flip32 :: Word32 -> Word32
flip32 i = ((i `shiftR` 24) .&. 0x000000ff) .|. ((i `shiftR`  8) .&. 0x0000ff00) .|. ((i `shiftL`  8) .&. 0x00ff0000) .|. ((i `shiftL` 24) .&. 0xff000000)

data Raw = Raw { serviceType :: !Word32, isender :: !Group, igroups :: ![Group], 
                    itype :: !Word16,daemonEndianMismatch :: !Bool, iendianMismatch :: !Bool, body :: !ByteString } deriving Show

parseRaw :: Raw -> Message
parseRaw i@Raw{serviceType=t} | isSet rEJECT_MESS t     = Rejected $ asRejected i
                              | isSet mEMBERSHIP_MESS t = Membership $ asMembership i
                              | otherwise               = Regular $ asRegular i

data RejectedMsg = WasGroup !GroupMsg | WasOut !OutMsg deriving Show
asGroupMsg i@Raw{serviceType=t,igroups=[g]} | isSet lEAVE_MESS t = Leaving g
                                            | isSet jOIN_MESS t = Joining g
asOutMsg i = Outgoing {outOrdering = getOrdering $ serviceType i,
	               outDiscard = isSet sELF_DISCARD $ serviceType i,
	               outData = body i,
	               outGroups = igroups i,
	               outMsgType = itype i}

asRejected :: Raw -> RejectedMsg
asRejected i@Raw{serviceType=t} | isSet rEGULAR_MESS t = WasOut $ asOutMsg i
                                | otherwise            = WasGroup $ asGroupMsg i

asRegular :: Raw -> InMsg
asRegular i = Incoming { inOrdering = getOrdering $ serviceType i,
                         inSender = toPrivateGroup (isender i),
                         inData = body i,
                         inGroups = igroups i,
                         inMsgType = itype i,
                         inEndianMismatch = iendianMismatch i
                       }

receive_internal h prvg = liftM parseRaw $ runAsk (L.hGet h) $ Ask (mAX_GROUP_NAME + 16) getInternal
    where 
          getInternal = do srvT <- getInt
                           let (dEM,maybeFlip) = if sameEndian srvT then (False,id) else (True, flip32)-- deamonEndianMismatch
                           senderbs <- getGroup
                           [ng,hint,dl] <- replicateM 3 getInt
                           let [serviceType',numGroups,dataLen] = map maybeFlip [srvT,ng,dl]
                               eM  = not (sameEndian hint)
                               hint' = if eM then flip32 hint else hint
                               type' = fromIntegral $ (clearEndian hint' `shiftR` 8) .&. 0x0000FFFF
                               getOldType :: Get Word32
                               getOldType = do oldt <- maybeFlip `fmap` getInt
                                               return $ rEJECT_MESS .|. oldt
                           return $ do 
                             serviceType <- fmap clearEndian $ if isSet rEJECT_MESS serviceType' 
                                                               then Ask 4 (result getOldType)
                                                               else return serviceType'
                             Ask ((fromIntegral numGroups * mAX_GROUP_NAME) + (fromIntegral dataLen)) $ do 
                                    groups <- readGroups (fromIntegral numGroups)
                                    body <- getByteString (fromIntegral dataLen )
                                    result . return $ Raw { serviceType = serviceType,
                                                    isender = senderbs,
                                                    igroups = groups,
                                                    itype = if isSet mEMBERSHIP_MESS serviceType && isSet rEG_MEMB_MESS serviceType 
                                                              then fromIntegral . fromJust . elemIndex (fromPrivateGroup prvg) $ groups
                                                              else type',
                                                    daemonEndianMismatch = dEM,
                                                    body = body,
                                                    iendianMismatch = eM
                                                  }

asMembership i@Raw{serviceType=t} | isSet tRANSITION_MESS t = Transient (isender i)
                                  | isSet rEG_MEMB_MESS t   = Reg { changingGroup = isender i,
                                                                    index = fromIntegral $ itype i,
                                                                    members = map toPrivateGroup (igroups i),
                                                                    numMembers = Prelude.length (igroups i),
                                                                    groupId = gid,
                                                                    cause = cause
                                                                  }
                                  | isSet cAUSED_BY_LEAVE t && 
                                    not (isSet rEG_MEMB_MESS t) = SelfLeave { changingGroup = isender i }
                                  | otherwise = error "asMembership: unexpected message type"
    where (gids,rest) = B.splitAt 12 (body i)
          gid = runGetS gids $ (join . join) (liftM3 GId) getInt'
          getInt' = (if daemonEndianMismatch i then flip32 else id) `fmap` getInt
          getSet = fmap (map toPrivateGroup) . readGroups . fromIntegral =<< getInt' --  a set is n followed by n Groups
          cause = runGetS rest $ do
                    numSets <- fmap fromIntegral getInt'
                    byteOffsetToLocalSet <- fmap fromIntegral getInt'
                    firstSet <- bytesRead
                    let localSetIndex = firstSet + byteOffsetToLocalSet
                    pairs <- replicateM numSets $ do
                               mark <- bytesRead
                               grps <- getSet
                               return (mark,grps)
                    let first = head . head $ sets
                        sets = map snd pairs
                        lSet = fromJust . lookup localSetIndex $ pairs
                    return $ case () of 
                      _ | isSet cAUSED_BY_JOIN t -> Join first
                        | isSet cAUSED_BY_LEAVE t -> Leave first
                        | isSet cAUSED_BY_DISCONNECT t -> Disconnect first
                        | isSet cAUSED_BY_NETWORK t -> Network sets lSet
           
readGroups :: Int -> Get [Group]
readGroups n = replicateM n $ getGroup
                
getGroup :: Get Group
getGroup = (G . fst . B.spanEnd (==0)) `fmap` getBytes mAX_GROUP_NAME
        
                          


                
-- lifts (Get a) to (IO a) reading n bytes from the handle.                             
--runGetNH :: Int -> GHC.IOBase.Handle -> Get a -> IO a
runGetNH n h m = runGet m `fmap` L.hGet h n
                         
getInt :: Get Word32
getInt = getWord32host
putInt :: Word32 -> Put
putInt = putWord32host
isSet :: (Bits a) => a -> a -> Bool
isSet f t = t .&. f /= 0

class Sendable a where
    getGroups :: a -> [Group]
    --getGroups _ = []
    getData :: a -> ByteString 
    getData _ = B.empty
    getServiceType :: a -> Word32
    getType :: a -> Word16
    getType _ = 0

putGroup :: Group -> PutM ()
putGroup (G b) = putPadded mAX_GROUP_NAME . B.take mAX_GROUP_NAME $ b

putPadded :: Int -> ByteString -> PutM ()
putPadded n s = let len = B.length s in putByteString s >> replicateM_ (n - len) (putWord8 0)

-- multicast_internal :: (Sendable a) => PrivateGroup -> a -> Handle -> IO Bool
multicast_internal prvg s h = maybe (return False) ((>> (hFlush h >> return True)) . L.hPut h) . sendable prvg $ s

fromPrivateGroup :: PrivateGroup -> Group
fromPrivateGroup = id

sendable :: (Sendable a) => PrivateGroup -> a -> Maybe L.ByteString
sendable prvg m = if numBytes > mAX_MESSAGE_LENGTH 
                  then Nothing 
                  else Just . runPut $ do 
                    putInt (setEndian $ getServiceType m)
                    putGroup (fromPrivateGroup prvg)
                    putInt (fromIntegral $ numGroups)
                    putInt (setEndian $ fromIntegral (getType m) `shiftL` 8 .&. 0x00FFFF00)
                    putInt (fromIntegral $ B.length data')
                    mapM_ putGroup groups
                    putByteString data' 
    where groups = getGroups m
          numGroups = (length groups) 
          data' = getData m
          numBytes = 16 + mAX_GROUP_NAME * (numGroups + 1) + B.length data'