module Crypto.Noise.Internal.HandshakeState
(
HandshakeCallbacks(..),
HandshakeState,
HandshakeStateParams(..),
SendingCipherState,
ReceivingCipherState,
handshakeState,
runHandshake,
evalHandshakePattern,
evalToken,
encryptPayload,
decryptPayload
) where
import Control.Exception (throw)
import Control.Lens hiding (re)
import Control.Monad.Free.Church
import Control.Monad.State
import Data.ByteString (ByteString)
import qualified Data.ByteString as B (empty, splitAt)
import Data.Maybe (fromMaybe, isJust)
import Data.Proxy
import Crypto.Noise.Cipher
import Crypto.Noise.Curve
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Internal.HandshakePattern hiding (s, split)
import Crypto.Noise.Types
data HandshakeStateParams c d =
HandshakeStateParams { hspPattern :: HandshakePattern c
, hspPrologue :: Plaintext
, hspPreSharedKey :: Maybe Plaintext
, hspLocalStaticKey :: Maybe (KeyPair d)
, hspLocalEphemeralKey :: Maybe (KeyPair d)
, hspRemoteStaticKey :: Maybe (PublicKey d)
, hspRemoteEphemeralKey :: Maybe (PublicKey d)
, hspInitiator :: Bool
}
data HandshakeState c d h =
HandshakeState { _hsSymmetricState :: SymmetricState c h
, _hsLocalStaticKey :: Maybe (KeyPair d)
, _hsLocalEphemeralKey :: Maybe (KeyPair d)
, _hsRemoteStaticKey :: Maybe (PublicKey d)
, _hsRemoteEphemeralKey :: Maybe (PublicKey d)
, _hsInitiator :: Bool
, _hsMsgBuffer :: ByteString
, _hsPattern :: HandshakePattern c
}
data HandshakeCallbacks =
HandshakeCallbacks { hscbSend :: ByteString -> IO ()
, hscbRecv :: IO ByteString
, hscbPayloadIn :: Plaintext -> IO ()
, hscbPayloadOut :: IO Plaintext
}
type HandshakeMonad c d h = StateT (HandshakeState c d h) IO
type PreMsgMonad c d h = StateT (HandshakeState c d h) Identity
newtype SendingCipherState c = SCS { _unSCS :: CipherState c }
newtype ReceivingCipherState c = RCS { _unRCS :: CipherState c }
$(makeLenses ''HandshakeState)
handshakeState :: forall c d h. (Cipher c, Curve d, Hash h)
=> HandshakeStateParams c d
-> HandshakeState c d h
handshakeState HandshakeStateParams{..} =
maybe hs'' hs''' $ hspPattern ^. hpPreActions
where
ss = symmetricState $ mkHPN hs (hspPattern ^. hpName) (isJust hspPreSharedKey)
hs = HandshakeState ss
hspLocalStaticKey
hspLocalEphemeralKey
hspRemoteStaticKey
hspRemoteEphemeralKey
hspInitiator
""
hspPattern
hs' = doPrologue hspPrologue hs
hs'' = maybe hs' (`doPSK` hs') hspPreSharedKey
hs''' pmp = runIdentity . execStateT (iterM evalPreMsgPattern pmp) $ hs''
doPrologue :: forall c d h. (Cipher c, Curve d, Hash h)
=> Plaintext
-> HandshakeState c d h
-> HandshakeState c d h
doPrologue (Plaintext pro) hs = hs & hsSymmetricState %~ mixHash pro
doPSK :: forall c d h. (Cipher c, Curve d, Hash h)
=> Plaintext
-> HandshakeState c d h
-> HandshakeState c d h
doPSK (Plaintext psk) hs = hs & hsSymmetricState %~ mixPSK psk
mkHPN :: forall c d h. (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> ByteString
-> Bool
-> ScrubbedBytes
mkHPN _ hpn psk = if psk then hpn' ppsk else hpn' p
where
p = bsToSB' "Noise_"
ppsk = bsToSB' "NoisePSK_"
a = curveName (Proxy :: Proxy d)
b = cipherName (Proxy :: Proxy c)
c = hashName (Proxy :: Proxy h)
u = bsToSB' "_"
hpn' pfx = concatSB [pfx, bsToSB' hpn, u, a, u, b, u, c]
runHandshake :: (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> HandshakeCallbacks
-> IO (SendingCipherState c, ReceivingCipherState c)
runHandshake hs hscb = do
(cs1, cs2) <- evalStateT p hs
if hs ^. hsInitiator then
return (SCS cs1, RCS cs2)
else
return (SCS cs2, RCS cs1)
where
p = iterM (evalHandshakePattern hscb) (hs ^. hsPattern ^. hpActions)
evalHandshakePattern :: (Cipher c, Curve d, Hash h)
=> HandshakeCallbacks
-> HandshakePatternF (HandshakeMonad c d h (CipherState c, CipherState c))
-> HandshakeMonad c d h (CipherState c, CipherState c)
evalHandshakePattern hscb p = do
hs <- get
case p of
Initiator t next -> sendOrRecv hs True t next
Responder t next -> sendOrRecv hs False t next
Split -> return . split $ hs ^. hsSymmetricState
where
sendOrRecv hs i t next = do
if i == hs ^. hsInitiator then do
iterM (evalToken i) t
hs' <- get
payload <- liftIO $ hscbPayloadOut hscb
let (ep, ss) = encryptAndHash payload $ hs' ^. hsSymmetricState
toSend = (hs' ^. hsMsgBuffer) `mappend` sbToBS' ep
liftIO . hscbSend hscb $ toSend
put $ hs' & hsMsgBuffer .~ B.empty
& hsSymmetricState .~ ss
else do
msg <- liftIO $ hscbRecv hscb
put $ hs & hsMsgBuffer .~ msg
iterM (evalToken i) t
hs' <- get
let remaining = hs' ^. hsMsgBuffer
(dp, ss) = decryptAndHash (cipherBytesToText (bsToSB' remaining))
$ hs' ^. hsSymmetricState
liftIO . hscbPayloadIn hscb $ dp
put $ hs' & hsMsgBuffer .~ B.empty
& hsSymmetricState .~ ss
next
evalToken :: forall c d h. (Cipher c, Curve d, Hash h)
=> Bool
-> TokenF (HandshakeMonad c d h ())
-> HandshakeMonad c d h ()
evalToken i (E next) = do
hs <- get
if i == hs ^. hsInitiator then do
~kp@(_, pk) <- liftIO curveGenKey
let pk' = curvePubToBytes pk
ss = hs ^. hsSymmetricState
ss' = mixHash pk' ss
ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss'
put $ hs & hsLocalEphemeralKey .~ Just kp
& hsSymmetricState .~ ss''
& hsMsgBuffer %~ (flip mappend . convert) pk'
else do
let (b, rest) = B.splitAt (curveLength (Proxy :: Proxy d)) $ hs ^. hsMsgBuffer
reBytes = convert b
ss = hs ^. hsSymmetricState
ss' = mixHash reBytes ss
ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss'
put $ hs & hsRemoteEphemeralKey .~ Just (curveBytesToPub reBytes)
& hsSymmetricState .~ ss''
& hsMsgBuffer .~ rest
next
evalToken i (S next) = do
hs <- get
if i == hs ^. hsInitiator then do
let pk = curvePubToBytes . snd . getLocalStaticKey $ hs
ss = hs ^. hsSymmetricState
(ct, ss') = encryptAndHash ((Plaintext . convert) pk) ss
put $ hs & hsSymmetricState .~ ss'
& hsMsgBuffer %~ (flip mappend . convert) ct
else
if isJust (hs ^. hsRemoteStaticKey) then
throw . HandshakeStateFailure $ "unable to overwrite remote static key"
else do
let hasKey = hs ^. hsSymmetricState . ssHasKey
len = curveLength (Proxy :: Proxy d)
d = if hasKey then len + 16 else len
(b, rest) = B.splitAt d $ hs ^. hsMsgBuffer
ct = cipherBytesToText . convert $ b
ss = hs ^. hsSymmetricState
(Plaintext pt, ss') = decryptAndHash ct ss
put $ hs & hsRemoteStaticKey .~ Just (curveBytesToPub pt)
& hsSymmetricState .~ ss'
& hsMsgBuffer .~ rest
next
evalToken _ (Dhee next) = do
hs <- get
let ss = hs ^. hsSymmetricState
~(sk, _) = getLocalEphemeralKey hs
rpk = getRemoteEphemeralKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalToken i (Dhes next) = do
hs <- get
if i == hs ^. hsInitiator then do
let ss = hs ^. hsSymmetricState
~(sk, _) = getLocalEphemeralKey hs
rpk = getRemoteStaticKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
else
evalToken (not i) $ Dhse next
evalToken i (Dhse next) = do
hs <- get
if i == hs ^. hsInitiator then do
let ss = hs ^. hsSymmetricState
~(sk, _) = getLocalStaticKey hs
rpk = getRemoteEphemeralKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
else
evalToken (not i) $ Dhes next
evalToken _ (Dhss next) = do
hs <- get
let ss = hs ^. hsSymmetricState
~(sk, _) = getLocalStaticKey hs
rpk = getRemoteStaticKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgPattern :: forall c d h. (Cipher c, Curve d, Hash h)
=> HandshakePatternF (PreMsgMonad c d h ())
-> PreMsgMonad c d h ()
evalPreMsgPattern (Initiator t next) = do
iterM (evalPreMsgToken True) t
next
evalPreMsgPattern (Responder t next) = do
iterM (evalPreMsgToken False) t
next
evalPreMsgPattern _ = error "invalid pre-message pattern operation"
evalPreMsgToken :: forall c d h. (Cipher c, Curve d, Hash h)
=> Bool
-> TokenF (PreMsgMonad c d h ())
-> PreMsgMonad c d h ()
evalPreMsgToken i (E next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk = if i == hs ^. hsInitiator then (snd . getLocalEphemeralKey) hs else getRemoteEphemeralKey hs
ss' = mixHash (curvePubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken i (S next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk = if i == hs ^. hsInitiator then (snd . getLocalStaticKey) hs else getRemoteStaticKey hs
ss' = mixHash (curvePubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken _ _ = error "invalid pre-message pattern token"
getLocalStaticKey :: Curve d => HandshakeState c d h -> KeyPair d
getLocalStaticKey hs = fromMaybe (throw (HandshakeStateFailure "local static key not set"))
(hs ^. hsLocalStaticKey)
getLocalEphemeralKey :: Curve d => HandshakeState c d h -> KeyPair d
getLocalEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "local ephemeral key not set"))
(hs ^. hsLocalEphemeralKey)
getRemoteStaticKey :: Curve d => HandshakeState c d h -> PublicKey d
getRemoteStaticKey hs = fromMaybe (throw (HandshakeStateFailure "remote static key not set"))
(hs ^. hsRemoteStaticKey)
getRemoteEphemeralKey :: Curve d => HandshakeState c d h -> PublicKey d
getRemoteEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "remote ephemeral key not set"))
(hs ^. hsRemoteEphemeralKey)
encryptPayload :: Cipher c
=> Plaintext
-> SendingCipherState c
-> (ByteString, SendingCipherState c)
encryptPayload pt (SCS cs) = ((convert . cipherTextToBytes) ct, SCS cs')
where
(ct, cs') = encryptAndIncrement ad pt cs
ad = AssocData . bsToSB' $ ""
decryptPayload :: Cipher c
=> ByteString
-> ReceivingCipherState c
-> (Plaintext, ReceivingCipherState c)
decryptPayload ct (RCS cs) = (pt, RCS cs')
where
(pt, cs') = decryptAndIncrement ad ((cipherBytesToText . convert) ct) cs
ad = AssocData . bsToSB' $ ""