{-# LANGUAGE TemplateHaskell, ScopedTypeVariables #-}
-------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.NoiseState
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Internal.NoiseState where

import Control.Lens
import Control.Monad.Catch.Pure
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.State
import Data.ByteArray (ScrubbedBytes)

import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.Handshake.Interpreter
import Crypto.Noise.Internal.Handshake.Pattern (HandshakePattern)
import Crypto.Noise.Internal.Handshake.State
import Crypto.Noise.Internal.SymmetricState (split)

-- | This type represents the state of an entire Noise conversation, and it is
--   used both during the handshake and for every message read and written
--   thereafter (transport messages). It is parameterized by the 'Cipher', 'DH'
--   method, and 'Hash' to be used.
data NoiseState c d h =
  NoiseState { forall c d h. NoiseState c d h -> HandshakeState c d h
_nsHandshakeState       :: HandshakeState c d h
             , forall c d h. NoiseState c d h -> HandshakePattern
_nsHandshakePattern     :: HandshakePattern
             , forall c d h.
NoiseState c d h -> Maybe (ScrubbedBytes -> Handshake c d h ())
_nsHandshakeSuspension  :: Maybe (ScrubbedBytes -> Handshake c d h ())
             , forall c d h. NoiseState c d h -> Maybe (CipherState c)
_nsSendingCipherState   :: Maybe (CipherState c)
             , forall c d h. NoiseState c d h -> Maybe (CipherState c)
_nsReceivingCipherState :: Maybe (CipherState c)
             }

$(makeLenses ''NoiseState)

-- | Creates a 'NoiseState' from the given handshake options and pattern.
noiseState :: (Cipher c, DH d, Hash h)
           => HandshakeOpts d
           -> HandshakePattern
           -> NoiseState c d h
noiseState :: forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> NoiseState c d h
noiseState HandshakeOpts d
ho HandshakePattern
hp =
  NoiseState { _nsHandshakeState :: HandshakeState c d h
_nsHandshakeState       = HandshakeOpts d -> HandshakePattern -> HandshakeState c d h
forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> HandshakeState c d h
handshakeState HandshakeOpts d
ho HandshakePattern
hp
             , _nsHandshakePattern :: HandshakePattern
_nsHandshakePattern     = HandshakePattern
hp
             , _nsHandshakeSuspension :: Maybe (ScrubbedBytes -> Handshake c d h ())
_nsHandshakeSuspension  = Maybe (ScrubbedBytes -> Handshake c d h ())
forall a. Maybe a
Nothing
             , _nsSendingCipherState :: Maybe (CipherState c)
_nsSendingCipherState   = Maybe (CipherState c)
forall a. Maybe a
Nothing
             , _nsReceivingCipherState :: Maybe (CipherState c)
_nsReceivingCipherState = Maybe (CipherState c)
forall a. Maybe a
Nothing
             }

-- | Resumes a handshake in progress using the given input data.
resumeHandshake :: (MonadThrow m, Cipher c, DH d, Hash h)
                => ScrubbedBytes
                -> NoiseState c d h
                -> m (HandshakeResult, NoiseState c d h)
resumeHandshake :: forall (m :: * -> *) c d h.
(MonadThrow m, Cipher c, DH d, Hash h) =>
ScrubbedBytes
-> NoiseState c d h -> m (HandshakeResult, NoiseState c d h)
resumeHandshake ScrubbedBytes
msg NoiseState c d h
ns = case NoiseState c d h
ns NoiseState c d h
-> Getting
     (Maybe (ScrubbedBytes -> Handshake c d h ()))
     (NoiseState c d h)
     (Maybe (ScrubbedBytes -> Handshake c d h ()))
-> Maybe (ScrubbedBytes -> Handshake c d h ())
forall s a. s -> Getting a s a -> a
^. Getting
  (Maybe (ScrubbedBytes -> Handshake c d h ()))
  (NoiseState c d h)
  (Maybe (ScrubbedBytes -> Handshake c d h ()))
forall c d h (f :: * -> *).
Functor f =>
(Maybe (ScrubbedBytes -> Handshake c d h ())
 -> f (Maybe (ScrubbedBytes -> Handshake c d h ())))
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakeSuspension of
  Maybe (ScrubbedBytes -> Handshake c d h ())
Nothing -> do
    let hp :: HandshakePattern
hp = NoiseState c d h
ns NoiseState c d h
-> Getting HandshakePattern (NoiseState c d h) HandshakePattern
-> HandshakePattern
forall s a. s -> Getting a s a -> a
^. Getting HandshakePattern (NoiseState c d h) HandshakePattern
forall c d h (f :: * -> *).
Functor f =>
(HandshakePattern -> f HandshakePattern)
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakePattern
    (HandshakeResult
_, NoiseState c d h
ns') <- Handshake c d h () -> m (HandshakeResult, NoiseState c d h)
forall {m :: * -> *}.
MonadThrow m =>
Handshake c d h () -> m (HandshakeResult, NoiseState c d h)
runInterpreter (Handshake c d h () -> m (HandshakeResult, NoiseState c d h))
-> (HandshakePattern -> Handshake c d h ())
-> HandshakePattern
-> m (HandshakeResult, NoiseState c d h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakePattern -> Handshake c d h ()
forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakePattern -> Handshake c d h ()
runHandshakePattern (HandshakePattern -> m (HandshakeResult, NoiseState c d h))
-> HandshakePattern -> m (HandshakeResult, NoiseState c d h)
forall a b. (a -> b) -> a -> b
$ HandshakePattern
hp
    ScrubbedBytes
-> NoiseState c d h -> m (HandshakeResult, NoiseState c d h)
forall (m :: * -> *) c d h.
(MonadThrow m, Cipher c, DH d, Hash h) =>
ScrubbedBytes
-> NoiseState c d h -> m (HandshakeResult, NoiseState c d h)
resumeHandshake ScrubbedBytes
msg NoiseState c d h
ns'

  Just ScrubbedBytes -> Handshake c d h ()
s -> Handshake c d h () -> m (HandshakeResult, NoiseState c d h)
forall {m :: * -> *}.
MonadThrow m =>
Handshake c d h () -> m (HandshakeResult, NoiseState c d h)
runInterpreter (Handshake c d h () -> m (HandshakeResult, NoiseState c d h))
-> (ScrubbedBytes -> Handshake c d h ())
-> ScrubbedBytes
-> m (HandshakeResult, NoiseState c d h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> Handshake c d h ()
s (ScrubbedBytes -> m (HandshakeResult, NoiseState c d h))
-> ScrubbedBytes -> m (HandshakeResult, NoiseState c d h)
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes
msg

  where
    runInterpreter :: Handshake c d h () -> m (HandshakeResult, NoiseState c d h)
runInterpreter Handshake c d h ()
i = do
      let result :: Either
  SomeException
  (Either
     (Request
        HandshakeResult
        ScrubbedBytes
        (Coroutine
           (Request HandshakeResult ScrubbedBytes)
           (StateT (HandshakeState c d h) (CatchT Identity))
           ()))
     (),
   HandshakeState c d h)
result = Catch
  (Either
     (Request
        HandshakeResult
        ScrubbedBytes
        (Coroutine
           (Request HandshakeResult ScrubbedBytes)
           (StateT (HandshakeState c d h) (CatchT Identity))
           ()))
     (),
   HandshakeState c d h)
-> Either
     SomeException
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        (),
      HandshakeState c d h)
forall a. Catch a -> Either SomeException a
runCatch (Catch
   (Either
      (Request
         HandshakeResult
         ScrubbedBytes
         (Coroutine
            (Request HandshakeResult ScrubbedBytes)
            (StateT (HandshakeState c d h) (CatchT Identity))
            ()))
      (),
    HandshakeState c d h)
 -> Either
      SomeException
      (Either
         (Request
            HandshakeResult
            ScrubbedBytes
            (Coroutine
               (Request HandshakeResult ScrubbedBytes)
               (StateT (HandshakeState c d h) (CatchT Identity))
               ()))
         (),
       HandshakeState c d h))
-> (HandshakeState c d h
    -> Catch
         (Either
            (Request
               HandshakeResult
               ScrubbedBytes
               (Coroutine
                  (Request HandshakeResult ScrubbedBytes)
                  (StateT (HandshakeState c d h) (CatchT Identity))
                  ()))
            (),
          HandshakeState c d h))
-> HandshakeState c d h
-> Either
     SomeException
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        (),
      HandshakeState c d h)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT
  (HandshakeState c d h)
  (CatchT Identity)
  (Either
     (Request
        HandshakeResult
        ScrubbedBytes
        (Coroutine
           (Request HandshakeResult ScrubbedBytes)
           (StateT (HandshakeState c d h) (CatchT Identity))
           ()))
     ())
-> HandshakeState c d h
-> Catch
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        (),
      HandshakeState c d h)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) (CatchT Identity))
  ()
