module Network.CommSec.KeyExchange
( connect
, accept
, keyExchangeInit, keyExchangeResp
, CS.send, CS.recv, CS.Connection, Net.HostName, Net.PortNumber
) where
import qualified Network.Socket as Net
import qualified Network.Socket.ByteString as NetBS
import Crypto.Types.PubKey.RSA
import Crypto.Cipher.AES128
import Crypto.Classes
import Crypto.Util
import Crypto.Modes (zeroIV)
import Crypto.Hash.CryptoAPI
import Control.Monad
import Control.Monad.CryptoRandom
import qualified Codec.Crypto.RSA as RSA
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (fromStrict, toChunks)
import qualified Data.ByteString.Lazy as L
import Data.Serialize
import Data.Serialize.Get
import Data.Serialize.Put
import Crypto.Random.DRBG
import Data.Maybe (listToMaybe)
import Control.Concurrent
import Foreign.Storable
import qualified Network.CommSec as CS
import Network.CommSec hiding (accept, connect)
import Network.CommSec.Package (InContext(..), OutContext(..))
thePrime :: Integer
thePrime = 0x87A8E61DB4B6663CFFBBD19C651959998CEEF608660DD0F25D2CEED4435E3B00E00DF8F1D61957D4FAF7DF4561B2AA3016C3D91134096FAA3BF4296D830E9A7C209E0C6497517ABD5A8A9D306BCF67ED91F9E6725B4758C022E0B1EF4275BF7B6C5BFC11D45F9088B941F54EB1E59BB8BC39A0BF12307F5C4FDB70C581B23F76B63ACAE1CAA6B7902D52526735488A0EF13C6D9A51BFA4AB3AD8347796524D8EF6A167B5A41825D967E144E5140564251CCACB83E6B486F6B3CA3F7971506026C0B857F689962856DED4010ABD0BE621C3A3960A54E710C375F26375D7014103A4B54330C198AF126116D2276E11715F693877FAD7EF09CADB094AE91E1A1597
theGenerator :: Integer
theGenerator = 5
signExps :: Integer -> Integer -> PrivateKey -> ByteString
signExps a b k = L.toStrict . RSA.sign k $ encodeExps a b
verifyExps :: Integer -> Integer -> ByteString -> PublicKey -> Bool
verifyExps a b sig k = RSA.verify k (encodeExps a b) (fromStrict sig)
encodeExps :: Integer -> Integer -> L.ByteString
encodeExps a b = fromStrict . runPut $ put a >> put b
getXaX :: IO (Integer, Integer)
getXaX = do
g <- newGenIO :: IO HmacDRBG
let (x,_) = throwLeft $ crandomR (1,thePrime2) g
ax = modexp theGenerator x thePrime
return (x,ax)
keyExchangeResp :: Net.Socket -> PublicKey -> PrivateKey -> IO (OutContext, InContext)
keyExchangeResp sock publicThem privateMe = do
(y,ay) <- getXaX
ax <- (either error id . decode) `fmap` recvMsg sock
let axy = modexp ax y thePrime
sharedSecret = encode . sha256 $ i2bs (2048 `div` 8) axy
shared512 = expandSecret sharedSecret (16 + 16 + 4 + 4)
(aesKey1, aesKey2, salt1, salt2) =
let (key1tmp, rest1) = B.splitAt (keyLengthBytes `for` aesKey1) shared512
(key2tmp, rest2) = B.splitAt (keyLengthBytes `for` aesKey2) rest1
(salt1tmp, rest3) = B.splitAt (sizeOf salt1) rest2
salt2tmp = B.take (sizeOf salt2) rest3
op = fromIntegral . bs2i
bk = maybe (error "failed to build key") id . buildKey
in (bk key1tmp, bk key2tmp, op salt1tmp, op salt2tmp)
mySig = signExps ay ax privateMe
(enc, _) = ctr aesKey1 zeroIV mySig
outCtx = Out 2 salt1 aesKey1
inCtx = InStrict 1 salt2 aesKey2
sendMsg sock (runPut (put ay >> put enc))
encSaAxAy <- recvMsg sock
let theirSig = fst $ unCtr aesKey2 zeroIV encSaAxAy
when (not $ verifyExps ax ay theirSig publicThem)
(error "RESP: Verification failed when exchanging key. Man in the middle?")
return (outCtx, inCtx)
keyExchangeInit :: Net.Socket -> PublicKey -> PrivateKey -> IO (OutContext, InContext)
keyExchangeInit sock publicThem privateMe = do
(x,ax) <- getXaX
sendMsg sock (encode ax)
pkg <- recvMsg sock
let (ay, encSbAyAx) = either error id (decodePkg pkg)
decodePkg = runGet (do i <- get
e <- get
return (i,e))
axy = modexp ay x thePrime :: Integer
sharedSecret = encode . sha256 $ i2bs (2048 `div` 8) axy
shared512 = expandSecret sharedSecret 64
(aesKey1, aesKey2, salt1, salt2) =
let (key1tmp, rest1) = B.splitAt (keyLengthBytes `for` aesKey1) shared512
(key2tmp, rest2) = B.splitAt (keyLengthBytes `for` aesKey2) rest1
(salt1tmp, rest3) = B.splitAt (sizeOf salt1) rest2
salt2tmp = B.take (sizeOf salt2) rest3
op = fromIntegral . bs2i
bk = maybe (error "failed to build key") id . buildKey
in (bk key1tmp, bk key2tmp, op salt1tmp, op salt2tmp)
mySig = signExps ax ay privateMe
(enc, _) = ctr aesKey2 zeroIV mySig
outCtx = Out 2 salt2 aesKey2
inCtx = InStrict 1 salt1 aesKey1
theirSig = fst $ unCtr aesKey1 zeroIV encSbAyAx
when (not $ verifyExps ay ax theirSig publicThem)
(error "INIT: Verification failed when exchanging key. Man in the middle?")
sendMsg sock enc
return (outCtx, inCtx)
connect :: Net.HostName -> Net.PortNumber -> PublicKey -> PrivateKey -> IO Connection
connect host port them us = do
sockaddr <- resolve host port
socket <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
Net.connect socket sockaddr
Net.setSocketOption socket Net.NoDelay 1
Net.setSocketOption socket Net.ReuseAddr 1
(oCtx, iCtx) <- keyExchangeInit socket them us
inCtx <- newMVar iCtx
outCtx <- newMVar oCtx
return (Conn {..})
where
resolve :: Net.HostName -> Net.PortNumber -> IO Net.SockAddr
resolve h port = do
ai <- Net.getAddrInfo (Just $ Net.defaultHints {
Net.addrFamily = Net.AF_INET, Net.addrSocketType = Net.Stream } ) (Just h) (Just (show port))
return (maybe (error $ "Could not resolve host " ++ h) Net.addrAddress (listToMaybe ai))
accept :: Net.PortNumber -> PublicKey -> PrivateKey -> IO Connection
accept port them us = do
let sockaddr = Net.SockAddrInet port Net.iNADDR_ANY
sock <- Net.socket Net.AF_INET Net.Stream Net.defaultProtocol
Net.setSocketOption sock Net.ReuseAddr 1
Net.bind sock sockaddr
Net.listen sock 1
socket <- fst `fmap` Net.accept sock
Net.setSocketOption socket Net.NoDelay 1
Net.close sock
(oCtx, iCtx) <- keyExchangeResp socket them us
outCtx <- newMVar oCtx
inCtx <- newMVar iCtx
return (Conn {..})
recvMsg :: Net.Socket -> IO ByteString
recvMsg s = do
lenBS <- recvAll s 4
let len = fromIntegral . either error id . runGet getWord32be $ lenBS
recvAll s len
recvAll :: Net.Socket -> Int -> IO ByteString
recvAll s nr = go nr []
where
go 0 x = return $ B.concat (reverse x)
go n x = do
bs <- NetBS.recv s n
go (n B.length bs) (bs:x)
sendMsg :: Net.Socket -> ByteString -> IO ()
sendMsg s msg = do
let pkt = B.append (runPut . putWord32be . fromIntegral . B.length $ msg) msg
NetBS.sendAll s pkt
keyLengthBytes = fmap ((`div` 8) . (+7)) keyLength
sha256 :: ByteString -> SHA256
sha256 bs = hash' bs
modexp :: Integer -> Integer -> Integer -> Integer
modexp b e n = go 1 b e
where
go !p _ 0 = p
go !p !x !e =
if even e
then go p (mod (x*x) n) (div e 2)
else go (mod (p*x) n) x (pred e)