module Network.CommSec
(
Connection(..)
, CommSecError(..)
, send, recv
, sendPtr, recvPtr
, accept
, connect
) where
import Crypto.Classes (buildKey)
import Crypto.Cipher.AES128.Internal (encryptCTR)
import Crypto.Cipher.AES128 (AESKey)
import Network.CommSec.Package
import Network.CommSec.Types
import Network.Socket ( Socket, SocketType(..), SockAddr, AddrInfo(..)
, defaultHints , getAddrInfo, sendBuf, addrAddress
, recvBuf, HostName, PortNumber)
import qualified Network.Socket as Net
import Control.Concurrent.MVar
import Control.Exception (throw)
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Internal as B
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Data.Word
import Data.Maybe (listToMaybe)
data Connection
= Conn { inCtx :: MVar InContext
, outCtx :: MVar OutContext
, socket :: Socket
}
pMVar :: MVar v -> v -> IO ()
pMVar m v = v `seq` putMVar m v
send :: Connection -> B.ByteString -> IO ()
send = sendWith takeMVar pMVar
recv :: Connection -> IO B.ByteString
recv = recvWith takeMVar pMVar
sendPtr :: Connection -> Ptr Word8 -> Int -> IO ()
sendPtr = sendPtrWith takeMVar pMVar
recvPtr :: Connection -> Ptr Word8 -> Int -> IO Int
recvPtr = recvPtrWith takeMVar pMVar
sendWith :: (MVar OutContext -> IO OutContext) -> (MVar OutContext -> OutContext -> IO ()) -> Connection -> B.ByteString -> IO ()
sendWith get put conn msg = B.useAsCStringLen msg $ \(ptPtr, ptLen) ->
sendPtrWith get put conn (castPtr ptPtr) ptLen
data RecvRes = Good | Small | Err deriving (Eq)
recvWith :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> IO B.ByteString
recvWith get put conn@(Conn {..}) = allocGo baseSize
where
baseSize = 2048
allocGo :: Int -> IO B.ByteString
allocGo n = allocaBytes sizeTagLen (go n)
go :: Int -> Ptr Word8 -> IO B.ByteString
go sz tmpPtr
| sz > 2^28 = error "recvWith: A message is over 256MB! Probably corrupt data or the stream is unsyncronized."
| otherwise = do
recvBytesPtr socket tmpPtr sizeTagLen
sz <- fromIntegral `fmap` peekBE32 tmpPtr
(b, res) <- B.createAndTrim' sz $ \ptPtr -> do
resSz <- recvPtrOfSz get put conn ptPtr sz
case resSz of
Left err -> if err `elem` retryOn then return (0,0,Err)
else throw err
Right s ->
if s > sz
then return (0,0,Small)
else return (0,s,Good)
case res of
Good -> return b
Small -> go (sz * 2) tmpPtr
Err -> go sz tmpPtr
retryOn :: [CommSecError]
retryOn = [DuplicateSeq, InvalidICV, BadPadding]
sendPtrWith :: (MVar OutContext -> IO OutContext) -> (MVar OutContext -> OutContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO ()
sendPtrWith get put c@(Conn {..}) ptPtr ptLen = do
let ctLen = encBytes ptLen
pktLen = sizeTagLen + ctLen
allocaBytes pktLen $ \pktPtr -> do
let ctPtr = pktPtr `plusPtr` sizeTagLen
pokeBE32 pktPtr (fromIntegral ctLen)
o <- get outCtx
o2 <- encodePtr o ptPtr ctPtr ptLen
put outCtx o2
sendBytesPtr socket pktPtr pktLen
return ()
recvPtrWith :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO Int
recvPtrWith get put c@(Conn{..}) ptPtr maxLen = do
r <- go
case r of
Nothing -> recvPtrWith get put c ptPtr maxLen
Just res -> return res
where
go :: IO (Maybe Int)
go = allocaBytes sizeTagLen $ \szPtr -> do
recvBytesPtr socket szPtr sizeTagLen
len <- fromIntegral `fmap` peekBE32 szPtr
let ptMaxSize = decBytes (len sizeTagLen)
allocaBytes len $ \ctPtr -> do
recvBytesPtr socket ctPtr len
i <- get inCtx
let finish pointer = do
dRes <- decodePtr i ctPtr pointer len
case dRes of
Left err -> if err `elem` retryOn then return Nothing
else throw err
Right (resLen,i2) -> put inCtx i2 >> return (Just resLen)
if ptMaxSize > maxLen
then allocaBytes ptMaxSize (\tmp -> do
res <- finish tmp
B.memcpy ptPtr tmp maxLen
return res)
else finish ptPtr
recvPtrOfSz :: (MVar InContext -> IO InContext) -> (MVar InContext -> InContext -> IO ()) -> Connection -> Ptr Word8 -> Int -> IO (Either CommSecError Int)
recvPtrOfSz get put (Conn {..}) ptPtr sz =
allocaBytes sz $ \ct -> do
recvBytesPtr socket ct sz
i <- get inCtx
dRes <- decodePtr i ct ptPtr sz
case dRes of
Left err -> return (Left err)
Right (resLen,i2) -> put inCtx i2 >> return (Right resLen)
recvBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
recvBytesPtr s p 0 = return ()
recvBytesPtr s p l = do
nr <- recvBuf s p l
recvBytesPtr s (p `plusPtr` nr) (l nr)
sendBytesPtr :: Socket -> Ptr Word8 -> Int -> IO ()
sendBytesPtr s p 0 = return ()
sendBytesPtr s p l = do
nr <- sendBuf s p l
sendBytesPtr s (p `plusPtr` nr) (l nr)
expandSecret :: B.ByteString -> Int -> B.ByteString
expandSecret entropy sz =
let k = buildKey entropy
in case k of
Nothing -> error "Build key failed"
Just key ->
let iv = B.replicate 16 0
in enc key iv input
where
input = B.replicate sz 0
enc :: AESKey -> B.ByteString -> B.ByteString -> B.ByteString
enc k i pt = B.unsafeCreate sz $ \ctPtr ->
B.useAsCString pt $ \ptPtr ->
B.useAsCString i $ \iv ->
encryptCTR k (castPtr iv) nullPtr (castPtr ctPtr) (castPtr ptPtr) sz
accept :: B.ByteString -> PortNumber -> IO Connection
accept = doAccept newMVar
doAccept :: (forall x. x -> IO (MVar x)) -> B.ByteString -> PortNumber -> IO Connection
doAccept create s p
| B.length s < 16 = error "Invalid input entropy"
| otherwise = do
let ent = expandSecret s 64
k1 = B.take 32 ent
k2 = B.drop 32 ent
iCtx = newInContext k1 Sequential
oCtx = newOutContext k2
sockaddr = Net.SockAddrInet p Net.iNADDR_ANY
sock <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
Net.setSocketOption sock Net.ReuseAddr 1
Net.bind sock sockaddr
Net.listen sock 10
socket <- fst `fmap` Net.accept sock
Net.setSocketOption socket Net.NoDelay 1
Net.close sock
inCtx <- create iCtx
outCtx <- create oCtx
return (Conn {..})
doConnect :: (forall x. x -> IO (MVar x)) -> B.ByteString -> HostName -> PortNumber -> IO Connection
doConnect create s hn p
| B.length s < 16 = error "Invalid input entropy"
| otherwise = do
sockaddr <- resolve hn p
let ent = expandSecret s 64
k2 = B.take 32 ent
k1 = B.drop 32 ent
iCtx = newInContext k1 Sequential
oCtx = newOutContext k2
socket <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
Net.connect socket sockaddr
Net.setSocketOption socket Net.NoDelay 1
Net.setSocketOption socket Net.ReuseAddr 1
inCtx <- create iCtx
outCtx <- create oCtx
return (Conn {..})
where
resolve :: HostName -> PortNumber -> IO SockAddr
resolve h port = do
ai <- getAddrInfo (Just $ defaultHints { addrFamily = Net.AF_INET, addrSocketType = Stream } ) (Just h) (Just (show port))
return (maybe (error $ "Could not resolve host " ++ h) addrAddress (listToMaybe ai))
connect :: B.ByteString
-> HostName
-> PortNumber
-> IO Connection
connect = doConnect newMVar
sizeTagLen :: Int
sizeTagLen = 4