{-# LANGUAGE RecordWildCards, RankNTypes #-}
module Network.CommSec
    (
    -- * Types
      Connection(..)
    , CommSecError(..)
    -- * Send and receive operations
    , send, recv
    , sendPtr, recvPtr
    -- * Establishing a connection from a shared secret
    , accept
    , connect
    , close
    -- * Establishing a connection from a public identity (PKI)
    -- , acceptId
    -- , connectId
    -- * Utility
    , expandSecret
    ) where

import Crypto.Classes (buildKey)
import Crypto.Cipher.AES128.Internal (encryptCTR)
import Crypto.Cipher.AES128 (AESKey128, ctr, zeroIV)
import Network.CommSec.Package
import Network.CommSec.Types
import Network.Socket ( Socket, SocketType(..), SockAddr, AddrInfo(..)
                      , defaultHints , getAddrInfo, sendBuf, addrAddress
                      , recvBuf, HostName, PortNumber)
import qualified Network.Socket as Net
import Control.Concurrent.MVar
import Control.Exception (throw)
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Internal as B
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Data.Word
import Data.Maybe (isJust, listToMaybe)

-- | A connection is a secure bidirectional communication channel.
data Connection
            = Conn { inCtx        :: MVar InContext
                   , outCtx       :: MVar OutContext
                   , socket       :: Socket
                   , socketAddr   :: SockAddr -- ^ address of remote
                   }

pMVar :: MVar v -> v -> IO ()
pMVar m v = v `seq` putMVar m v

-- |Send a datagram, first encrypting it, using the given secure
-- connection.
send :: Connection -> B.ByteString -> IO ()
send conn msg = B.useAsCStringLen msg $ \(ptPtr, ptLen) ->
  sendPtr conn (castPtr ptPtr) ptLen

data RecvRes = Good | Err deriving (Eq)

recv :: Connection -> IO B.ByteString
recv conn@(Conn {..}) =
  allocaBytes sizeTagLen $ \tmpPtr ->
  modifyMVar inCtx       $ \iCtx   ->
  go tmpPtr iCtx

  where

  go :: Ptr Word8 -> InContext -> IO (InContext, B.ByteString)
  go tmpPtr iCtx = do
    recvBytesPtr socket tmpPtr sizeTagLen -- XXX This is unacceptable for datagrams
    sz <- fromIntegral `fmap` peekBE32 tmpPtr
    when (sz > 2^28)
      (fail "recv: A message is over 256MB! Probably corrupt data or the stream is unsyncronized.")
    (b, (iCtx,res)) <- B.createAndTrim' sz $ \ptPtr -> -- second time iCtx shadowed
        do (iCtx,resSz) <- recvPtrOfSz socket iCtx ptPtr sz -- first time iCtx shadowed
           case resSz of
             Left err
               | err `elem` retryOn -> return (0,0,(iCtx,Err))
               | otherwise          -> throw err
             Right s                -> return (0,s,(iCtx,Good))
    case res of
            Good   -> return (iCtx,b)
            Err    -> go tmpPtr iCtx

retryOn :: [CommSecError]
retryOn = [DuplicateSeq, InvalidICV, BadPadding]

-- |Sends a message over the connection.
sendPtr :: Connection -> Ptr Word8 -> Int -> IO ()
sendPtr c@(Conn {..}) ptPtr ptLen = do
    let ctLen  = encBytes ptLen
        pktLen = sizeTagLen + ctLen
    allocaBytes pktLen $ \pktPtr -> do
        let ctPtr = pktPtr `plusPtr` sizeTagLen
        pokeBE32 pktPtr (fromIntegral ctLen)
        modifyMVar_ outCtx $ \oCtx -> do
            r <- encodePtr oCtx ptPtr ctPtr ptLen
            sendBytesPtr socket pktPtr pktLen
            return r

-- |Blocks till it receives a valid message, placing the resulting plaintext
-- in the provided buffer.  If the incoming message is larger that the
-- provided buffer then the message is truncated.  This process also incurs
-- an additional copy.
recvPtr :: Connection -> Ptr Word8 -> Int -> IO Int
recvPtr c@(Conn{..}) ptPtr maxLen =
  allocaBytes sizeTagLen $ \szPtr ->
  modifyMVar inCtx       $ \iCtx -> do
  go szPtr iCtx

  where

  go szPtr iCtx = do
    recvBytesPtr socket szPtr sizeTagLen
    len <- fromIntegral `fmap` peekBE32 szPtr
    let ptMaxSize = decBytes (len - sizeTagLen)
    (iCtx, mbOutLen) <- allocaBytes len $ \ctPtr -> do -- shadow iCtx
      recvBytesPtr socket ctPtr len
      let finish pointer = do
              dRes <- decodePtr iCtx ctPtr pointer len
              case dRes of
                Left err
                  | err `elem` retryOn -> return (iCtx, Nothing) -- preserve old context
                  | otherwise          -> throw err
                Right (resLen,i2) -> return (i2, Just resLen)
      if ptMaxSize > maxLen -- shadow iCtx
        then allocaBytes ptMaxSize $ \tmp ->
               do res <- finish tmp
                  when (isJust (snd res)) -- don't bother copying in the retry case
                       (B.memcpy ptPtr tmp maxLen)
                  return res
        else finish ptPtr
    case mbOutLen of
      Nothing     -> go szPtr iCtx -- retry
      Just outLen -> return (iCtx, outLen)

-- Receive sz bytes and decode it into ptPtr, helper for recvWith
recvPtrOfSz :: Socket -> InContext -> Ptr Word8 -> Int -> IO (InContext, Either CommSecError Int)
recvPtrOfSz socket iCtx ptPtr sz =
    allocaBytes sz $ \ct -> do
        recvBytesPtr socket ct sz
        dRes <- decodePtr iCtx ct ptPtr sz
        return $ case dRes of
          Left err          -> (iCtx, Left err) -- restore old context
          Right (resLen,i2) -> (i2  , Right resLen)

-- Retry until we have received exactly the specified number of bytes
recvBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
recvBytesPtr s p 0 = return ()
recvBytesPtr s p l = do
        nr <- recvBuf s p l
        recvBytesPtr s (p `plusPtr` nr) (l - nr)

-- Retry until we have sent exactly the specified number of bytes
sendBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
sendBytesPtr s p 0 = return ()
sendBytesPtr s p l = do
        nr <- sendBuf s p l
        sendBytesPtr s (p `plusPtr` nr) (l - nr)

-- Use counter mode to expand input entropy that is at least 16 bytes long
expandSecret :: B.ByteString -> Int -> B.ByteString
expandSecret entropy sz =
    let k = buildKey entropy :: Maybe AESKey128
    in case k of
           Nothing  -> error "Build key failed"
           Just key -> fst $ ctr key zeroIV (B.replicate sz 0)

-- |Expands the provided 128 (or more) bit secret into two
-- keys to create a connection.
--
-- ex: accept ent 3134
accept  :: B.ByteString -> PortNumber -> IO Connection
accept = doAccept newMVar

doAccept  :: (forall x. x -> IO (MVar x)) -> B.ByteString -> PortNumber -> IO Connection
doAccept create s p
  | B.length s < 16 = error "Invalid input entropy"
  | otherwise = do
    let ent   = expandSecret s 64
        k1    = B.take 32 ent
        k2    = B.drop 32 ent
        iCtx  = newInContext k1 Sequential
        oCtx  = newOutContext k2
        sockaddr = Net.SockAddrInet p Net.iNADDR_ANY
    sock <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
    Net.setSocketOption sock Net.ReuseAddr 1
    Net.bind sock sockaddr
    Net.listen sock 10
    (socket,socketAddr) <- Net.accept sock
    Net.setSocketOption socket Net.NoDelay 1
    Net.close sock
    inCtx  <- create iCtx
    outCtx <- create oCtx
    return (Conn {..})

doConnect  :: (forall x. x -> IO (MVar x)) -> B.ByteString -> HostName -> PortNumber -> IO Connection
doConnect create s hn p
  | B.length s < 16 = error "Invalid input entropy"
  | otherwise = do
    socketAddr <- resolve hn p
    let ent  = expandSecret s 64
        k2   = B.take 32 ent
        k1   = B.drop 32 ent
        iCtx = newInContext k1 Sequential
        oCtx = newOutContext k2
    socket <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
    Net.connect socket socketAddr
    Net.setSocketOption socket Net.NoDelay 1
    Net.setSocketOption socket Net.ReuseAddr 1
    inCtx  <- create iCtx
    outCtx <- create oCtx
    return (Conn {..})
  where
      resolve :: HostName -> PortNumber -> IO SockAddr
      resolve h port = do
        ai <- getAddrInfo (Just $ defaultHints { addrFamily = Net.AF_INET, addrSocketType = Stream } ) (Just h) (Just (show port))
        return (maybe (error $ "Could not resolve host " ++ h) addrAddress (listToMaybe ai))

-- |Expands the provided 128 (or more) bit secret into two
-- keys to create a connection.
connect :: B.ByteString
        -> HostName
        -> PortNumber
        -> IO Connection
connect = doConnect newMVar

-- |Close a connection
close :: Connection -> IO ()
close c = Net.close (socket c)

-- |We use a word32 to indicate the size of a datagram
sizeTagLen :: Int
sizeTagLen = 4