-> StateT
     (HandshakeState c d h)
     (CatchT Identity)
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        ())
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine
   (Request HandshakeResult ScrubbedBytes)
   (StateT (HandshakeState c d h) (CatchT Identity))
   ()
 -> StateT
      (HandshakeState c d h)
      (CatchT Identity)
      (Either
         (Request
            HandshakeResult
            ScrubbedBytes
            (Coroutine
               (Request HandshakeResult ScrubbedBytes)
               (StateT (HandshakeState c d h) (CatchT Identity))
               ()))
         ()))
-> (Handshake c d h ()
    -> Coroutine
         (Request HandshakeResult ScrubbedBytes)
         (StateT (HandshakeState c d h) (CatchT Identity))
         ())
-> Handshake c d h ()
-> StateT
     (HandshakeState c d h)
     (CatchT Identity)
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handshake c d h ()
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) (CatchT Identity))
     ()
forall c d h r.
Handshake c d h r
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) (CatchT Identity))
     r
runHandshake (Handshake c d h ()
 -> StateT
      (HandshakeState c d h)
      (CatchT Identity)
      (Either
         (Request
            HandshakeResult
            ScrubbedBytes
            (Coroutine
               (Request HandshakeResult ScrubbedBytes)
               (StateT (HandshakeState c d h) (CatchT Identity))
               ()))
         ()))
