{-# LANGUAGE RecordWildCards, RankNTypes #-}
module Network.CommSec
    (
    -- * Types
      Connection(..)
    , CommSecError(..)
    -- * Send and receive operations
    , send, recv
    , sendPtr, recvPtr
    -- * Establishing a connection from a shared secret
    , accept
    , connect
    -- * Establishing a connection from a public identity (PKI)
    -- , acceptId
    -- , connectId
    ) where

import Crypto.Classes (buildKey)
import Crypto.Cipher.AES128.Internal (encryptCTR)
import Crypto.Cipher.AES128 (AESKey)
import Network.CommSec.Package
import Network.CommSec.Types
import Network.Socket ( Socket, SocketType(..), SockAddr, AddrInfo(..)
                      , defaultHints , getAddrInfo, sendBuf, addrAddress
                      , recvBuf, HostName, PortNumber)
import qualified Network.Socket as Net
import Control.Concurrent.MVar
import Control.Exception (throw)
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Internal as B
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Data.Word
import Data.Maybe (listToMaybe)

-- | A connection is a secure bidirectional communication channel.
data Connection
            = Conn { inCtx        :: MVar InContext
                   , outCtx       :: MVar OutContext
                   , socket       :: Socket
                   }

pMVar :: MVar v -> v -> IO ()
pMVar m v = v `seq` putMVar m v

send :: Connection -> B.ByteString -> IO ()
send = sendWith takeMVar pMVar

recv :: Connection -> IO B.ByteString
recv = recvWith takeMVar pMVar

-- |Sends a message over the connection.
sendPtr :: Connection -> Ptr Word8 -> Int -> IO ()
sendPtr = sendPtrWith takeMVar pMVar

-- |Blocks till it receives a valid message, placing the resulting plaintext
-- in the provided buffer.  If the incoming message is larger that the
-- provided buffer then the message is truncated.  This process also incurs
-- an additional copy.
recvPtr :: Connection -> Ptr Word8 -> Int -> IO Int
recvPtr = recvPtrWith takeMVar pMVar

-- helper for send
sendWith :: (MVar OutContext -> IO OutContext) -> (MVar OutContext -> OutContext -> IO ()) -> Connection -> B.ByteString -> IO ()
sendWith get put conn msg = B.useAsCStringLen msg $ \(ptPtr, ptLen) ->
  sendPtrWith get put conn (castPtr ptPtr) ptLen
{-# INLINE sendWith #-}

data RecvRes = Good | Small | Err deriving (Eq)

-- helper for recv
recvWith :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> IO B.ByteString
recvWith get put conn@(Conn {..}) = allocGo baseSize
 where
    baseSize = 2048
    allocGo :: Int -> IO B.ByteString
    allocGo n = allocaBytes sizeTagLen (go n)

    --

    go :: Int -> Ptr Word8 -> IO B.ByteString
    go sz tmpPtr
      | sz > 2^28 = error "recvWith: A message is over 256MB! Probably corrupt data or the stream is unsyncronized."
      | otherwise = do
          recvBytesPtr socket tmpPtr sizeTagLen
          sz <- fromIntegral `fmap` peekBE32 tmpPtr
          (b, res) <- B.createAndTrim' sz $ \ptPtr -> do
            resSz <- recvPtrOfSz get put conn ptPtr sz
            case resSz of
                Left err -> if err `elem` retryOn then return (0,0,Err)
                                                  else throw err
                Right s  ->
                    if s > sz
                        then return (0,0,Small)
                        else return (0,s,Good)
          case res of
              Good   -> return b
              Small  -> go (sz * 2) tmpPtr
              Err    -> go sz tmpPtr
{-# INLINE recvWith #-}

retryOn :: [CommSecError]
retryOn = [DuplicateSeq, InvalidICV, BadPadding]

-- helper for sendPtr
sendPtrWith :: (MVar OutContext -> IO OutContext) -> (MVar OutContext -> OutContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO ()
sendPtrWith get put c@(Conn {..}) ptPtr ptLen = do
    let ctLen  = encBytes ptLen
        pktLen = sizeTagLen + ctLen
    allocaBytes pktLen $ \pktPtr -> do
        let ctPtr = pktPtr `plusPtr` sizeTagLen
        pokeBE32 pktPtr (fromIntegral ctLen)
        o  <- get outCtx
        o2 <- encodePtr o ptPtr ctPtr ptLen
        put outCtx o2
        sendBytesPtr socket pktPtr pktLen
        return ()

-- helper for recvPtr
recvPtrWith :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO Int
recvPtrWith get put c@(Conn{..}) ptPtr maxLen = do
    r <- go
    case r of
        Nothing  -> recvPtrWith get put c ptPtr maxLen
        Just res -> return res
 where
  go :: IO (Maybe Int)
  go = allocaBytes sizeTagLen $ \szPtr -> do
    recvBytesPtr socket szPtr sizeTagLen
    len <- fromIntegral `fmap` peekBE32 szPtr
    let ptMaxSize = decBytes (len - sizeTagLen)
    allocaBytes len $ \ctPtr -> do
      recvBytesPtr socket ctPtr len
      i <- get inCtx
      let finish pointer = do
              dRes <- decodePtr i ctPtr pointer len
              case dRes of
                Left err -> if err `elem` retryOn then return Nothing
                                                  else throw err
                Right (resLen,i2) -> put inCtx i2 >> return (Just resLen)
      if ptMaxSize > maxLen
          then allocaBytes ptMaxSize (\tmp -> do
                        res <- finish tmp
                        B.memcpy ptPtr tmp maxLen
                        return res)
          else finish ptPtr

-- Receive sz bytes and decode it into ptPtr, helper for recvWith
recvPtrOfSz :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO (Either CommSecError Int)
recvPtrOfSz get put (Conn {..}) ptPtr sz =
    allocaBytes sz $ \ct -> do
        recvBytesPtr socket ct sz
        i <- get inCtx
        dRes <- decodePtr i ct ptPtr sz
        case dRes of
                Left err -> return (Left err)
                Right (resLen,i2) -> put inCtx i2 >> return (Right resLen)

-- Retry until we have received exactly the specified number of bytes
recvBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
recvBytesPtr s p 0 = return ()
recvBytesPtr s p l = do
        nr <- recvBuf s p l
        recvBytesPtr s (p `plusPtr` nr) (l - nr)

-- Retry until we have sent exactly the specified number of bytes
sendBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
sendBytesPtr s p 0 = return ()
sendBytesPtr s p l = do
        nr <- sendBuf s p l
        sendBytesPtr s (p `plusPtr` nr) (l - nr)

-- Use counter mode to expand input entropy that is at least 16 bytes long
expandSecret :: B.ByteString -> Int -> B.ByteString
expandSecret entropy sz =
    let k = buildKey entropy
    in case k of
           Nothing  -> error "Build key failed"
           Just key ->
              let iv = B.replicate 16 0
              in enc key iv input
 where
  input = B.replicate sz 0
  enc :: AESKey -> B.ByteString -> B.ByteString -> B.ByteString
  enc k i pt = B.unsafeCreate sz $ \ctPtr ->
                B.useAsCString pt $ \ptPtr ->
                 B.useAsCString i $ \iv ->
                   encryptCTR k (castPtr iv) nullPtr (castPtr ctPtr) (castPtr ptPtr) sz

-- |Expands the provided 128 (or more) bit secret into two
-- keys to create a connection.
--
-- ex: accept ent 3134
accept  :: B.ByteString -> PortNumber -> IO Connection
accept = doAccept newMVar

doAccept  :: (forall x. x -> IO (MVar x)) -> B.ByteString -> PortNumber -> IO Connection
doAccept create s p
  | B.length s < 16 = error "Invalid input entropy"
  | otherwise = do
    let ent   = expandSecret s 64
        k1    = B.take 32 ent
        k2    = B.drop 32 ent
        iCtx  = newInContext k1 Sequential
        oCtx  = newOutContext k2
        sockaddr = Net.SockAddrInet p Net.iNADDR_ANY
    sock <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
    Net.setSocketOption sock Net.ReuseAddr 1
    Net.bind sock sockaddr
    Net.listen sock 10
    socket <- fst `fmap` Net.accept sock
    Net.setSocketOption socket Net.NoDelay 1
    Net.close sock
    inCtx  <- create iCtx
    outCtx <- create oCtx
    return (Conn {..})

doConnect  :: (forall x. x -> IO (MVar x)) -> B.ByteString -> HostName -> PortNumber -> IO Connection
doConnect create s hn p
  | B.length s < 16 = error "Invalid input entropy"
  | otherwise = do
    sockaddr <- resolve hn p
    let ent  = expandSecret s 64
        k2   = B.take 32 ent
        k1   = B.drop 32 ent
        iCtx = newInContext k1 Sequential
        oCtx = newOutContext k2
    socket <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
    Net.connect socket sockaddr
    Net.setSocketOption socket Net.NoDelay 1
    Net.setSocketOption socket Net.ReuseAddr 1
    inCtx  <- create iCtx
    outCtx <- create oCtx
    return (Conn {..})
  where
      resolve :: HostName -> PortNumber -> IO SockAddr
      resolve h port = do
        ai <- getAddrInfo (Just $ defaultHints { addrFamily = Net.AF_INET, addrSocketType = Stream } ) (Just h) (Just (show port))
        return (maybe (error $ "Could not resolve host " ++ h) addrAddress (listToMaybe ai))

-- |Expands the provided 128 (or more) bit secret into two
-- keys to create a connection.
connect :: B.ByteString
        -> HostName
        -> PortNumber
        -> IO Connection
connect = doConnect newMVar

-- |We use a word32 to indicate the size of a datagram
sizeTagLen :: Int
sizeTagLen = 4