module Crypto.Noise.Internal.NoiseState where
import Control.Monad.Catch.Pure
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.State (MonadState(..), runStateT, get, put)
import Control.Monad.Free.Church
import Control.Lens
import Data.ByteArray (ScrubbedBytes, convert, length, splitAt)
import Data.ByteString (ByteString)
import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Prelude hiding (concat, splitAt, length)
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Internal.Handshake
import Crypto.Noise.Internal.HandshakePattern hiding (e, s, ee, se, es, ss)
import Crypto.Noise.Internal.Types
data NoiseState c d h =
NoiseState { _nsHandshakeState :: HandshakeState c d h
, _nsHandshakeSuspension :: ScrubbedBytes -> Handshake c d h ()
, _nsSendingCipherState :: Maybe (CipherState c)
, _nsReceivingCipherState :: Maybe (CipherState c)
}
$(makeLenses ''NoiseState)
defaultHandshakeOpts :: HandshakePattern
-> HandshakeRole
-> HandshakeOpts d
defaultHandshakeOpts hp r =
HandshakeOpts { _hoPattern = hp
, _hoRole = r
, _hoPrologue = ""
, _hoPreSharedKey = Nothing
, _hoLocalStatic = Nothing
, _hoLocalSemiEphemeral = Nothing
, _hoLocalEphemeral = Nothing
, _hoRemoteStatic = Nothing
, _hoRemoteSemiEphemeral = Nothing
, _hoRemoteEphemeral = Nothing
}
mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
=> ByteString
-> Bool
-> proxy (c, d, h)
-> ScrubbedBytes
mkHandshakeName hpn psk _ = p <> convert hpn <> "_" <> d <> "_" <> c <> "_" <> h
where
p = if psk then "NoisePSK_" else "Noise_"
c = cipherName (Proxy :: Proxy c)
d = dhName (Proxy :: Proxy d)
h = hashName (Proxy :: Proxy h)
handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> HandshakeState c d h
handshakeState ho | not (validPSK (ho ^. hoPreSharedKey)) = error "pre-shared key must be 32 bytes in length"
| otherwise =
HandshakeState { _hsSymmetricState = ss''
, _hsOpts = ho
, _hsMsgBuffer = mempty
}
where
validPSK = maybe True (\psk -> length psk == 32)
ss = symmetricState $ mkHandshakeName (ho ^. hoPattern ^. hpName)
(isJust (ho ^. hoPreSharedKey))
(Proxy :: Proxy (c, d, h))
ss' = mixHash (ho ^. hoPrologue) ss
ss'' = maybe ss' (`mixPSK` ss') $ ho ^. hoPreSharedKey
runHandshake :: (MonadThrow m, Cipher c, Hash h)
=> ScrubbedBytes
-> NoiseState c d h
-> m (ScrubbedBytes, NoiseState c d h)
runHandshake msg ns = reThrow . runCatch $ do
((res, ns''), hs') <- runStateT st $ ns ^. nsHandshakeState
return (res, ns'' & nsHandshakeState .~ hs')
where
reThrow = either throwM return
st = do
x <- resume . runHandshake' . (ns ^. nsHandshakeSuspension) $ msg
case x of
Left (Request req resp) -> return (req, ns & nsHandshakeSuspension .~ (Handshake . resp))
Right _ -> do
hs <- get
let (cs1, cs2) = split (hs ^. hsSymmetricState)
ns' = if hs ^. hsOpts . hoRole == InitiatorRole
then ns & nsSendingCipherState .~ Just cs1
& nsReceivingCipherState .~ Just cs2
else ns & nsSendingCipherState .~ Just cs2
& nsReceivingCipherState .~ Just cs1
return (hs ^. hsMsgBuffer, ns')
noiseState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> NoiseState c d h
noiseState ho =
NoiseState { _nsHandshakeState = hs''
, _nsHandshakeSuspension = suspension
, _nsSendingCipherState = Nothing
, _nsReceivingCipherState = Nothing
}
where
hs = handshakeState ho :: HandshakeState c d h
coroutine = iterM evalPattern $ ho ^. hoPattern . hpActions
(suspension, hs'') = case runCatch (runStateT (resume (runHandshake' coroutine)) hs) of
Left err -> error $ "handshake pattern interpreter threw exception: " <> show err
Right result -> case result of
(Left (Request _ resp), hs') -> (Handshake . resp, hs')
_ -> error "handshake pattern interpreter ended pre-maturely"
processPatternOp :: (Cipher c, DH d, Hash h)
=> HandshakeRole
-> F TokenF ()
-> Handshake c d h ()
-> Handshake c d h ()
processPatternOp opRole t next = do
hs <- get
input <- Handshake <$> request $ hs ^. hsMsgBuffer
hs' <- get
if opRole == hs' ^. hsOpts . hoRole then do
put $ hs' & hsMsgBuffer .~ mempty
iterM (evalMsgToken opRole) t
hs'' <- get
let enc = encryptAndHash (convert input) $ hs'' ^. hsSymmetricState
(ep, ss) <- either throwM return enc
put $ hs'' & hsMsgBuffer %~ (flip mappend . convert) ep
& hsSymmetricState .~ ss
else do
put $ hs' & hsMsgBuffer .~ input
iterM (evalMsgToken opRole) t
hs'' <- get
let remaining = hs'' ^. hsMsgBuffer
dec = decryptAndHash (cipherBytesToText (convert remaining))
$ hs'' ^. hsSymmetricState
(dp, ss) <- either (const . throwM . HandshakeError $ "handshake payload failed to decrypt") return dec
put $ hs'' & hsMsgBuffer .~ convert dp
& hsSymmetricState .~ ss
next
evalPattern :: (Cipher c, DH d, Hash h)
=> HandshakePatternF (Handshake c d h ())
-> Handshake c d h ()
evalPattern (PreInitiator t next) = do
iterM (evalPreMsgToken InitiatorRole) t
next
evalPattern (PreResponder t next) = do
iterM (evalPreMsgToken ResponderRole) t
next
evalPattern (Initiator t next) = processPatternOp InitiatorRole t next
evalPattern (Responder t next) = processPatternOp ResponderRole t next
evalMsgToken :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeRole
-> TokenF (Handshake c d h ())
-> Handshake c d h ()
evalMsgToken opRole (E next) = do
hs <- get
if opRole == hs ^. hsOpts . hoRole then do
(_, pk) <- getLocalEphemeral hs
let pk' = dhPubToBytes pk
ss = hs ^. hsSymmetricState
ss' = mixHash pk' ss
ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss'
put $ hs & hsSymmetricState .~ ss''
& hsMsgBuffer %~ (flip mappend . convert) pk'
else do
let (b, rest) = splitAt (dhLength (Proxy :: Proxy d)) $ hs ^. hsMsgBuffer
reBytes = convert b
ss = hs ^. hsSymmetricState
ss' = mixHash reBytes ss
ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss'
theirKey = dhBytesToPub reBytes
theirKey' <- maybe (throwM . HandshakeError $ "invalid remote ephemeral key") return theirKey
put $ hs & hsOpts . hoRemoteEphemeral .~ Just theirKey'
& hsSymmetricState .~ ss''
& hsMsgBuffer .~ rest
next
evalMsgToken opRole (S next) = do
hs <- get
if opRole == hs ^. hsOpts. hoRole then do
pk <- dhPubToBytes . snd <$> getLocalStatic hs
let ss = hs ^. hsSymmetricState
enc = encryptAndHash (convert pk) ss
(ct, ss') <- either throwM return enc
put $ hs & hsSymmetricState .~ ss'
& hsMsgBuffer %~ (flip mappend . convert) ct
else
if isJust (hs ^. hsOpts ^. hoRemoteStatic)
then throwM . InvalidHandshakeOptions $ "unable to overwrite remote static key"
else do
let hasKey = hs ^. hsSymmetricState . ssHasKey
len = dhLength (Proxy :: Proxy d)
d = if hasKey then len + 16 else len
(b, rest) = splitAt d $ hs ^. hsMsgBuffer
ct = cipherBytesToText . convert $ b
ss = hs ^. hsSymmetricState
dec = decryptAndHash ct ss
(pt, ss') <- either (const . throwM . HandshakeError $ "failed to decrypt remote static key") return dec
theirKey' <- maybe (throwM . HandshakeError $ "invalid remote static key provided") return $ dhBytesToPub pt
put $ hs & hsOpts . hoRemoteStatic .~ Just theirKey'
& hsSymmetricState .~ ss'
& hsMsgBuffer .~ rest
next
evalMsgToken _ (Ee next) = do
hs <- get
~(sk, _) <- getLocalEphemeral hs
rpk <- getRemoteEphemeral hs
let ss' = mixKey (dhPerform sk rpk) $ hs ^. hsSymmetricState
put $ hs & hsSymmetricState .~ ss'
next
evalMsgToken _ (Es next) = do
hs <- get
let ss = hs ^. hsSymmetricState
if hs ^. hsOpts . hoRole == InitiatorRole then do
rpk <- getRemoteStatic hs
~(sk, _) <- getLocalEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
else do
~(sk, _) <- getLocalStatic hs
rpk <- getRemoteEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalMsgToken _ (Se next) = do
hs <- get
let ss = hs ^. hsSymmetricState
if hs ^. hsOpts . hoRole == InitiatorRole then do
~(sk, _) <- getLocalStatic hs
rpk <- getRemoteEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
else do
rpk <- getRemoteStatic hs
~(sk, _) <- getLocalEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalMsgToken _ (Ss next) = do
hs <- get
let ss = hs ^. hsSymmetricState
~(sk, _) <- getLocalStatic hs
rpk <- getRemoteStatic hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken :: (Cipher c, DH d, Hash h)
=> HandshakeRole
-> TokenF (Handshake c d h ())
-> Handshake c d h ()
evalPreMsgToken opRole (E next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalSemiEphemeral hs
else getRemoteSemiEphemeral hs
let ss' = mixHash (dhPubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken opRole (S next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalStatic hs
else getRemoteStatic hs
let ss' = mixHash (dhPubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken _ _ = error "invalid pre-message pattern token"
getLocalStatic :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalStatic hs = maybe (throwM (InvalidHandshakeOptions "local static key not set"))
return
(hs ^. hsOpts ^. hoLocalStatic)
getLocalSemiEphemeral :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local semi-ephemeral key not set"))
return
(hs ^. hsOpts ^. hoLocalSemiEphemeral)
getLocalEphemeral :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local ephemeral key not set"))
return
(hs ^. hsOpts ^. hoLocalEphemeral)
getRemoteStatic :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteStatic hs = maybe (throwM (InvalidHandshakeOptions "remote static key not set"))
return
(hs ^. hsOpts ^. hoRemoteStatic)
getRemoteSemiEphemeral :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote semi-ephemeral key not set"))
return
(hs ^. hsOpts ^. hoRemoteSemiEphemeral)
getRemoteEphemeral :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote ephemeral key not set"))
return
(hs ^. hsOpts ^. hoRemoteEphemeral)