{-# LANGUAGE TemplateHaskell, ScopedTypeVariables, GeneralizedNewtypeDeriving,
FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Crypto.Noise.Internal.Handshake.State where
import Control.Lens
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Catch.Pure
import Control.Monad.State (MonadState(..), StateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.ByteArray (ScrubbedBytes, convert)
import Data.ByteString (ByteString)
import Data.Monoid ((<>))
import Data.Proxy
import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.Handshake.Pattern hiding (ss)
import Crypto.Noise.Internal.SymmetricState
data HandshakeRole = InitiatorRole | ResponderRole
deriving (Int -> HandshakeRole -> ShowS
[HandshakeRole] -> ShowS
HandshakeRole -> String
(Int -> HandshakeRole -> ShowS)
-> (HandshakeRole -> String)
-> ([HandshakeRole] -> ShowS)
-> Show HandshakeRole
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeRole -> ShowS
showsPrec :: Int -> HandshakeRole -> ShowS
$cshow :: HandshakeRole -> String
show :: HandshakeRole -> String
$cshowList :: [HandshakeRole] -> ShowS
showList :: [HandshakeRole] -> ShowS
Show, HandshakeRole -> HandshakeRole -> Bool
(HandshakeRole -> HandshakeRole -> Bool)
-> (HandshakeRole -> HandshakeRole -> Bool) -> Eq HandshakeRole
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandshakeRole -> HandshakeRole -> Bool
== :: HandshakeRole -> HandshakeRole -> Bool
$c/= :: HandshakeRole -> HandshakeRole -> Bool
/= :: HandshakeRole -> HandshakeRole -> Bool
Eq)
data HandshakeOpts d =
HandshakeOpts { forall d. HandshakeOpts d -> HandshakeRole
_hoRole :: HandshakeRole
, forall d. HandshakeOpts d -> Plaintext
_hoPrologue :: Plaintext
, forall d. HandshakeOpts d -> Maybe (KeyPair d)
_hoLocalEphemeral :: Maybe (KeyPair d)
, forall d. HandshakeOpts d -> Maybe (KeyPair d)
_hoLocalStatic :: Maybe (KeyPair d)
, forall d. HandshakeOpts d -> Maybe (PublicKey d)
_hoRemoteEphemeral :: Maybe (PublicKey d)
, forall d. HandshakeOpts d -> Maybe (PublicKey d)
_hoRemoteStatic :: Maybe (PublicKey d)
}
$(makeLenses ''HandshakeOpts)
data HandshakeState c d h =
HandshakeState { forall c d h. HandshakeState c d h -> SymmetricState c h
_hsSymmetricState :: SymmetricState c h
, forall c d h. HandshakeState c d h -> HandshakeOpts d
_hsOpts :: HandshakeOpts d
, forall c d h. HandshakeState c d h -> Bool
_hsPSKMode :: Bool
, forall c d h. HandshakeState c d h -> Plaintext
_hsMsgBuffer :: ScrubbedBytes
}
$(makeLenses ''HandshakeState)
data HandshakeResult
= HandshakeResultMessage ScrubbedBytes
| HandshakeResultNeedPSK
newtype Handshake c d h r =
Handshake { forall c d h r.
Handshake c d h r
-> Coroutine
(Request HandshakeResult Plaintext)
(StateT (HandshakeState c d h) Catch)
r
runHandshake :: Coroutine (Request HandshakeResult ScrubbedBytes) (StateT (HandshakeState c d h) Catch) r
} deriving ( (forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b)
-> (forall a b. a -> Handshake c d h b -> Handshake c d h a)
-> Functor (Handshake c d h)
forall a b. a -> Handshake c d h b -> Handshake c d h a
forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h a b. a -> Handshake c d h b -> Handshake c d h a
forall c d h a b.
(a -> b) -> Handshake c d h a -> Handshake c d h b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall c d h a b.
(a -> b) -> Handshake c d h a -> Handshake c d h b
fmap :: forall a b. (a -> b) -> Handshake c d h a -> Handshake c d h b
$c<$ :: forall c d h a b. a -> Handshake c d h b -> Handshake c d h a
<$ :: forall a b. a -> Handshake c d h b -> Handshake c d h a
Functor
, Functor (Handshake c d h)
Functor (Handshake c d h) =>
(forall a. a -> Handshake c d h a)
-> (forall a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b)
-> (forall a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c)
-> (forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b)
-> (forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a)
-> Applicative (Handshake c d h)
forall a. a -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h. Functor (Handshake c d h)
forall a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
forall c d h a. a -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall c d h a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
forall c d h a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall c d h a. a -> Handshake c d h a
pure :: forall a. a -> Handshake c d h a
$c<*> :: forall c d h a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
<*> :: forall a b.
Handshake c d h (a -> b) -> Handshake c d h a -> Handshake c d h b
$cliftA2 :: forall c d h a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
liftA2 :: forall a b c.
(a -> b -> c)
-> Handshake c d h a -> Handshake c d h b -> Handshake c d h c
$c*> :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
*> :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
$c<* :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
<* :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h a
Applicative
, Applicative (Handshake c d h)
Applicative (Handshake c d h) =>
(forall a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b)
-> (forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b)
-> (forall a. a -> Handshake c d h a)
-> Monad (Handshake c d h)
forall a. a -> Handshake c d h a
forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
forall c d h. Applicative (Handshake c d h)
forall c d h a. a -> Handshake c d h a
forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
forall c d h a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall c d h a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
>>= :: forall a b.
Handshake c d h a -> (a -> Handshake c d h b) -> Handshake c d h b
$c>> :: forall c d h a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
>> :: forall a b.
Handshake c d h a -> Handshake c d h b -> Handshake c d h b
$creturn :: forall c d h a. a -> Handshake c d h a
return :: forall a. a -> Handshake c d h a
Monad
, Monad (Handshake c d h)
Monad (Handshake c d h) =>
(forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a)
-> MonadThrow (Handshake c d h)
forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a
forall c d h. Monad (Handshake c d h)
forall c d h e a.
(HasCallStack, Exception e) =>
e -> Handshake c d h a
forall (m :: * -> *).
Monad m =>
(forall e a. (HasCallStack, Exception e) => e -> m a)
-> MonadThrow m
$cthrowM :: forall c d h e a.
(HasCallStack, Exception e) =>
e -> Handshake c d h a
throwM :: forall e a. (HasCallStack, Exception e) => e -> Handshake c d h a
MonadThrow
, MonadState (HandshakeState c d h)
)
defaultHandshakeOpts :: HandshakeRole
-> Plaintext
-> HandshakeOpts d
defaultHandshakeOpts :: forall d. HandshakeRole -> Plaintext -> HandshakeOpts d
defaultHandshakeOpts HandshakeRole
r Plaintext
p =
HandshakeOpts { _hoRole :: HandshakeRole
_hoRole = HandshakeRole
r
, _hoPrologue :: Plaintext
_hoPrologue = Plaintext
p
, _hoLocalEphemeral :: Maybe (KeyPair d)
_hoLocalEphemeral = Maybe (KeyPair d)
forall a. Maybe a
Nothing
, _hoLocalStatic :: Maybe (KeyPair d)
_hoLocalStatic = Maybe (KeyPair d)
forall a. Maybe a
Nothing
, _hoRemoteEphemeral :: Maybe (PublicKey d)
_hoRemoteEphemeral = Maybe (PublicKey d)
forall a. Maybe a
Nothing
, _hoRemoteStatic :: Maybe (PublicKey d)
_hoRemoteStatic = Maybe (PublicKey d)
forall a. Maybe a
Nothing
}
setLocalEphemeral :: Maybe (KeyPair d)
-> HandshakeOpts d
-> HandshakeOpts d
setLocalEphemeral :: forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalEphemeral Maybe (KeyPair d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoLocalEphemeral = k }
setLocalStatic :: Maybe (KeyPair d)
-> HandshakeOpts d
-> HandshakeOpts d
setLocalStatic :: forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalStatic Maybe (KeyPair d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoLocalStatic = k }
setRemoteEphemeral :: Maybe (PublicKey d)
-> HandshakeOpts d
-> HandshakeOpts d
setRemoteEphemeral :: forall d. Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d
setRemoteEphemeral Maybe (PublicKey d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoRemoteEphemeral = k }
setRemoteStatic :: Maybe (PublicKey d)
-> HandshakeOpts d
-> HandshakeOpts d
setRemoteStatic :: forall d. Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d
setRemoteStatic Maybe (PublicKey d)
k HandshakeOpts d
opts = HandshakeOpts d
opts { _hoRemoteStatic = k }
mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
=> ByteString
-> proxy (c, d, h)
-> ScrubbedBytes
mkHandshakeName :: forall c d h (proxy :: * -> *).
(Cipher c, DH d, Hash h) =>
ByteString -> proxy (c, d, h) -> Plaintext
mkHandshakeName ByteString
protoName proxy (c, d, h)
_ =
Plaintext
"Noise_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> ByteString -> Plaintext
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
protoName Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
d Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
c Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
"_" Plaintext -> Plaintext -> Plaintext
forall a. Semigroup a => a -> a -> a
<> Plaintext
h
where
c :: Plaintext
c = Proxy c -> Plaintext
forall c (proxy :: * -> *). Cipher c => proxy c -> Plaintext
forall (proxy :: * -> *). proxy c -> Plaintext
cipherName (Proxy c
forall {k} (t :: k). Proxy t
Proxy :: Proxy c)
d :: Plaintext
d = Proxy d -> Plaintext
forall d (proxy :: * -> *). DH d => proxy d -> Plaintext
forall (proxy :: * -> *). proxy d -> Plaintext
dhName (Proxy d
forall {k} (t :: k). Proxy t
Proxy :: Proxy d)
h :: Plaintext
h = Proxy h -> Plaintext
forall h (proxy :: * -> *). Hash h => proxy h -> Plaintext
forall (proxy :: * -> *). proxy h -> Plaintext
hashName (Proxy h
forall {k} (t :: k). Proxy t
Proxy :: Proxy h)
handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
=> HandshakeOpts d
-> HandshakePattern
-> HandshakeState c d h
handshakeState :: forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> HandshakeState c d h
handshakeState HandshakeOpts d
ho HandshakePattern
hp =
HandshakeState { _hsSymmetricState :: SymmetricState c h
_hsSymmetricState = SymmetricState c h
ss'
, _hsOpts :: HandshakeOpts d
_hsOpts = HandshakeOpts d
ho
, _hsPSKMode :: Bool
_hsPSKMode = HandshakePattern
hp HandshakePattern -> Getting Bool HandshakePattern Bool -> Bool
forall s a. s -> Getting a s a -> a
^. Getting Bool HandshakePattern Bool
Lens' HandshakePattern Bool
hpPSKMode
, _hsMsgBuffer :: Plaintext
_hsMsgBuffer = Plaintext
forall a. Monoid a => a
mempty
}
where
ss :: SymmetricState c h
ss = Plaintext -> SymmetricState c h
forall c h. (Cipher c, Hash h) => Plaintext -> SymmetricState c h
symmetricState (Plaintext -> SymmetricState c h)
-> Plaintext -> SymmetricState c h
forall a b. (a -> b) -> a -> b
$ ByteString -> Proxy (c, d, h) -> Plaintext
forall c d h (proxy :: * -> *).
(Cipher c, DH d, Hash h) =>
ByteString -> proxy (c, d, h) -> Plaintext
mkHandshakeName (HandshakePattern
hp HandshakePattern
-> Getting ByteString HandshakePattern ByteString -> ByteString
forall s a. s -> Getting a s a -> a
^. Getting ByteString HandshakePattern ByteString
Lens' HandshakePattern ByteString
hpName)
(Proxy (c, d, h)
forall {k} (t :: k). Proxy t
Proxy :: Proxy (c, d, h))
ss' :: SymmetricState c h
ss' = Plaintext -> SymmetricState c h -> SymmetricState c h
forall h c.
Hash h =>
Plaintext -> SymmetricState c h -> SymmetricState c h
mixHash (HandshakeOpts d
ho HandshakeOpts d
-> Getting Plaintext (HandshakeOpts d) Plaintext -> Plaintext
forall s a. s -> Getting a s a -> a
^. Getting Plaintext (HandshakeOpts d) Plaintext
forall d (f :: * -> *).
Functor f =>
(Plaintext -> f Plaintext)
-> HandshakeOpts d -> f (HandshakeOpts d)
hoPrologue) SymmetricState c h
ss
instance (Functor f, MonadThrow m) => MonadThrow (Coroutine f m) where
throwM :: forall e a. (HasCallStack, Exception e) => e -> Coroutine f m a
throwM = m a -> Coroutine f m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine f m a) -> (e -> m a) -> e -> Coroutine f m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM
instance (Functor f, MonadState s m) => MonadState s (Coroutine f m) where
get :: Coroutine f m s
get = m s -> Coroutine f m s
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
put :: s -> Coroutine f m ()
put = m () -> Coroutine f m ()
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> Coroutine f m ()) -> (s -> m ()) -> s -> Coroutine f m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
state :: forall a. (s -> (a, s)) -> Coroutine f m a
state = m a -> Coroutine f m a
forall (m :: * -> *) a. Monad m => m a -> Coroutine f m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Coroutine f m a)
-> ((s -> (a, s)) -> m a) -> (s -> (a, s)) -> Coroutine f m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state