-> Handshake c d h ()
-> StateT
     (HandshakeState c d h)
     (CatchT Identity)
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        ())
forall a b. (a -> b) -> a -> b
$ Handshake c d h ()
i)
                                        (HandshakeState c d h
 -> Either
      SomeException
      (Either
         (Request
            HandshakeResult
            ScrubbedBytes
            (Coroutine
               (Request HandshakeResult ScrubbedBytes)
               (StateT (HandshakeState c d h) (CatchT Identity))
               ()))
         (),
       HandshakeState c d h))
-> HandshakeState c d h
-> Either
     SomeException
     (Either
        (Request
           HandshakeResult
           ScrubbedBytes
           (Coroutine
              (Request HandshakeResult ScrubbedBytes)
              (StateT (HandshakeState c d h) (CatchT Identity))
              ()))
        (),
      HandshakeState c d h)
forall a b. (a -> b) -> a -> b
$ NoiseState c d h
ns NoiseState c d h
-> Getting
     (HandshakeState c d h) (NoiseState c d h) (HandshakeState c d h)
-> HandshakeState c d h
forall s a. s -> Getting a s a -> a
^. Getting
  (HandshakeState c d h) (NoiseState c d h) (HandshakeState c d h)
forall c d h (f :: * -> *).
Functor f =>
(HandshakeState c d h -> f (HandshakeState c d h))
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakeState
      case Either
  SomeException
  (Either
     (Request
        HandshakeResult
        ScrubbedBytes
        (Coroutine
           (Request HandshakeResult ScrubbedBytes)
           (StateT (HandshakeState c d h) (CatchT Identity))
           ()))
     (),
   HandshakeState c d h)
