module Crypto.Noise.Internal.HandshakeState
(
MonadHandshake(..),
MessagePattern,
MessagePatternIO,
HandshakePattern(HandshakePattern),
HandshakeState,
runMessagePatternT,
getLocalStaticKey,
getLocalEphemeralKey,
getRemoteStaticKey,
getRemoteEphemeralKey,
handshakeState,
writeMessage,
readMessage,
writeMessageFinal,
readMessageFinal,
encryptPayload,
decryptPayload
) where
import Control.Exception (throw)
import Control.Lens hiding (re)
import Control.Monad.State
import Data.ByteString (ByteString)
import qualified Data.ByteString as B (append, splitAt)
import Data.Maybe (fromMaybe, isJust)
import Data.Proxy
import Crypto.Noise.Cipher
import Crypto.Noise.Curve
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Types
newtype MessagePatternT c d h m a = MessagePatternT { unMP :: StateT (HandshakeState c d h) m a }
deriving (Functor, Applicative, Monad, MonadIO, MonadState(HandshakeState c d h))
runMessagePatternT :: Monad m
=> MessagePatternT c d h m a
-> HandshakeState c d h
-> m (a, HandshakeState c d h)
runMessagePatternT = runStateT . unMP
type MessagePattern c d h a = MessagePatternT c d h Identity a
type MessagePatternIO c d h a = MessagePatternT c d h IO a
data HandshakePattern c d h =
HandshakePattern { _hpName :: ByteString
, _hpPreMsg :: Maybe (MessagePattern c d h ())
, _hpWriteMsg :: [MessagePatternIO c d h ByteString]
, _hpReadMsg :: [ByteString -> MessagePattern c d h ByteString]
}
data HandshakeState c d h =
HandshakeState { _hssSymmetricState :: SymmetricState c h
, _hssHandshakePattern :: HandshakePattern c d h
, _hssLocalStaticKey :: Maybe (KeyPair d)
, _hssLocalEphemeralKey :: Maybe (KeyPair d)
, _hssRemoteStaticKey :: Maybe (PublicKey d)
, _hssRemoteEphemeralKey :: Maybe (PublicKey d)
}
$(makeLenses ''HandshakePattern)
$(makeLenses ''HandshakeState)
class Monad m => MonadHandshake m where
tokenPreLS :: m ()
tokenPreRS :: m ()
tokenPreLE :: m ()
tokenPreRE :: m ()
tokenRE :: ByteString -> m ByteString
tokenRS :: ByteString -> m ByteString
tokenWE :: MonadIO m => m ByteString
tokenWS :: m ByteString
tokenDHEE :: m ()
tokenDHES :: m ()
tokenDHSE :: m ()
tokenDHSS :: m ()
instance (Monad m, Cipher c, Curve d, Hash h) => MonadHandshake (MessagePatternT c d h m) where
tokenPreLS = tokenPreLX hssLocalStaticKey
tokenPreRS = tokenPreRX hssRemoteStaticKey
tokenPreLE = tokenPreLX hssLocalEphemeralKey
tokenPreRE = tokenPreRX hssRemoteEphemeralKey
tokenRE buf = do
hs <- get
let (b, rest) = B.splitAt (curveLength (Proxy :: Proxy d)) buf
reBytes = convert b
ss = hs ^. hssSymmetricState
ss' = mixHash reBytes ss
ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss'
put $ hs & hssRemoteEphemeralKey .~ Just (curveBytesToPub reBytes)
& hssSymmetricState .~ ss''
return rest
tokenRS buf = do
hs <- get
if isJust (hs ^. hssRemoteStaticKey) then
throw . HandshakeStateFailure $ "unable to overwrite remote static key"
else do
let hasKey = hs ^. hssSymmetricState . ssHasKey
(b, rest) = B.splitAt (d hasKey) buf
ct = cipherBytesToText . convert $ b
ss = hs ^. hssSymmetricState
(Plaintext pt, ss') = decryptAndHash ct ss
put $ hs & hssRemoteStaticKey .~ Just (curveBytesToPub pt)
& hssSymmetricState .~ ss'
return rest
where
len = curveLength (Proxy :: Proxy d)
d hk
| hk = len + 16
| otherwise = len
tokenWE = do
~kp@(_, pk) <- liftIO curveGenKey
hs <- get
let pk' = curvePubToBytes pk
ss = hs ^. hssSymmetricState
ss' = mixHash pk' ss
ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss'
put $ hs & hssLocalEphemeralKey .~ Just kp
& hssSymmetricState .~ ss''
return . convert $ pk'
tokenWS = do
hs <- get
let pk = curvePubToBytes . snd . getLocalStaticKey $ hs
ss = hs ^. hssSymmetricState
(ct, ss') = encryptAndHash ((Plaintext . convert) pk) ss
put $ hs & hssSymmetricState .~ ss'
return . convert $ ct
tokenDHEE = do
hs <- get
let ss = hs ^. hssSymmetricState
~(sk, _) = getLocalEphemeralKey hs
rpk = getRemoteEphemeralKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hssSymmetricState .~ ss'
tokenDHES = do
hs <- get
let ss = hs ^. hssSymmetricState
~(sk, _) = getLocalEphemeralKey hs
rpk = getRemoteStaticKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hssSymmetricState .~ ss'
tokenDHSE = do
hs <- get
let ss = hs ^. hssSymmetricState
~(sk, _) = getLocalStaticKey hs
rpk = getRemoteEphemeralKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hssSymmetricState .~ ss'
tokenDHSS = do
hs <- get
let ss = hs ^. hssSymmetricState
~(sk, _) = getLocalStaticKey hs
rpk = getRemoteStaticKey hs
dh = curveDH sk rpk
ss' = mixKey dh ss
put $ hs & hssSymmetricState .~ ss'
getLocalStaticKey :: Curve d => HandshakeState c d h -> KeyPair d
getLocalStaticKey hs = fromMaybe (throw (HandshakeStateFailure "local static key not set"))
(hs ^. hssLocalStaticKey)
getLocalEphemeralKey :: Curve d => HandshakeState c d h -> KeyPair d
getLocalEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "local ephemeral key not set"))
(hs ^. hssLocalEphemeralKey)
getRemoteStaticKey :: Curve d => HandshakeState c d h -> PublicKey d
getRemoteStaticKey hs = fromMaybe (throw (HandshakeStateFailure "remote static key not set"))
(hs ^. hssRemoteStaticKey)
getRemoteEphemeralKey :: Curve d => HandshakeState c d h -> PublicKey d
getRemoteEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "remote ephemeral key not set"))
(hs ^. hssRemoteEphemeralKey)
tokenPreLX :: (MonadState (HandshakeState c d h) m, Cipher c, Curve d, Hash h)
=> Lens' (HandshakeState c d h) (Maybe (KeyPair d))
-> m ()
tokenPreLX keyToView = do
hs <- get
let ss = hs ^. hssSymmetricState
(_, pk) = fromMaybe (throw (HandshakeStateFailure "tokenPreLX: local key not set"))
(hs ^. keyToView)
ss' = mixHash (curvePubToBytes pk) ss
put $ hs & hssSymmetricState .~ ss'
tokenPreRX :: (MonadState (HandshakeState c d h) m, Cipher c, Curve d, Hash h)
=> Lens' (HandshakeState c d h) (Maybe (PublicKey d))
-> m ()
tokenPreRX keyToView = do
hs <- get
let ss = hs ^. hssSymmetricState
pk = fromMaybe (throw (HandshakeStateFailure "tokenPreRX: remote key not set"))
(hs ^. keyToView)
ss' = mixHash (curvePubToBytes pk) ss
put $ hs & hssSymmetricState .~ ss'
handshakeState :: forall c d h. (Cipher c, Curve d, Hash h)
=> HandshakePattern c d h
-> Plaintext
-> Maybe Plaintext
-> Maybe (KeyPair d)
-> Maybe (KeyPair d)
-> Maybe (PublicKey d)
-> Maybe (PublicKey d)
-> HandshakeState c d h
handshakeState hp (Plaintext pro) (Just (Plaintext psk)) ls le rs re =
maybe hs' hs'' $ hp ^. hpPreMsg
where
hsPro x = x & hssSymmetricState %~ mixHash pro
hsPSK x = x & hssSymmetricState %~ mixPSK psk
hs = HandshakeState (symmetricState (mkHPN hp True)) hp ls le rs re
hs' = hsPSK . hsPro $ hs
hs'' mp = snd . runIdentity $ runMessagePatternT mp hs'
handshakeState hp (Plaintext pro) Nothing ls le rs re =
maybe hs' hs'' $ hp ^. hpPreMsg
where
hsPro x = x & hssSymmetricState %~ mixHash pro
hs = HandshakeState (symmetricState (mkHPN hp False)) hp ls le rs re
hs' = hsPro hs
hs'' mp = snd . runIdentity $ runMessagePatternT mp hs'
mkHPN :: forall c d h. (Cipher c, Curve d, Hash h)
=> HandshakePattern c d h
-> Bool
-> ScrubbedBytes
mkHPN hp psk = if psk then hpn' ppsk else hpn' p
where
p = bsToSB' "Noise_"
ppsk = bsToSB' "NoisePSK_"
a = curveName (Proxy :: Proxy d)
b = cipherName (Proxy :: Proxy c)
c = hashName (Proxy :: Proxy h)
u = bsToSB' "_"
hpn' pfx = concatSB [pfx, bsToSB' (hp ^. hpName), u, a, u, b, u, c]
writeMessage :: (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> Plaintext
-> IO (ByteString, HandshakeState c d h)
writeMessage hs payload = do
let (wmp:wmps) = hs ^. hssHandshakePattern . hpWriteMsg
(d, hs') <- runMessagePatternT wmp hs
let (ep, ss') = encryptAndHash payload $ hs' ^. hssSymmetricState
hs'' = hs' & hssSymmetricState .~ ss'
& hssHandshakePattern . hpWriteMsg .~ wmps
return (d `B.append` convert ep, hs'')
readMessage :: (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> ByteString
-> (Plaintext, HandshakeState c d h)
readMessage hs buf = (dp, hs'')
where
(rmp:rmps) = hs ^. hssHandshakePattern . hpReadMsg
(d, hs') = runIdentity $ runMessagePatternT (rmp buf) hs
(dp, ss') = decryptAndHash (cipherBytesToText (convert d))
$ hs' ^. hssSymmetricState
hs'' = hs' & hssSymmetricState .~ ss'
& hssHandshakePattern . hpReadMsg .~ rmps
writeMessageFinal :: (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> Plaintext
-> IO (ByteString, CipherState c, CipherState c)
writeMessageFinal hs payload = do
(d, hs') <- writeMessage hs payload
let (cs1, cs2) = split $ hs' ^. hssSymmetricState
return (d, cs1, cs2)
readMessageFinal :: (Cipher c, Curve d, Hash h)
=> HandshakeState c d h
-> ByteString
-> (Plaintext, CipherState c, CipherState c)
readMessageFinal hs buf = (pt, cs1, cs2)
where
(pt, hs') = readMessage hs buf
(cs1, cs2) = split $ hs' ^. hssSymmetricState
encryptPayload :: Cipher c
=> Plaintext
-> CipherState c
-> (ByteString, CipherState c)
encryptPayload pt cs = ((convert . cipherTextToBytes) ct, cs')
where
(ct, cs') = encryptAndIncrement ad pt cs
ad = AssocData . bsToSB' $ ""
decryptPayload :: Cipher c
=> ByteString
-> CipherState c
-> (Plaintext, CipherState c)
decryptPayload ct cs = (pt, cs')
where
(pt, cs') = decryptAndIncrement ad ((cipherBytesToText . convert) ct) cs
ad = AssocData . bsToSB' $ ""