{-# LANGUAGE TemplateHaskell, ScopedTypeVariables, GeneralizedNewtypeDeriving,
             FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
------------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.Handshake.State
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Internal.Handshake.State where

import Control.Lens
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Catch.Pure
import Control.Monad.State       (MonadState(..), StateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.ByteArray            (ScrubbedBytes, convert)
import Data.ByteString           (ByteString)
import Data.Monoid               ((<>))
import Data.Proxy

import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.Handshake.Pattern hiding (ss)
import Crypto.Noise.Internal.SymmetricState

-- | Represents the side of the conversation upon which a party resides.
data HandshakeRole = InitiatorRole | ResponderRole
                     deriving (Int -> HandshakeRole -> ShowS
[HandshakeRole] -> ShowS
HandshakeRole -> String
(Int -> HandshakeRole -> ShowS)
-> (HandshakeRole -> String)
-> ([HandshakeRole] -> ShowS)
-> Show HandshakeRole
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeRole -> ShowS
showsPrec :: Int -> HandshakeRole -> ShowS
$cshow :: HandshakeRole -> String
show :: HandshakeRole -> String
$cshowList :: [HandshakeRole] -> ShowS
showList :: [HandshakeRole] -> ShowS
Show, HandshakeRole -> HandshakeRole -> Bool
(HandshakeRole -> HandshakeRole -> Bool)
-> (HandshakeRole -> HandshakeRole -> Bool) -> Eq HandshakeRole
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandshakeRole -> HandshakeRole -> Bool
== :: HandshakeRole -> HandshakeRole -> Bool
$c/= :: HandshakeRole -> HandshakeRole -> Bool
/= :: HandshakeRole -> HandshakeRole -> Bool
Eq)

-- | Represents the various options and keys for a handshake parameterized by
--   the 'DH' method.
data HandshakeOpts d =
  HandshakeOpts { forall d. HandshakeOpts d -> HandshakeRole
_hoRole                :: HandshakeRole
                , forall d. HandshakeOpts d -> Plaintext
_hoPrologue            :: Plaintext
                , forall d. HandshakeOpts d -> Maybe (KeyPair d)
_hoLocalEphemeral      :: Maybe (KeyPair d)
                , forall d. HandshakeOpts d -> Maybe (KeyPair d)
_hoLocalStatic         :: Maybe (KeyPair d)
                , forall d. HandshakeOpts d -> Maybe (PublicKey d)
_hoRemoteEphemeral     :: Maybe (PublicKey d)
                , forall d. HandshakeOpts d -> Maybe (PublicKey d)
_hoRemoteStatic        :: Maybe (PublicKey d)
                }

$(makeLenses ''HandshakeOpts)

-- | Holds all state associated with the interpreter.
data HandshakeState c d h =
  HandshakeState { forall c d h. HandshakeState c d h -> SymmetricState c h
_hsSymmetricState :: SymmetricState c h
                 , forall c d h. HandshakeState c d h -> HandshakeOpts d
_hsOpts           :: HandshakeOpts d
                 , forall c d h. HandshakeState c d h -> Bool
_hsPSKMode        :: Bool
                 , forall c d h. HandshakeState c d h -> Plaintext
_hsMsgBuffer      :: ScrubbedBytes
                 }

$(makeLenses ''HandshakeState)

-- | This data structure is yielded by the coroutine when more data is needed.
data HandshakeResult
  = HandshakeResultMessage ScrubbedBytes
  | HandshakeResultNeedPSK

-- | All HandshakePattern interpreters run within this Monad.
newtype Handshake c d h r =
  Handshake { forall c d h r.
Handshake c d h r
-> Coroutine
     (Request HandshakeResult Plaintext)
     (StateT (HandshakeState c d h) Catch)
     r
runHandshake :: Coroutine (Request HandshakeResult ScrubbedBytes) (StateT (HandshakeState c d h) Catch) r
            } deriving ( (forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b)
-> (forall a b. a -> Handshake c d h b -> Handshake c d h a)
-> Functor (Handshake c d h)
forall a b. a -> Handshake c d h b -> Handshake c d h a
forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h a b. a -> Handshake c d h b -> Handshake c d h a
forall c d h a b.
(a -> b) -> Handshake c d h a -> Handshake c d h b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall c d h a b.
(a -> b) -> Handshake c d h a -> Handshake c d h b
fmap :: forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b
$c<$ :: forall c d h a b. a -> Handshake c d h b -> Handshake c d h a
<$ :: forall a b. a -> Handshake c d h b -> Handshake c d h a
Functor
                       , Functor (Handshake c d h)
Functor (Handshake c d h) =>
(forall a. a -> Handshake c d h a)
-> (forall a b.
    Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b)
-> (forall a b c.
    (a -> b -> c)
    -> Handshake c d h a -> Handshake c d h b -> Handshake c d h c)
-> (forall a b.
    Handshake c d h a -> Handshake c d h b -> Handshake c d h b)
-> (forall a b.
    Handshake c d h a -> Handshake c d h b -> Handshake c d h a)
-> Applicative (Handshake c d h)
forall a. a -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h. Functor (Handshake c d h)
forall a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
forall c d h a. a -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall c d h a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall c d h a. a -> Handshake c d h a
pure :: forall a. a -> Handshake c d h a
$c<*> :: forall c d h a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
<*> :: forall a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
$cliftA2 :: forall c d h a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
liftA2 :: forall a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
$c*> :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
*> :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
$c<* :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
<* :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
Applicative
                       , Applicative (Handshake c d h)
Applicative (Handshake c d h) =>
(forall a b.
 Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b)
-> (forall a b.
    Handshake c d h a -> Handshake c d h b -> Handshake c d h b)
-> (forall a. a -> Handshake c d h a)
-> Monad (Handshake c d h)
forall a. a -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
forall c d h. Applicative (Handshake c d h)
forall c d h a. a -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall c d h a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall c d h a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
>>= :: forall a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
$c>> :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
>> :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
$creturn :: forall c d h a. a -> Handshake c d h a
return :: forall a. a -> Handshake c d h a
Monad
                       , Monad (Handshake c d h)
Monad (Handshake c d h) =>
(forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a)
-> MonadThrow (Handshake c d h)
forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a
forall c d h. Monad (Handshake c d h)
forall c d h e a.
(HasCallStack, Exception e) =>
e -> Handshake c d h a
forall (m :: * -> *).
Monad m =>
(forall e a. (HasCallStack, Exception e) => e -> m a)
-> MonadThrow m
$cthrowM :: forall c d h e a.
(HasCallStack, Exception e) =>
e -> Handshake c d h a
throwM :: forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a
MonadThrow
                       , MonadState (HandshakeState c d h)
                       )

-- | @defaultHandshakeOpts role prologue@ returns a default set of handshake
--   options. All keys are set to 'Nothing'.
defaultHandshakeOpts :: HandshakeRole
                     -> Plaintext
                     -> HandshakeOpts d
defaultHandshakeOpts :: forall d. HandshakeRole -> Plaintext -> HandshakeOpts d
defaultHandshakeOpts HandshakeRole
r Plaintext
p =
  HandshakeOpts { _hoRole :: HandshakeRole
_hoRole                = HandshakeRole
r
                , _hoPrologue :: Plaintext
_hoPrologue            = Plaintext
p
                , _hoLocalEphemeral :: Maybe (KeyPair d)
_hoLocalEphemeral      = Maybe (KeyPair d)
forall a. Maybe a
Nothing
                , _hoLocalStatic :: Maybe (KeyPair d)
_hoLocalStatic         = Maybe (KeyPair d)
forall a. Maybe a
Nothing
                , _hoRemoteEphemeral :: Maybe (PublicKey d)
_hoRemoteEphemeral     = Maybe (PublicKey d)
forall a. Maybe a
Nothing
                , _hoRemoteStatic :: Maybe (PublicKey d)
_hoRemoteStatic        = Maybe (PublicKey d)
forall a. Maybe a
Nothing
                }

-- | Sets the local ephemeral key.
setLocalEphemeral :: Maybe (KeyPair d)
                  -> HandshakeOpts d
                  -> HandshakeOpts d
setLocalEphemeral :: forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalEphemeral Maybe (KeyPair d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoLocalEphemeral = k }

-- | Sets the local static key.
setLocalStatic :: Maybe (KeyPair d)
               -> HandshakeOpts d
               -> HandshakeOpts d
setLocalStatic :: forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalStatic Maybe (KeyPair d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoLocalStatic = k }

-- | Sets the remote ephemeral key (rarely needed).
setRemoteEphemeral :: Maybe (PublicKey d)
                   -> HandshakeOpts d
                   -> HandshakeOpts d
setRemoteEphemeral :: forall d. Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d
setRemoteEphemeral Maybe (PublicKey d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoRemoteEphemeral = k }

-- | Sets the remote static key.
setRemoteStatic :: Maybe (PublicKey d)
                -> HandshakeOpts d
                -> HandshakeOpts d
setRemoteStatic :: forall d. Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d
setRemoteStatic Maybe (PublicKey d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoRemoteStatic = k }

-- | Given a protocol name, returns the full handshake name according to the
--   rules in section 8.
mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
                => ByteString
                -> proxy (c, d, h)
                -> ScrubbedBytes
mkHandshakeName :: forall c d h (proxy :: * -> *).
(Cipher c, DH d, Hash h) =>
ByteString -> proxy (c, d, h) -> Plaintext
mkHandshakeName ByteString
protoName proxy (c, d, h)
_ =
  Plaintext
"Noise_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> ByteString -> Plaintext
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
protoName Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
d Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
c Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
h
  where
    c :: Plaintext
c = Proxy c -> Plaintext
forall c (proxy :: * -> *). Cipher c => proxy c -> Plaintext
forall (proxy :: * -> *). proxy c -> Plaintext
cipherName (Proxy c
forall {k} (t :: k). Proxy t
Proxy :: Proxy c)
    d :: Plaintext
d = Proxy d -> Plaintext
forall d (proxy :: * -> *). DH d => proxy d -> Plaintext
forall (proxy :: * -> *). proxy d -> Plaintext
dhName     (Proxy d
forall {k} (t :: k). Proxy t
Proxy :: Proxy d)
    h :: Plaintext
h = Proxy h -> Plaintext
forall h (proxy :: * -> *). Hash h => proxy h -> Plaintext
forall (proxy :: * -> *). proxy h -> Plaintext
hashName   (Proxy h
forall {k} (t :: k). Proxy t
Proxy :: Proxy h)

-- | Constructs a HandshakeState from a given set of options and a protocol
--   name (such as "NN" or "IK").
handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
               => HandshakeOpts d
               -> HandshakePattern
               -> HandshakeState c d h
handshakeState :: forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> HandshakeState c d h
handshakeState HandshakeOpts d
ho HandshakePattern
hp =
  HandshakeState { _hsSymmetricState :: SymmetricState c h
_hsSymmetricState = SymmetricState c h
ss'
                 , _hsOpts :: HandshakeOpts d
_hsOpts           = HandshakeOpts d
ho
                 , _hsPSKMode :: Bool
_hsPSKMode        = HandshakePattern
hp HandshakePattern -> Getting Bool HandshakePattern Bool -> Bool
forall s a. s -> Getting a s a -> a
^. Getting Bool HandshakePattern Bool
Lens' HandshakePattern Bool
hpPSKMode
                 , _hsMsgBuffer :: Plaintext
_hsMsgBuffer      = Plaintext
forall a. Monoid a => a
mempty
                 }
  where
    ss :: SymmetricState c h
ss  = Plaintext -> SymmetricState c h
forall c h. (Cipher c, Hash h) => Plaintext -> SymmetricState c h
symmetricState (Plaintext -> SymmetricState c h)
-> Plaintext -> SymmetricState c h
forall a b. (a -> b) -> a -> b
$ ByteString -> Proxy (c, d, h) -> Plaintext
forall c d h (proxy :: * -> *).
(Cipher c, DH d, Hash h) =>
ByteString -> proxy (c, d, h) -> Plaintext
mkHandshakeName (HandshakePattern
hp HandshakePattern
-> Getting ByteString HandshakePattern ByteString -> ByteString
forall s a. s -> Getting a s a -> a
^. Getting ByteString HandshakePattern ByteString
Lens' HandshakePattern ByteString
hpName)
                                           (Proxy (c, d, h)
forall {k} (t :: k). Proxy t
Proxy :: Proxy (c, d, h))
    ss' :: SymmetricState c h
ss' = Plaintext -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
Plaintext -> SymmetricState c h -> SymmetricState c h
mixHash (HandshakeOpts d
ho HandshakeOpts d
-> Getting Plaintext (HandshakeOpts d) Plaintext -> Plaintext
forall s a. s -> Getting a s a -> a
^. Getting Plaintext (HandshakeOpts d) Plaintext
forall d (f :: * -> *).
Functor f =>
(Plaintext -> f Plaintext)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoPrologue) SymmetricState c h
ss

instance (Functor f, MonadThrow m) => MonadThrow (Coroutine f m) where
  throwM :: forall e a. (HasCallStack, Exception e) => e -> Coroutine f m a
throwM = m a -> Coroutine f m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine f m a) -> (e -> m a) -> e -> Coroutine f m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM

instance (Functor f, MonadState s m) => MonadState s (Coroutine f m) where
  get :: Coroutine f m s
get = m s -> Coroutine f m s
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> Coroutine f m ()
put = m () -> Coroutine f m ()
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> Coroutine f m ()) -> (s -> m ()) -> s -> Coroutine f m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
  state :: forall a. (s -> (a, s)) -> Coroutine f m a
state = m a -> Coroutine f m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine f m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> Coroutine f m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state