result of
        -- The interpreter threw an exception. Propagate it up the chain.
        Left SomeException
err -> SomeException -> m (HandshakeResult, NoiseState c d h)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
err
        -- The interpreter did not throw an exception. Determine if it finished
        -- running.
        Right (Either
  (Request
     HandshakeResult
     ScrubbedBytes
     (Coroutine
        (Request HandshakeResult ScrubbedBytes)
        (StateT (HandshakeState c d h) (CatchT Identity))
        ()))
  ()
suspension, HandshakeState c d h
hs) -> case Either
  (Request
     HandshakeResult
     ScrubbedBytes
     (Coroutine
        (Request HandshakeResult ScrubbedBytes)
        (StateT (HandshakeState c d h) (CatchT Identity))
        ()))
  ()
suspension of
          -- The handshake pattern has not finished running. Save the suspension
          -- and the mutated HandshakeState and return what was yielded.
          Left (Request HandshakeResult
req ScrubbedBytes
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) (CatchT Identity))
     ()
resp) -> do
            let ns' :: NoiseState c d h
ns' = NoiseState c d h
ns NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (Maybe (ScrubbedBytes -> Handshake c d h ())
 -> Identity (Maybe (ScrubbedBytes -> Handshake c d h ())))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(Maybe (ScrubbedBytes -> Handshake c d h ())
 -> f (Maybe (ScrubbedBytes -> Handshake c d h ())))
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakeSuspension ((Maybe (ScrubbedBytes -> Handshake c d h ())
  -> Identity (Maybe (ScrubbedBytes -> Handshake c d h ())))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> (ScrubbedBytes -> Handshake c d h ())
-> NoiseState c d h
-> NoiseState c d h
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ (Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) (CatchT Identity))
  ()
-> Handshake c d h ()
forall c d h r.
Coroutine
  (Request HandshakeResult ScrubbedBytes)
  (StateT (HandshakeState c d h) (CatchT Identity))
  r
-> Handshake c d h r
Handshake (Coroutine
   (Request HandshakeResult ScrubbedBytes)
   (StateT (HandshakeState c d h) (CatchT Identity))
   ()
 -> Handshake c d h ())
-> (ScrubbedBytes
    -> Coroutine
         (Request HandshakeResult ScrubbedBytes)
         (StateT (HandshakeState c d h) (CatchT Identity))
         ())
-> ScrubbedBytes
-> Handshake c d h ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes
-> Coroutine
     (Request HandshakeResult ScrubbedBytes)
     (StateT (HandshakeState c d h) (CatchT Identity))
     ()
