module Crypto.Noise.Internal.NoiseState where
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Except (MonadError(..), runExcept)
import Control.Monad.State (MonadState(..), runStateT, get, put)
import Control.Monad.Trans.Free.Church
import Control.Lens
import Data.ByteString (ByteString)
import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Prelude hiding (concat, splitAt, length)
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Internal.Handshake
import Crypto.Noise.Internal.HandshakePattern
import Crypto.Noise.Internal.Types
import Data.ByteArray.Extend
data NoiseState c d h =
NoiseState { _nsHandshakeState :: HandshakeState c d h
, _nsHandshakeSuspension :: ScrubbedBytes -> Handshake c d h ()
, _nsSendingCipherState :: Maybe (CipherState c)
, _nsReceivingCipherState :: Maybe (CipherState c)
}
$(makeLenses ''NoiseState)
defaultHandshakeOpts :: HandshakePattern
-> HandshakeRole
-> HandshakeOpts d
defaultHandshakeOpts hp r =
HandshakeOpts { _hoPattern = hp
, _hoRole = r
, _hoPrologue = ""
, _hoPreSharedKey = Nothing
, _hoLocalStatic = Nothing
, _hoLocalSemiEphemeral = Nothing
, _hoLocalEphemeral = Nothing
, _hoRemoteStatic = Nothing
, _hoRemoteSemiEphemeral = Nothing
, _hoRemoteEphemeral = Nothing
}
mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
=> ByteString
-> Bool
-> proxy (c, d, h)
-> ScrubbedBytes
mkHandshakeName hpn psk _ = p <> convert hpn <> "_" <> d <> "_" <> c <> "_" <> h
where
p = if psk then "NoisePSK_" else "Noise_"
c = cipherName (Proxy :: Proxy c)
d = dhName (Proxy :: Proxy d)
h = hashName (Proxy :: Proxy h)
invertRole :: HandshakeRole -> HandshakeRole
invertRole InitiatorRole = ResponderRole
invertRole ResponderRole = InitiatorRole
handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> HandshakeState c d h
handshakeState ho | not (validPSK (ho ^. hoPreSharedKey)) = error "pre-shared key must be 32 bytes in length"
| otherwise =
HandshakeState { _hsSymmetricState = ss''
, _hsOpts = ho
, _hsMsgBuffer = mempty
}
where
validPSK = maybe True (\psk -> length psk == 32)
ss = symmetricState $ mkHandshakeName (ho ^. hoPattern ^. hpName)
(isJust (ho ^. hoPreSharedKey))
(Proxy :: Proxy (c, d, h))
ss' = mixHash (ho ^. hoPrologue) ss
ss'' = maybe ss' (`mixPSK` ss') $ ho ^. hoPreSharedKey
runHandshake :: (Cipher c, Hash h)
=> ScrubbedBytes
-> NoiseState c d h
-> Either NoiseException (ScrubbedBytes, NoiseState c d h)
runHandshake msg ns = runExcept $ do
((res, ns''), hs') <- runStateT st $ ns ^. nsHandshakeState
return (res, ns'' & nsHandshakeState .~ hs')
where
st = do
x <- resume . runHandshake' . (ns ^. nsHandshakeSuspension) $ msg
case x of
Left (Request req resp) -> return (req, ns & nsHandshakeSuspension .~ (Handshake . resp))
Right _ -> do
hs <- get
let (cs1, cs2) = split (hs ^. hsSymmetricState)
ns' = if hs ^. hsOpts . hoRole == InitiatorRole
then ns & nsSendingCipherState .~ Just cs1
& nsReceivingCipherState .~ Just cs2
else ns & nsSendingCipherState .~ Just cs2
& nsReceivingCipherState .~ Just cs1
return (hs ^. hsMsgBuffer, ns')
noiseState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> NoiseState c d h
noiseState ho =
NoiseState { _nsHandshakeState = hs''
, _nsHandshakeSuspension = suspension
, _nsSendingCipherState = Nothing
, _nsReceivingCipherState = Nothing
}
where
hs = handshakeState ho :: HandshakeState c d h
coroutine = iterM evalPattern $ hoistFT (return . runIdentity) (ho ^. hoPattern . hpActions)
(suspension, hs'') = case runExcept (runStateT (resume (runHandshake' coroutine)) hs) of
Left err -> error $ "handshake pattern interpreter threw exception: " <> show err
Right result -> case result of
(Left (Request _ resp), hs') -> (Handshake . resp, hs')
_ -> error "handshake pattern interpreter ended pre-maturely"
processPatternOp :: (Cipher c, DH d, Hash h)
=> HandshakeRole
-> FT TokenF Identity ()
-> Handshake c d h ()
-> Handshake c d h ()
processPatternOp opRole t next = do
hs <- get
input <- Handshake <$> request $ hs ^. hsMsgBuffer
hs' <- get
if opRole == hs' ^. hsOpts . hoRole then do
put $ hs' & hsMsgBuffer .~ mempty
iterM (evalMsgToken opRole) $ hoistFT (return . runIdentity) t
hs'' <- get
let enc = encryptAndHash (convert input) $ hs'' ^. hsSymmetricState
(ep, ss) <- either throwError return enc
put $ hs'' & hsMsgBuffer %~ (flip mappend . convert) ep
& hsSymmetricState .~ ss
else do
put $ hs' & hsMsgBuffer .~ input
iterM (evalMsgToken opRole) $ hoistFT (return . runIdentity) t
hs'' <- get
let remaining = hs'' ^. hsMsgBuffer
dec = decryptAndHash (cipherBytesToText (convert remaining))
$ hs'' ^. hsSymmetricState
(dp, ss) <- either (const . throwError . HandshakeError $ "handshake payload failed to decrypt") return dec
put $ hs'' & hsMsgBuffer .~ convert dp
& hsSymmetricState .~ ss
next
evalPattern :: (Cipher c, DH d, Hash h)
=> HandshakePatternF (Handshake c d h ())
-> Handshake c d h ()
evalPattern (PreInitiator t next) = do
iterM (evalPreMsgToken InitiatorRole) $ hoistFT (return . runIdentity) t
next
evalPattern (PreResponder t next) = do
iterM (evalPreMsgToken ResponderRole) $ hoistFT (return . runIdentity) t
next
evalPattern (Initiator t next) = processPatternOp InitiatorRole t next
evalPattern (Responder t next) = processPatternOp ResponderRole t next
evalMsgToken :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeRole
-> TokenF (Handshake c d h ())
-> Handshake c d h ()
evalMsgToken opRole (E next) = do
hs <- get
if opRole == hs ^. hsOpts . hoRole then do
(_, pk) <- getLocalEphemeral hs
let pk' = dhPubToBytes pk
ss = hs ^. hsSymmetricState
ss' = mixHash pk' ss
ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss'
put $ hs & hsSymmetricState .~ ss''
& hsMsgBuffer %~ (flip mappend . convert) pk'
else do
let (b, rest) = splitAt (dhLength (Proxy :: Proxy d)) $ hs ^. hsMsgBuffer
reBytes = convert b
ss = hs ^. hsSymmetricState
ss' = mixHash reBytes ss
ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss'
theirKey = dhBytesToPub reBytes
theirKey' <- maybe (throwError . HandshakeError $ "invalid remote ephemeral key") return theirKey
put $ hs & hsOpts . hoRemoteEphemeral .~ Just theirKey'
& hsSymmetricState .~ ss''
& hsMsgBuffer .~ rest
next
evalMsgToken opRole (S next) = do
hs <- get
if opRole == hs ^. hsOpts. hoRole then do
pk <- dhPubToBytes . snd <$> getLocalStatic hs
let ss = hs ^. hsSymmetricState
enc = encryptAndHash (convert pk) ss
(ct, ss') <- either throwError return enc
put $ hs & hsSymmetricState .~ ss'
& hsMsgBuffer %~ (flip mappend . convert) ct
else
if isJust (hs ^. hsOpts ^. hoRemoteStatic)
then throwError . InvalidHandshakeOptions $ "unable to overwrite remote static key"
else do
let hasKey = hs ^. hsSymmetricState . ssHasKey
len = dhLength (Proxy :: Proxy d)
d = if hasKey then len + 16 else len
(b, rest) = splitAt d $ hs ^. hsMsgBuffer
ct = cipherBytesToText . convert $ b
ss = hs ^. hsSymmetricState
dec = decryptAndHash ct ss
(pt, ss') <- either (const . throwError . HandshakeError $ "failed to decrypt remote static key") return dec
theirKey' <- maybe (throwError . HandshakeError $ "invalid remote static key provided") return $ dhBytesToPub pt
put $ hs & hsOpts . hoRemoteStatic .~ Just theirKey'
& hsSymmetricState .~ ss'
& hsMsgBuffer .~ rest
next
evalMsgToken _ (Dhee next) = do
hs <- get
~(sk, _) <- getLocalEphemeral hs
rpk <- getRemoteEphemeral hs
let ss' = mixKey (dhPerform sk rpk) $ hs ^. hsSymmetricState
put $ hs & hsSymmetricState .~ ss'
next
evalMsgToken opRole (Dhes next) = do
hs <- get
if opRole == hs ^. hsOpts . hoRole then do
let ss = hs ^. hsSymmetricState
rpk <- getRemoteStatic hs
~(sk, _) <- getLocalEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
else evalMsgToken (invertRole opRole) $ Dhse next
evalMsgToken opRole (Dhse next) = do
hs <- get
if opRole == hs ^. hsOpts . hoRole then do
let ss = hs ^. hsSymmetricState
~(sk, _) <- getLocalStatic hs
rpk <- getRemoteEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
else evalMsgToken (invertRole opRole) $ Dhes next
evalMsgToken _ (Dhss next) = do
hs <- get
let ss = hs ^. hsSymmetricState
~(sk, _) <- getLocalStatic hs
rpk <- getRemoteStatic hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken :: (Cipher c, DH d, Hash h)
=> HandshakeRole
-> TokenF (Handshake c d h ())
-> Handshake c d h ()
evalPreMsgToken opRole (E next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalSemiEphemeral hs
else getRemoteSemiEphemeral hs
let ss' = mixHash (dhPubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken opRole (S next) = do
hs <- get
let ss = hs ^. hsSymmetricState
pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalStatic hs
else getRemoteStatic hs
let ss' = mixHash (dhPubToBytes pk) ss
put $ hs & hsSymmetricState .~ ss'
next
evalPreMsgToken _ _ = error "invalid pre-message pattern token"
getLocalStatic :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalStatic hs = maybe (throwError (InvalidHandshakeOptions "local static key not set"))
return
(hs ^. hsOpts ^. hoLocalStatic)
getLocalSemiEphemeral :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalSemiEphemeral hs = maybe (throwError (InvalidHandshakeOptions "local semi-ephemeral key not set"))
return
(hs ^. hsOpts ^. hoLocalSemiEphemeral)
getLocalEphemeral :: HandshakeState c d h
-> Handshake c d h (KeyPair d)
getLocalEphemeral hs = maybe (throwError (InvalidHandshakeOptions "local ephemeral key not set"))
return
(hs ^. hsOpts ^. hoLocalEphemeral)
getRemoteStatic :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteStatic hs = maybe (throwError (InvalidHandshakeOptions "remote static key not set"))
return
(hs ^. hsOpts ^. hoRemoteStatic)
getRemoteSemiEphemeral :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteSemiEphemeral hs = maybe (throwError (InvalidHandshakeOptions "remote semi-ephemeral key not set"))
return
(hs ^. hsOpts ^. hoRemoteSemiEphemeral)
getRemoteEphemeral :: HandshakeState c d h
-> Handshake c d h (PublicKey d)
getRemoteEphemeral hs = maybe (throwError (InvalidHandshakeOptions "remote ephemeral key not set"))
return
(hs ^. hsOpts ^. hoRemoteEphemeral)