module Crypto.Noise.Internal.Handshake where
import Control.Lens
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Except (MonadError(..), Except)
import Control.Monad.State (MonadState(..), StateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Internal.HandshakePattern
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Internal.Types
import Data.ByteArray.Extend
data HandshakeRole = InitiatorRole | ResponderRole
deriving Eq
data HandshakeOpts d =
HandshakeOpts { _hoPattern :: HandshakePattern
, _hoRole :: HandshakeRole
, _hoPrologue :: Plaintext
, _hoPreSharedKey :: Maybe Plaintext
, _hoLocalStatic :: Maybe (KeyPair d)
, _hoLocalSemiEphemeral :: Maybe (KeyPair d)
, _hoLocalEphemeral :: Maybe (KeyPair d)
, _hoRemoteStatic :: Maybe (PublicKey d)
, _hoRemoteSemiEphemeral :: Maybe (PublicKey d)
, _hoRemoteEphemeral :: Maybe (PublicKey d)
}
$(makeLenses ''HandshakeOpts)
data HandshakeState c d h =
HandshakeState { _hsSymmetricState :: SymmetricState c h
, _hsOpts :: HandshakeOpts d
, _hsMsgBuffer :: ScrubbedBytes
}
$(makeLenses ''HandshakeState)
newtype Handshake c d h r =
Handshake { runHandshake' :: Coroutine (Request ScrubbedBytes ScrubbedBytes) (StateT (HandshakeState c d h) (Except NoiseException)) r
} deriving ( Functor
, Applicative
, Monad
, MonadError NoiseException
, MonadState (HandshakeState c d h)
)
instance (Functor f, MonadError e m) => MonadError e (Coroutine f m) where
throwError = lift . throwError
catchError m _ = m
instance (Functor f, MonadState s m) => MonadState s (Coroutine f m) where
get = lift get
put = lift . put
state = lift . state