resp)
                         NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (HandshakeState c d h -> Identity (HandshakeState c d h))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(HandshakeState c d h -> f (HandshakeState c d h))
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakeState      ((HandshakeState c d h -> Identity (HandshakeState c d h))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> HandshakeState c d h -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a b -> b -> s -> t
.~ HandshakeState c d h
hs
            (HandshakeResult, NoiseState c d h)
-> m (HandshakeResult, NoiseState c d h)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (HandshakeResult
req, NoiseState c d h
ns')
          -- The handshake pattern has finished running. Create the CipherStates.
          Right ()
_ -> do
            let (CipherState c
cs1, CipherState c
cs2) = SymmetricState c h -> (CipherState c, CipherState c)
forall c h.
(Cipher c, Hash h) =>
SymmetricState c h -> (CipherState c, CipherState c)
split (HandshakeState c d h
hs HandshakeState c d h
-> Getting
     (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
-> SymmetricState c h
forall s a. s -> Getting a s a -> a
^. Getting
  (SymmetricState c h) (HandshakeState c d h) (SymmetricState c h)
forall c1 d h1 c2 h2 (f :: * -> *).
Functor f =>
(SymmetricState c1 h1 -> f (SymmetricState c2 h2))
-> HandshakeState c1 d h1 -> f (HandshakeState c2 d h2)
hsSymmetricState)

                ns' :: NoiseState c d h
ns'        = if HandshakeState c d h
hs HandshakeState c d h
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
-> HandshakeRole
forall s a. s -> Getting a s a -> a
^. (HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> HandshakeState c d h
-> Const HandshakeRole (HandshakeState c d h)
forall c d1 h d2 (f :: * -> *).
Functor f =>
(HandshakeOpts d1 -> f (HandshakeOpts d2))
-> HandshakeState c d1 h -> f (HandshakeState c d2 h)
hsOpts ((HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
 -> HandshakeState c d h
 -> Const HandshakeRole (HandshakeState c d h))
-> ((HandshakeRole -> Const HandshakeRole HandshakeRole)
    -> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d))
-> Getting HandshakeRole (HandshakeState c d h) HandshakeRole
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HandshakeRole -> Const HandshakeRole HandshakeRole)
-> HandshakeOpts d -> Const HandshakeRole (HandshakeOpts d)
forall d (f :: * -> *).
Functor f =>
(HandshakeRole -> f HandshakeRole)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoRole HandshakeRole -> HandshakeRole -> Bool
forall a. Eq a => a -> a -> Bool
== HandshakeRole
InitiatorRole
                               then NoiseState c d h
ns NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(Maybe (CipherState c) -> f (Maybe (CipherState c)))
-> NoiseState c d h -> f (NoiseState c d h)
nsSendingCipherState   ((Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> CipherState c -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ CipherState c
cs1
                                       NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(Maybe (CipherState c) -> f (Maybe (CipherState c)))
-> NoiseState c d h -> f (NoiseState c d h)
nsReceivingCipherState ((Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> CipherState c -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ CipherState c
cs2
                               else NoiseState c d h
ns NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(Maybe (CipherState c) -> f (Maybe (CipherState c)))
-> NoiseState c d h -> f (NoiseState c d h)
nsSendingCipherState   ((Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> CipherState c -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ CipherState c
cs2
                                       NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(Maybe (CipherState c) -> f (Maybe (CipherState c)))
-> NoiseState c d h -> f (NoiseState c d h)
nsReceivingCipherState ((Maybe (CipherState c) -> Identity (Maybe (CipherState c)))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> CipherState c -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ CipherState c
cs1

                ns'' :: NoiseState c d h
ns''       = NoiseState c d h
ns' NoiseState c d h
-> (NoiseState c d h -> NoiseState c d h) -> NoiseState c d h
forall a b. a -> (a -> b) -> b
& (HandshakeState c d h -> Identity (HandshakeState c d h))
-> NoiseState c d h -> Identity (NoiseState c d h)
forall c d h (f :: * -> *).
Functor f =>
(HandshakeState c d h -> f (HandshakeState c d h))
-> NoiseState c d h -> f (NoiseState c d h)
nsHandshakeState ((HandshakeState c d h -> Identity (HandshakeState c d h))
 -> NoiseState c d h -> Identity (NoiseState c d h))
-> HandshakeState c d h -> NoiseState c d h -> NoiseState c d h
forall s t a b. ASetter s t a b -> b -> s -> t
.~ HandshakeState c d h
hs

            (HandshakeResult, NoiseState c d h)
-> m (HandshakeResult, NoiseState c d h)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ScrubbedBytes -> HandshakeResult
HandshakeResultMessage (HandshakeState c d h
hs HandshakeState c d h
-> Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
-> ScrubbedBytes
forall s a. s -> Getting a s a -> a
^. Getting ScrubbedBytes (HandshakeState c d h) ScrubbedBytes
forall c d h (f :: * -> *).
Functor f =>
(ScrubbedBytes -> f ScrubbedBytes)
-> HandshakeState c d h -> f (HandshakeState c d h)
hsMsgBuffer), NoiseState c d h
ns'')