{-# LANGUAGE TemplateHaskell, ScopedTypeVariables #-}
module Crypto.Noise.Internal.NoiseState where
import Control.Lens
import Control.Monad.Catch.Pure
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.State
import Data.ByteArray (ScrubbedBytes)
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.Handshake.Interpreter
import Crypto.Noise.Internal.Handshake.Pattern (HandshakePattern)
import Crypto.Noise.Internal.Handshake.State
import Crypto.Noise.Internal.SymmetricState (split)
data NoiseState c d h =
NoiseState { _nsHandshakeState :: HandshakeState c d h
, _nsHandshakePattern :: HandshakePattern
, _nsHandshakeSuspension :: Maybe (ScrubbedBytes -> Handshake c d h ())
, _nsSendingCipherState :: Maybe (CipherState c)
, _nsReceivingCipherState :: Maybe (CipherState c)
}
$(makeLenses ''NoiseState)
noiseState :: (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> HandshakePattern
-> NoiseState c d h
noiseState ho hp =
NoiseState { _nsHandshakeState = handshakeState ho hp
, _nsHandshakePattern = hp
, _nsHandshakeSuspension = Nothing
, _nsSendingCipherState = Nothing
, _nsReceivingCipherState = Nothing
}
resumeHandshake :: (MonadThrow m, Cipher c, DH d, Hash h)
=> ScrubbedBytes
-> NoiseState c d h
-> m (HandshakeResult, NoiseState c d h)
resumeHandshake msg ns = case ns ^. nsHandshakeSuspension of
Nothing -> do
let hp = ns ^. nsHandshakePattern
(_, ns') <- runInterpreter . runHandshakePattern $ hp
resumeHandshake msg ns'
Just s -> runInterpreter . s $ msg
where
runInterpreter i = do
let result = runCatch . runStateT (resume . runHandshake $ i)
$ ns ^. nsHandshakeState
case result of
Left err -> throwM err
Right (suspension, hs) -> case suspension of
Left (Request req resp) -> do
let ns' = ns & nsHandshakeSuspension ?~ (Handshake . resp)
& nsHandshakeState .~ hs
return (req, ns')
Right _ -> do
let (cs1, cs2) = split (hs ^. hsSymmetricState)
ns' = if hs ^. hsOpts . hoRole == InitiatorRole
then ns & nsSendingCipherState ?~ cs1
& nsReceivingCipherState ?~ cs2
else ns & nsSendingCipherState ?~ cs2
& nsReceivingCipherState ?~ cs1
ns'' = ns' & nsHandshakeState .~ hs
return (HandshakeResultMessage (hs ^. hsMsgBuffer), ns'')