{-# LANGUAGE RecordWildCards #-}
module Network.Stun.Base where

import           Control.Monad
import           Data.Bits
import qualified Data.ByteString as BS
import           Data.Digest.CRC32
import           Data.Serialize
import           Data.Word

type Method = Word16

data MessageClass = Request
                  | Success
                  | Failure
                  | Indication
                    deriving (Show, Eq)

data Attribute = Attribute { attributeType :: {-# UNPACK #-} !Word16
                           , attributeValue :: BS.ByteString
                           } deriving (Show, Eq)


data TransactionID = TID {-# UNPACK #-} !Word32
                         {-# UNPACK #-} !Word32
                         {-# UNPACK #-} !Word32
                         deriving (Show, Read, Eq)

data Message = Message { messageMethod :: !Method
                       , messageClass  :: !MessageClass
                       , transactionID :: !TransactionID
                       , messageAttributes   :: [Attribute]
                       , fingerprint   :: !Bool
                       } deriving (Eq, Show)

-- | "magic cookie" constant
cookie :: Word32
cookie = 0x2112A442

data AttributeError = AttributeWrongType | AttributeDecodeError
                                           deriving (Show, Eq)

class Serialize a => IsAttribute a where
    attributeTypeValue :: a -> Word16
    toAttribute        :: a -> Attribute
    toAttribute x = Attribute { attributeType = attributeTypeValue x
                              , attributeValue = encode x
                              }
    fromAttribute      :: Attribute -> Either AttributeError a
    fromAttribute (Attribute tp vl) = x
      where x = if tp == attributeTypeValue ((\(Right r) -> r) x) then
                  case decode vl of
                      Left _  -> Left AttributeDecodeError
                      Right r -> Right r
                else Left AttributeWrongType

findAttribute :: IsAttribute a => [Attribute] -> Either AttributeError [a]
findAttribute [] = Right []
findAttribute (x:xs) = case fromAttribute x of
    Right r -> (r :) `fmap` findAttribute xs
    Left AttributeWrongType -> findAttribute xs
    Left AttributeDecodeError -> Left AttributeDecodeError


putAttribute :: Attribute -> PutM ()
putAttribute Attribute{..} = do
    putWord16be attributeType
    putWord16be (fromIntegral $ BS.length attributeValue)
    putByteString attributeValue
    -- padding:
    replicateM_ (negate (BS.length attributeValue) `mod` 4) $ putWord8 0
    return ()

getAttribute :: Get Attribute
getAttribute = do
    attributeType <- getWord16be
    leng <- getWord16be
    attributeValue <- getBytes (fromIntegral leng)
    -- consume padding:
    _ <- replicateM (negate (fromIntegral leng) `mod` 4) $ getWord8
    return Attribute{..}

instance Serialize Attribute where
    put = putAttribute
    get = getAttribute

encodeMessageType :: Method -> MessageClass -> Word16
encodeMessageType method messageClass =
    (method .&. 0xf)                    -- least 4 bits remain the same
    .|. (c0 `shiftL` 4)                 -- bit 5 is class low bit
    .|. ((method .&. 0x70)  `shiftL` 1) -- next 3 bits are offset by 1
    .|. (c1 `shiftL` 8)                 -- bit 9 is class high bit
    .|. ((method .&. 0xf80) `shiftL` 2) -- highest 5 bits are offset by 2
    -- most significant 2 bits remain 0
  where
    (c1, c0) = case messageClass of
        Request    -> (0,0) :: (Word16, Word16)
        Success    -> (1,0)
        Failure    -> (1,1)
        Indication -> (0,1)

decodeMessageType :: Word16 -> (Method, MessageClass)
decodeMessageType word = (method, mClass)
  where
    mClass = case (c1, c0) of
        (False, False) -> Request
        (True , False) -> Success
        (True , True ) -> Failure
        (False, True ) -> Indication
    c0 = testBit word 4
    c1 = testBit word 8
    method =
        (word .&. 0xf)                     -- least 4 bits remain the same
        .|. ((word .&. 0xe0)  `shiftR` 1)  -- next 3 bits are offset by 1
        .|. ((word .&. 0x3e00) `shiftR` 2) -- highest 5 bits are offset by 2


fingerprintXorConstant :: Word32
fingerprintXorConstant = 0x5354554e

fingerprintAttribute :: Word32 -> Attribute
fingerprintAttribute crc = Attribute { attributeType = 0x8028
                            , attributeValue = encode $ crc `xor` fingerprintXorConstant
                            }

putPlainMessage :: Int -> Message -> PutM ()
putPlainMessage plusSize Message{..} = do
    putWord16be (encodeMessageType messageMethod messageClass)
    let messageBody = runPut . void $ mapM put messageAttributes
    let messageLength = (fromIntegral $ BS.length messageBody + plusSize)
    putWord16be messageLength
    putWord32be cookie
    let (TID tid1 tid2 tid3) = transactionID
    putWord32be tid1
    putWord32be tid2
    putWord32be tid3
    putByteString messageBody

putMessage :: Message -> PutM ()
putMessage m | fingerprint m = do
    -- The rfc demands that we crc32 the message until the beginning of the
    -- fingerprint attribute, but with the message length already set to the
    -- length of the entire message (including fingerprint), so we pass the
    -- length of the fingerprint attribute (8 byte) to be added to the length
    let msg = runPut $ putPlainMessage 8 m
    putByteString msg
    put . fingerprintAttribute . crc32 $ msg
             -- No fingerprint demanded
             | otherwise = putPlainMessage 0 m

getMessage :: Get Message
getMessage = do
    (mlen, msg) <- lookAhead $ do
        tp <- getWord16be
        guard $ 0xc000 .&. tp == 0 -- highest 2 bits are always 0
        let (messageMethod, messageClass) = decodeMessageType tp
        messageLength <- fromIntegral `fmap` getWord16be
        guard $ messageLength `mod` 4 == 0
        guard . (== cookie) =<< getWord32be -- "Magic cookie"
        transactionID <- liftM3 TID getWord32be getWord32be getWord32be
        messageAttributes <- isolate messageLength getMessageAttributes
        let fingerprint = False
        return (messageLength, Message{..})
    case reverse . messageAttributes $ msg of -- Fingerprint has to be the last
                                              -- attribute
        (Attribute 0x8028 fp :_) -> do
            start <- getBytes ( 20    -- header length
                              + mlen  -- plus message length
                              - 8     -- but only up to the beginning of
                                      -- fingerprint
                              )
            let crc = fingerprintXorConstant `xor` crc32 start
            label "fingeprint does not match" $ guard (encode crc == fp)
            return msg{ fingerprint = True
                      , messageAttributes = init . messageAttributes $ msg
                      }
        _ -> return msg
  where
    getMessageAttributes = isEmpty >>= \e -> if e then return [] else go
    go = do
        attr <- getAttribute
        empty <- isEmpty
        rest <- if empty then return [] else go
        return $ attr:rest

instance Serialize Message where
    put = putMessage
    get = getMessage

-- Helper for debugging bit-twiddling
showBits :: Bits a => a -> [Char]
showBits a = reverse [if testBit a i then '1' else '0' | i <- [0.. bitSize a - 1]]