{-# LANGUAGE RankNTypes, ScopedTypeVariables #-}
------------------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.Handshake.Interpreter
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Internal.Handshake.Interpreter where

import Control.Applicative.Free
import Control.Exception.Safe
import Control.Lens
import Control.Monad.Coroutine.SuspensionFunctors
import Data.ByteArray (splitAt)
import Data.Maybe     (isJust)
import Data.Proxy
import Prelude hiding (splitAt)

import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Exception
import Crypto.Noise.Hash
import Crypto.Noise.Internal.Handshake.Pattern hiding (ss)
import Crypto.Noise.Internal.Handshake.State
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState

-- [ E ] -----------------------------------------------------------------------

interpretToken :: forall c d h r. (Cipher c, DH d, Hash h)
               => HandshakeRole
               -> Token r
               -> Handshake c d h r
interpretToken :: forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretToken HandshakeRole
opRole (E r
next) = do
  HandshakeRole
myRole  <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole
  Bool
pskMode <- Getting Bool (HandshakeState c d h) Bool -> Handshake c d h Bool
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting Bool (HandshakeState c d h) Bool
forall c d h (f :: * -> *).
Functor f =>
(Bool -> f Bool)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsPSKMode

  if HandshakeRole
opRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
myRole then do
    (SecretKey d
_, PublicKey d
pk) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalEphemeral ExceptionKeyType
LocalEphemeral
    let pkBytes :: ScrubbedBytes
pkBytes = PublicKey d -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes PublicKey d
pk

    if Bool
pskMode
      then (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey ScrubbedBytes
pkBytes (SymmetricState c h -> SymmetricState c h)
-> (SymmetricState c h -> SymmetricState c h)
-> SymmetricState c h
-> SymmetricState c h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash ScrubbedBytes
pkBytes
      else (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash ScrubbedBytes
pkBytes

    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer      ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a.
(MonadState s m, Semigroup a) =>
ASetter' s a -> a -> m ()
<>= ScrubbedBytes
pkBytes

  else do
    ScrubbedBytes
buf <- Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer
    let (ScrubbedBytes
pkBytes, ScrubbedBytes
remainingBytes) = Int -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
splitAt (Proxy d -> Int
forall d (proxy :: * -> *). DH d => proxy d -> Int
forall (proxy :: * -> *). proxy d -> Int
dhLength (Proxy d
forall {k} (t :: k). Proxy t
Proxy :: Proxy d)) ScrubbedBytes
buf
    PublicKey d
theirKey <- Handshake c d h (PublicKey d)
-> (PublicKey d -> Handshake c d h (PublicKey d))
-> Maybe (PublicKey d)
-> Handshake c d h (PublicKey d)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (NoiseException -> Handshake c d h (PublicKey d)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (NoiseException -> Handshake c d h (PublicKey d))
-> (ExceptionKeyType -> NoiseException)
-> ExceptionKeyType
-> Handshake c d h (PublicKey d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptionKeyType -> NoiseException
InvalidKey (ExceptionKeyType -> Handshake c d h (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall a b. (a -> b) -> a -> b
$ ExceptionKeyType
RemoteEphemeral)
                      PublicKey d -> Handshake c d h (PublicKey d)
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return
                      (ScrubbedBytes -> Maybe (PublicKey d)
forall d. DH d => ScrubbedBytes -> Maybe (PublicKey d)
dhBytesToPub ScrubbedBytes
pkBytes)
    (HandshakeOpts d -> Identity (HandshakeOpts d))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Identity (HandshakeOpts d))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ((Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
    -> HandshakeOpts d -> Identity (HandshakeOpts d))
-> (Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
-> HandshakeState c d h
-> Identity (HandshakeState c d h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
-> HandshakeOpts d -> Identity (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRemoteEphemeral ((Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> Maybe (PublicKey d) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= PublicKey d -> Maybe (PublicKey d)
forall a. a -> Maybe a
Just PublicKey d
theirKey

    if Bool
pskMode
      then (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey ScrubbedBytes
pkBytes (SymmetricState c h -> SymmetricState c h)
-> (SymmetricState c h -> SymmetricState c h)
-> SymmetricState c h
-> SymmetricState c h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash ScrubbedBytes
pkBytes
      else (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash ScrubbedBytes
pkBytes

    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer                ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ScrubbedBytes
remainingBytes

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ S ] -----------------------------------------------------------------------

interpretToken HandshakeRole
opRole (S r
next) = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole

  if HandshakeRole
opRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
myRole then do
    SymmetricState c h
ss <- Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
-> Handshake c d h (SymmetricState c h)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState
    (SecretKey d
_, PublicKey d
pk) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalStatic ExceptionKeyType
LocalStatic
    (Ciphertext c
ct, SymmetricState c h
ss') <- ScrubbedBytes
-> SymmetricState c h
-> Handshake c d h (Ciphertext c, SymmetricState c h)
forall (m :: * -> *) c h.
(MonadThrow m, Cipher c, Hash h) =>
ScrubbedBytes
-> SymmetricState c h -> m (Ciphertext c, SymmetricState c h)
encryptAndHash (PublicKey d -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes PublicKey d
pk) SymmetricState c h
ss
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> SymmetricState c h -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= SymmetricState c h
ss'
    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer      ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a.
(MonadState s m, Semigroup a) =>
ASetter' s a -> a -> m ()
<>= Ciphertext c -> ScrubbedBytes
forall c. Cipher c => Ciphertext c -> ScrubbedBytes
cipherTextToBytes Ciphertext c
ct

  else do
    Maybe (PublicKey d)
configuredRemoteStatic <- Getting
  (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
-> Handshake c d h (Maybe (PublicKey d))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting
   (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
 -> Handshake c d h (Maybe (PublicKey d)))
-> Getting
     (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
-> Handshake c d h (Maybe (PublicKey d))
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
-> HandshakeState c d h
-> Const (Maybe (PublicKey d)) (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const (Maybe (PublicKey d)) (HandshakeState c d h))
-> ((Maybe (PublicKey d)
     -> Const (Maybe (PublicKey d)) (Maybe (PublicKey d)))
    -> HandshakeOpts d
    -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
-> Getting
     (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (PublicKey d)
 -> Const (Maybe (PublicKey d)) (Maybe (PublicKey d)))
-> HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRemoteStatic
    if Maybe (PublicKey d) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (PublicKey d)
configuredRemoteStatic
      then NoiseException -> Handshake c d h ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM NoiseException
StaticKeyOverwrite
      else do
        -- If a SymmetricKey has been established, the static key will be
        -- encrypted. In that case, the number of bytes to be read off the
        -- buffer will be the length of the public key plus a 16 byte
        -- authentication tag.
        Maybe (SymmetricKey c)
k <- Getting
  (Maybe (SymmetricKey c))
  (HandshakeState c d h)
  (Maybe (SymmetricKey c))
-> Handshake c d h (Maybe (SymmetricKey c))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting
   (Maybe (SymmetricKey c))
   (HandshakeState c d h)
   (Maybe (SymmetricKey c))
 -> Handshake c d h (Maybe (SymmetricKey c)))
-> Getting
     (Maybe (SymmetricKey c))
     (HandshakeState c d h)
     (Maybe (SymmetricKey c))
-> Handshake c d h (Maybe (SymmetricKey c))
forall a b. (a -> b) -> a -> b
$ (SymmetricState c h
 -> Const (Maybe (SymmetricKey c)) (SymmetricState c h))
-> HandshakeState c d h
-> Const (Maybe (SymmetricKey c)) (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h
  -> Const (Maybe (SymmetricKey c)) (SymmetricState c h))
 -> HandshakeState c d h
 -> Const (Maybe (SymmetricKey c)) (HandshakeState c d h))
-> ((Maybe (SymmetricKey c)
     -> Const (Maybe (SymmetricKey c)) (Maybe (SymmetricKey c)))
    -> SymmetricState c h
    -> Const (Maybe (SymmetricKey c)) (SymmetricState c h))
-> Getting
     (Maybe (SymmetricKey c))
     (HandshakeState c d h)
     (Maybe (SymmetricKey c))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CipherState c -> Const (Maybe (SymmetricKey c)) (CipherState c))
-> SymmetricState c h
-> Const (Maybe (SymmetricKey c)) (SymmetricState c h)
forall c1 h c2 (f :: * -> *).
Functor f =>
(CipherState c1 -> f (CipherState c2))
-> SymmetricState c1 h -> f (SymmetricState c2 h)
ssCipher ((CipherState c -> Const (Maybe (SymmetricKey c)) (CipherState c))
 -> SymmetricState c h
 -> Const (Maybe (SymmetricKey c)) (SymmetricState c h))
-> ((Maybe (SymmetricKey c)
     -> Const (Maybe (SymmetricKey c)) (Maybe (SymmetricKey c)))
    -> CipherState c -> Const (Maybe (SymmetricKey c)) (CipherState c))
-> (Maybe (SymmetricKey c)
    -> Const (Maybe (SymmetricKey c)) (Maybe (SymmetricKey c)))
-> SymmetricState c h
-> Const (Maybe (SymmetricKey c)) (SymmetricState c h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (SymmetricKey c)
 -> Const (Maybe (SymmetricKey c)) (Maybe (SymmetricKey c)))
-> CipherState c -> Const (Maybe (SymmetricKey c)) (CipherState c)
forall c (f :: * -> *).
Functor f =>
(Maybe (SymmetricKey c) -> f (Maybe (SymmetricKey c)))
-> CipherState c -> f (CipherState c)
csk
        let dhLen :: Int
dhLen     = Proxy d -> Int
forall d (proxy :: * -> *). DH d => proxy d -> Int
forall (proxy :: * -> *). proxy d -> Int
dhLength (Proxy d
forall {k} (t :: k). Proxy t
Proxy :: Proxy d)
            lenToRead :: Int
lenToRead = if Maybe (SymmetricKey c) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (SymmetricKey c)
k then Int
dhLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16 else Int
dhLen

        ScrubbedBytes
buf <- Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer
        SymmetricState c h
ss  <- Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
-> Handshake c d h (SymmetricState c h)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState
        let (ScrubbedBytes
b, ScrubbedBytes
rest) = Int -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
splitAt Int
lenToRead ScrubbedBytes
buf
        (ScrubbedBytes
pk, SymmetricState c h
ss') <- Ciphertext c
-> SymmetricState c h
-> Handshake c d h (ScrubbedBytes, SymmetricState c h)
forall (m :: * -> *) c h.
(MonadThrow m, Cipher c, Hash h) =>
Ciphertext c
-> SymmetricState c h -> m (ScrubbedBytes, SymmetricState c h)
decryptAndHash (ScrubbedBytes -> Ciphertext c
forall c. Cipher c => ScrubbedBytes -> Ciphertext c
cipherBytesToText ScrubbedBytes
b) SymmetricState c h
ss
        PublicKey d
pk' <- Handshake c d h (PublicKey d)
-> (PublicKey d -> Handshake c d h (PublicKey d))
-> Maybe (PublicKey d)
-> Handshake c d h (PublicKey d)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (NoiseException -> Handshake c d h (PublicKey d)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (NoiseException -> Handshake c d h (PublicKey d))
-> (ExceptionKeyType -> NoiseException)
-> ExceptionKeyType
-> Handshake c d h (PublicKey d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptionKeyType -> NoiseException
InvalidKey (ExceptionKeyType -> Handshake c d h (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall a b. (a -> b) -> a -> b
$ ExceptionKeyType
RemoteStatic)
                     PublicKey d -> Handshake c d h (PublicKey d)
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return
                     (ScrubbedBytes -> Maybe (PublicKey d)
forall d. DH d => ScrubbedBytes -> Maybe (PublicKey d)
dhBytesToPub ScrubbedBytes
pk)

        (HandshakeOpts d -> Identity (HandshakeOpts d))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Identity (HandshakeOpts d))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ((Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
    -> HandshakeOpts d -> Identity (HandshakeOpts d))
-> (Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
-> HandshakeState c d h
-> Identity (HandshakeState c d h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
-> HandshakeOpts d -> Identity (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRemoteStatic ((Maybe (PublicKey d) -> Identity (Maybe (PublicKey d)))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> Maybe (PublicKey d) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= PublicKey d -> Maybe (PublicKey d)
forall a. a -> Maybe a
Just PublicKey d
pk'
        (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState        ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> SymmetricState c h -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= SymmetricState c h
ss'
        (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer             ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ScrubbedBytes
rest

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ EE ] -----------------------------------------------------------------------

interpretToken HandshakeRole
_ (Ee r
next) = do
  ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalEphemeral  ExceptionKeyType
LocalEphemeral
  PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteEphemeral ExceptionKeyType
RemoteEphemeral
  (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ ES ] -----------------------------------------------------------------------

interpretToken HandshakeRole
_ (Es r
next) = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole

  if HandshakeRole
myRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
InitiatorRole then do
    PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteStatic   ExceptionKeyType
RemoteStatic
    ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalEphemeral ExceptionKeyType
LocalEphemeral
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)
  else do
    ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalStatic     ExceptionKeyType
LocalStatic
    PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteEphemeral ExceptionKeyType
RemoteEphemeral
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ SE ] -----------------------------------------------------------------------

interpretToken HandshakeRole
_ (Se r
next) = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole

  if HandshakeRole
myRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
InitiatorRole then do
    ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalStatic     ExceptionKeyType
LocalStatic
    PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteEphemeral ExceptionKeyType
RemoteEphemeral
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)
  else do
    PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteStatic   ExceptionKeyType
RemoteStatic
    ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalEphemeral ExceptionKeyType
LocalEphemeral
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ SS ] -----------------------------------------------------------------------

interpretToken HandshakeRole
_ (Ss r
next) = do
  ~(SecretKey d
sk, PublicKey d
_) <- Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair   (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalStatic  ExceptionKeyType
LocalStatic
  PublicKey d
rpk      <- Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteStatic ExceptionKeyType
RemoteStatic
  (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey (SecretKey d -> PublicKey d -> ScrubbedBytes
forall d. DH d => SecretKey d -> PublicKey d -> ScrubbedBytes
dhPerform SecretKey d
sk PublicKey d
rpk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

-- [ PSK ] -----------------------------------------------------------------------

interpretToken HandshakeRole
_ (Psk r
next) = do
  ScrubbedBytes
input <- Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) Catch)
  ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall c d h r.
Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) Catch)
  r
-> Handshake c d h r
Handshake (Coroutine
   (Request HandshakeResult ScrubbedBytes)
   (StateT (HandshakeState c d h) Catch)
   ScrubbedBytes
 -> Handshake c d h ScrubbedBytes)
-> (HandshakeResult
    -> Coroutine
         (Request HandshakeResult ScrubbedBytes)
         (StateT (HandshakeState c d h) Catch)
         ScrubbedBytes)
-> HandshakeResult
-> Handshake c d h ScrubbedBytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeResult
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) Catch)
     ScrubbedBytes
forall (m :: * -> *) x y.
Monad m =>
x -> Coroutine (Request x y) m y
request (HandshakeResult -> Handshake c d h ScrubbedBytes)
-> HandshakeResult -> Handshake c d h ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ HandshakeResult
HandshakeResultNeedPSK
  (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall c h.
(Cipher c, Hash h) =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKeyAndHash ScrubbedBytes
input

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

processMsgPattern :: (Cipher c, DH d, Hash h)
                  => HandshakeRole
                  -> MessagePattern r
                  -> Handshake c d h r
processMsgPattern :: forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> MessagePattern r -> Handshake c d h r
processMsgPattern HandshakeRole
opRole MessagePattern r
mp = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole
  ScrubbedBytes
buf    <- Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer
  ScrubbedBytes
input  <- Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) Catch)
  ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall c d h r.
Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) Catch)
  r
-> Handshake c d h r
Handshake (Coroutine
   (Request HandshakeResult ScrubbedBytes)
   (StateT (HandshakeState c d h) Catch)
   ScrubbedBytes
 -> Handshake c d h ScrubbedBytes)
-> (HandshakeResult
    -> Coroutine
         (Request HandshakeResult ScrubbedBytes)
         (StateT (HandshakeState c d h) Catch)
         ScrubbedBytes)
-> HandshakeResult
-> Handshake c d h ScrubbedBytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HandshakeResult
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) Catch)
     ScrubbedBytes
forall (m :: * -> *) x y.
Monad m =>
x -> Coroutine (Request x y) m y
request (HandshakeResult -> Handshake c d h ScrubbedBytes)
-> HandshakeResult -> Handshake c d h ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> HandshakeResult
HandshakeResultMessage ScrubbedBytes
buf

  if HandshakeRole
opRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
myRole then do
    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ScrubbedBytes
forall a. Monoid a => a
mempty
    r
r  <- (forall x. Token x -> Handshake c d h x)
-> MessagePattern r -> Handshake c d h r
forall (g :: * -> *) (f :: * -> *) a.
Applicative g =>
(forall x. f x -> g x) -> Ap f a -> g a
runAp (HandshakeRole -> Token x -> Handshake c d h x
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretToken HandshakeRole
opRole) MessagePattern r
mp
    SymmetricState c h
ss <- Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
-> Handshake c d h (SymmetricState c h)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState
    (Ciphertext c
encPayload, SymmetricState c h
ss') <- ScrubbedBytes
-> SymmetricState c h
-> Handshake c d h (Ciphertext c, SymmetricState c h)
forall (m :: * -> *) c h.
(MonadThrow m, Cipher c, Hash h) =>
ScrubbedBytes
-> SymmetricState c h -> m (Ciphertext c, SymmetricState c h)
encryptAndHash ScrubbedBytes
input SymmetricState c h
ss
    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer      ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a.
(MonadState s m, Semigroup a) =>
ASetter' s a -> a -> m ()
<>= Ciphertext c -> ScrubbedBytes
forall c. Cipher c => Ciphertext c -> ScrubbedBytes
cipherTextToBytes Ciphertext c
encPayload
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> SymmetricState c h -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= SymmetricState c h
ss'
    r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
r

  else do
    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ScrubbedBytes
input
    r
r         <- (forall x. Token x -> Handshake c d h x)
-> MessagePattern r -> Handshake c d h r
forall (g :: * -> *) (f :: * -> *) a.
Applicative g =>
(forall x. f x -> g x) -> Ap f a -> g a
runAp (HandshakeRole -> Token x -> Handshake c d h x
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretToken HandshakeRole
opRole) MessagePattern r
mp
    ScrubbedBytes
remaining <- Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
-> Handshake c d h ScrubbedBytes
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer
    SymmetricState c h
ss        <- Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
-> Handshake c d h (SymmetricState c h)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState
    (ScrubbedBytes
decPayload, SymmetricState c h
ss') <- Ciphertext c
-> SymmetricState c h
-> Handshake c d h (ScrubbedBytes, SymmetricState c h)
forall (m :: * -> *) c h.
(MonadThrow m, Cipher c, Hash h) =>
Ciphertext c
-> SymmetricState c h -> m (ScrubbedBytes, SymmetricState c h)
decryptAndHash (ScrubbedBytes -> Ciphertext c
forall c. Cipher c => ScrubbedBytes -> Ciphertext c
cipherBytesToText ScrubbedBytes
remaining) SymmetricState c h
ss
    (ScrubbedBytes -> Identity ScrubbedBytes)
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer      ((ScrubbedBytes -> Identity ScrubbedBytes)
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> ScrubbedBytes -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ScrubbedBytes
decPayload
    (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> SymmetricState c h -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= SymmetricState c h
ss'
    r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
r

interpretPreToken :: (Cipher c, DH d, Hash h)
                  => HandshakeRole
                  -> Token r
                  -> Handshake c d h r
interpretPreToken :: forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretPreToken HandshakeRole
opRole (E r
next) = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole
  PublicKey d
pk <- if HandshakeRole
opRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
myRole
    then (SecretKey d, PublicKey d) -> PublicKey d
forall a b. (a, b) -> b
snd ((SecretKey d, PublicKey d) -> PublicKey d)
-> Handshake c d h (SecretKey d, PublicKey d)
-> Handshake c d h (PublicKey d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalEphemeral ExceptionKeyType
LocalEphemeral
    else Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteEphemeral ExceptionKeyType
RemoteEphemeral

  (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash (PublicKey d -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes PublicKey d
pk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

interpretPreToken HandshakeRole
opRole (S r
next) = do
  HandshakeRole
myRole <- Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting HandshakeRole (HandshakeState c d h) HandshakeRole
 -> Handshake c d h HandshakeRole)
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> Handshake c d h HandshakeRole
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole
  PublicKey d
pk <- if HandshakeRole
opRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
myRole
    then (SecretKey d, PublicKey d) -> PublicKey d
forall a b. (a, b) -> b
snd ((SecretKey d, PublicKey d) -> PublicKey d)
-> Handshake c d h (SecretKey d, PublicKey d)
-> Handshake c d h (PublicKey d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
-> ExceptionKeyType -> Handshake c d h (SecretKey d, PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair (Maybe (SecretKey d, PublicKey d)
 -> f (Maybe (SecretKey d, PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (KeyPair d) -> f (Maybe (KeyPair d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (SecretKey d, PublicKey d))
hoLocalStatic ExceptionKeyType
LocalStatic
    else Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey (Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(Maybe (PublicKey d) -> f (Maybe (PublicKey d)))
-> HandshakeOpts d -> f (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
hoRemoteStatic ExceptionKeyType
RemoteStatic

  (SymmetricState c h -> Identity (SymmetricState c h))
-> HandshakeState c d h -> Identity (HandshakeState c d h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState ((SymmetricState c h -> Identity (SymmetricState c h))
 -> HandshakeState c d h -> Identity (HandshakeState c d h))
-> (SymmetricState c h -> SymmetricState c h) -> Handshake c d h ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash (PublicKey d -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes PublicKey d
pk)

  r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

interpretPreToken HandshakeRole
_ Token r
_ = NoiseException -> Handshake c d h r
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM NoiseException
InvalidPattern

interpretMessage :: (Cipher c, DH d, Hash h)
                 => Message r
                 -> Handshake c d h r
interpretMessage :: forall c d h r.
(Cipher c, DH d, Hash h) =>
Message r -> Handshake c d h r
interpretMessage (PreInitiator MessagePattern ()
mp r
next) =
  (forall x. Token x -> Handshake c d h x)
-> MessagePattern () -> Handshake c d h ()
forall (g :: * -> *) (f :: * -> *) a.
Applicative g =>
(forall x. f x -> g x) -> Ap f a -> g a
runAp (HandshakeRole -> Token x -> Handshake c d h x
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretPreToken HandshakeRole
InitiatorRole) MessagePattern ()
mp Handshake c d h () -> Handshake c d h r -> Handshake c d h r
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

interpretMessage (PreResponder MessagePattern ()
mp r
next) =
  (forall x. Token x -> Handshake c d h x)
-> MessagePattern () -> Handshake c d h ()
forall (g :: * -> *) (f :: * -> *) a.
Applicative g =>
(forall x. f x -> g x) -> Ap f a -> g a
runAp (HandshakeRole -> Token x -> Handshake c d h x
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> Token r -> Handshake c d h r
interpretPreToken HandshakeRole
ResponderRole) MessagePattern ()
mp Handshake c d h () -> Handshake c d h r -> Handshake c d h r
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

interpretMessage (Initiator MessagePattern ()
mp r
next) =
  HandshakeRole -> MessagePattern () -> Handshake c d h ()
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> MessagePattern r -> Handshake c d h r
processMsgPattern HandshakeRole
InitiatorRole MessagePattern ()
mp Handshake c d h () -> Handshake c d h r -> Handshake c d h r
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

interpretMessage (Responder MessagePattern ()
mp r
next) =
  HandshakeRole -> MessagePattern () -> Handshake c d h ()
forall c d h r.
(Cipher c, DH d, Hash h) =>
HandshakeRole -> MessagePattern r -> Handshake c d h r
processMsgPattern HandshakeRole
ResponderRole MessagePattern ()
mp Handshake c d h () -> Handshake c d h r -> Handshake c d h r
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> r -> Handshake c d h r
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return r
next

runHandshakePattern :: (Cipher c, DH d, Hash h)
                    => HandshakePattern
                    -> Handshake c d h ()
runHandshakePattern :: forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakePattern -> Handshake c d h ()
runHandshakePattern HandshakePattern
hp = (forall x. Message x -> Handshake c d h x)
-> Ap Message () -> Handshake c d h ()
forall (g :: * -> *) (f :: * -> *) a.
Applicative g =>
(forall x. f x -> g x) -> Ap f a -> g a
runAp Message x -> Handshake c d h x
forall x. Message x -> Handshake c d h x
forall c d h r.
(Cipher c, DH d, Hash h) =>
Message r -> Handshake c d h r
interpretMessage (Ap Message () -> Handshake c d h ())
-> Ap Message () -> Handshake c d h ()
forall a b. (a -> b) -> a -> b
$ HandshakePattern
hp HandshakePattern
-> Getting (Ap Message ()) HandshakePattern (Ap Message ())
-> Ap Message ()
forall s a. s -> Getting a s a -> a
^. Getting (Ap Message ()) HandshakePattern (Ap Message ())
Lens' HandshakePattern (Ap Message ())
hpMsgSeq

getPublicKey :: Lens' (HandshakeOpts d) (Maybe (PublicKey d))
             -> ExceptionKeyType
             -> Handshake c d h (PublicKey d)
getPublicKey :: forall d c h.
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
getPublicKey Lens' (HandshakeOpts d) (Maybe (PublicKey d))
k ExceptionKeyType
keyType = do
  Maybe (PublicKey d)
r <- Getting
  (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
-> Handshake c d h (Maybe (PublicKey d))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting
   (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
 -> Handshake c d h (Maybe (PublicKey d)))
-> Getting
     (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
-> Handshake c d h (Maybe (PublicKey d))
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
-> HandshakeState c d h
-> Const (Maybe (PublicKey d)) (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const (Maybe (PublicKey d)) (HandshakeState c d h))
-> ((Maybe (PublicKey d)
     -> Const (Maybe (PublicKey d)) (Maybe (PublicKey d)))
    -> HandshakeOpts d
    -> Const (Maybe (PublicKey d)) (HandshakeOpts d))
-> Getting
     (Maybe (PublicKey d)) (HandshakeState c d h) (Maybe (PublicKey d))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (PublicKey d)
 -> Const (Maybe (PublicKey d)) (Maybe (PublicKey d)))
-> HandshakeOpts d -> Const (Maybe (PublicKey d)) (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (PublicKey d))
k
  Handshake c d h (PublicKey d)
-> (PublicKey d -> Handshake c d h (PublicKey d))
-> Maybe (PublicKey d)
-> Handshake c d h (PublicKey d)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (NoiseException -> Handshake c d h (PublicKey d)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (NoiseException -> Handshake c d h (PublicKey d))
-> (ExceptionKeyType -> NoiseException)
-> ExceptionKeyType
-> Handshake c d h (PublicKey d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptionKeyType -> NoiseException
KeyMissing (ExceptionKeyType -> Handshake c d h (PublicKey d))
-> ExceptionKeyType -> Handshake c d h (PublicKey d)
forall a b. (a -> b) -> a -> b
$ ExceptionKeyType
keyType) PublicKey d -> Handshake c d h (PublicKey d)
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PublicKey d)
r

getKeyPair :: Lens' (HandshakeOpts d) (Maybe (KeyPair d))
           -> ExceptionKeyType
           -> Handshake c d h (KeyPair d)
getKeyPair :: forall d c h.
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
getKeyPair Lens' (HandshakeOpts d) (Maybe (KeyPair d))
k ExceptionKeyType
keyType = do
  Maybe (KeyPair d)
r <- Getting
  (Maybe (KeyPair d)) (HandshakeState c d h) (Maybe (KeyPair d))
-> Handshake c d h (Maybe (KeyPair d))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Getting
   (Maybe (KeyPair d)) (HandshakeState c d h) (Maybe (KeyPair d))
 -> Handshake c d h (Maybe (KeyPair d)))
-> Getting
     (Maybe (KeyPair d)) (HandshakeState c d h) (Maybe (KeyPair d))
-> Handshake c d h (Maybe (KeyPair d))
forall a b. (a -> b) -> a -> b
$ (HandshakeOpts d -> Const (Maybe (KeyPair d)) (HandshakeOpts d))
-> HandshakeState c d h
-> Const (Maybe (KeyPair d)) (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const (Maybe (KeyPair d)) (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const (Maybe (KeyPair d)) (HandshakeState c d h))
-> ((Maybe (KeyPair d)
     -> Const (Maybe (KeyPair d)) (Maybe (KeyPair d)))
    -> HandshakeOpts d -> Const (Maybe (KeyPair d)) (HandshakeOpts d))
-> Getting
     (Maybe (KeyPair d)) (HandshakeState c d h) (Maybe (KeyPair d))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (KeyPair d)
 -> Const (Maybe (KeyPair d)) (Maybe (KeyPair d)))
-> HandshakeOpts d -> Const (Maybe (KeyPair d)) (HandshakeOpts d)
Lens' (HandshakeOpts d) (Maybe (KeyPair d))
k
  Handshake c d h (KeyPair d)
-> (KeyPair d -> Handshake c d h (KeyPair d))
-> Maybe (KeyPair d)
-> Handshake c d h (KeyPair d)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (NoiseException -> Handshake c d h (KeyPair d)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (NoiseException -> Handshake c d h (KeyPair d))
-> (ExceptionKeyType -> NoiseException)
-> ExceptionKeyType
-> Handshake c d h (KeyPair d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptionKeyType -> NoiseException
KeyMissing (ExceptionKeyType -> Handshake c d h (KeyPair d))
-> ExceptionKeyType -> Handshake c d h (KeyPair d)
forall a b. (a -> b) -> a -> b
$ ExceptionKeyType
keyType) KeyPair d -> Handshake c d h (KeyPair d)
forall a. a -> Handshake c d h a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (KeyPair d)
r