{-# LANGUAGE TemplateHaskell, ScopedTypeVariables, GeneralizedNewtypeDeriving,
FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Crypto.Noise.Internal.Handshake.State where
import Control.Lens
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Catch.Pure
import Control.Monad.State (MonadState(..), StateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.ByteArray (ScrubbedBytes, convert)
import Data.ByteString (ByteString)
import Data.Monoid ((<>))
import Data.Proxy
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.Handshake.Pattern hiding (ss)
import Crypto.Noise.Internal.SymmetricState
data HandshakeRole = InitiatorRole | ResponderRole
deriving (Show, Eq)
data HandshakeOpts d =
HandshakeOpts { _hoRole :: HandshakeRole
, _hoPrologue :: Plaintext
, _hoLocalEphemeral :: Maybe (KeyPair d)
, _hoLocalStatic :: Maybe (KeyPair d)
, _hoRemoteEphemeral :: Maybe (PublicKey d)
, _hoRemoteStatic :: Maybe (PublicKey d)
}
$(makeLenses ''HandshakeOpts)
data HandshakeState c d h =
HandshakeState { _hsSymmetricState :: SymmetricState c h
, _hsOpts :: HandshakeOpts d
, _hsPSKMode :: Bool
, _hsMsgBuffer :: ScrubbedBytes
}
$(makeLenses ''HandshakeState)
data HandshakeResult
= HandshakeResultMessage ScrubbedBytes
| HandshakeResultNeedPSK
newtype Handshake c d h r =
Handshake { runHandshake :: Coroutine (Request HandshakeResult ScrubbedBytes) (StateT (HandshakeState c d h) Catch) r
} deriving ( Functor
, Applicative
, Monad
, MonadThrow
, MonadState (HandshakeState c d h)
)
defaultHandshakeOpts :: HandshakeRole
-> Plaintext
-> HandshakeOpts d
defaultHandshakeOpts r p =
HandshakeOpts { _hoRole = r
, _hoPrologue = p
, _hoLocalEphemeral = Nothing
, _hoLocalStatic = Nothing
, _hoRemoteEphemeral = Nothing
, _hoRemoteStatic = Nothing
}
setLocalEphemeral :: Maybe (KeyPair d)
-> HandshakeOpts d
-> HandshakeOpts d
setLocalEphemeral k opts = opts { _hoLocalEphemeral = k }
setLocalStatic :: Maybe (KeyPair d)
-> HandshakeOpts d
-> HandshakeOpts d
setLocalStatic k opts = opts { _hoLocalStatic = k }
setRemoteEphemeral :: Maybe (PublicKey d)
-> HandshakeOpts d
-> HandshakeOpts d
setRemoteEphemeral k opts = opts { _hoRemoteEphemeral = k }
setRemoteStatic :: Maybe (PublicKey d)
-> HandshakeOpts d
-> HandshakeOpts d
setRemoteStatic k opts = opts { _hoRemoteStatic = k }
mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
=> ByteString
-> proxy (c, d, h)
-> ScrubbedBytes
mkHandshakeName protoName _ =
"Noise_" <> convert protoName <> "_" <> d <> "_" <> c <> "_" <> h
where
c = cipherName (Proxy :: Proxy c)
d = dhName (Proxy :: Proxy d)
h = hashName (Proxy :: Proxy h)
handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> HandshakePattern
-> HandshakeState c d h
handshakeState ho hp =
HandshakeState { _hsSymmetricState = ss'
, _hsOpts = ho
, _hsPSKMode = hp ^. hpPSKMode
, _hsMsgBuffer = mempty
}
where
ss = symmetricState $ mkHandshakeName (hp ^. hpName)
(Proxy :: Proxy (c, d, h))
ss' = mixHash (ho ^. hoPrologue) ss
instance (Functor f, MonadThrow m) => MonadThrow (Coroutine f m) where
throwM = lift . throwM
instance (Functor f, MonadState s m) => MonadState s (Coroutine f m) where
get = lift get
put = lift . put
state = lift . state