{-# LANGUAGE OverloadedStrings, FlexibleInstances, GeneralizedNewtypeDeriving, TemplateHaskell, RankNTypes, FlexibleContexts, ScopedTypeVariables #-} ---------------------------------------------------------------- -- | -- Module : Crypto.Noise.Internal.HandshakeState -- Maintainer : John Galt -- Stability : experimental -- Portability : POSIX module Crypto.Noise.Internal.HandshakeState ( -- * Classes MonadHandshake(..), -- * Types MessagePattern, MessagePatternIO, HandshakePattern(HandshakePattern), HandshakeState, -- * Functions runMessagePatternT, getLocalStaticKey, getLocalEphemeralKey, getRemoteStaticKey, getRemoteEphemeralKey, handshakeState, writeMessage, readMessage, writeMessageFinal, readMessageFinal, encryptPayload, decryptPayload ) where import Control.Exception (throw) import Control.Lens hiding (re) import Control.Monad.State import Data.ByteString (ByteString) import qualified Data.ByteString as B (append, splitAt) import Data.Maybe (fromMaybe, isJust) import Data.Proxy import Crypto.Noise.Cipher import Crypto.Noise.Curve import Crypto.Noise.Hash import Crypto.Noise.Internal.CipherState import Crypto.Noise.Internal.SymmetricState import Crypto.Noise.Types newtype MessagePatternT c d h m a = MessagePatternT { unMP :: StateT (HandshakeState c d h) m a } deriving (Functor, Applicative, Monad, MonadIO, MonadState(HandshakeState c d h)) runMessagePatternT :: Monad m => MessagePatternT c d h m a -> HandshakeState c d h -> m (a, HandshakeState c d h) runMessagePatternT = runStateT . unMP -- | Represents a series of operations that can be performed on a Noise -- message. type MessagePattern c d h a = MessagePatternT c d h Identity a -- | Represents a series of operations that will result in a Noise message. -- This must be done in IO to facilitate the generation of ephemeral -- keys. type MessagePatternIO c d h a = MessagePatternT c d h IO a -- | Represents a series of message patterns, the first for writing and the -- second for reading. data HandshakePattern c d h = HandshakePattern { _hpName :: ByteString , _hpPreMsg :: Maybe (MessagePattern c d h ()) , _hpWriteMsg :: [MessagePatternIO c d h ByteString] , _hpReadMsg :: [ByteString -> MessagePattern c d h ByteString] } -- | Contains the state of a handshake. data HandshakeState c d h = HandshakeState { _hssSymmetricState :: SymmetricState c h , _hssHandshakePattern :: HandshakePattern c d h , _hssLocalStaticKey :: Maybe (KeyPair d) , _hssLocalEphemeralKey :: Maybe (KeyPair d) , _hssRemoteStaticKey :: Maybe (PublicKey d) , _hssRemoteEphemeralKey :: Maybe (PublicKey d) } $(makeLenses ''HandshakePattern) $(makeLenses ''HandshakeState) class Monad m => MonadHandshake m where tokenPreLS :: m () tokenPreRS :: m () tokenPreLE :: m () tokenPreRE :: m () tokenRE :: ByteString -> m ByteString tokenRS :: ByteString -> m ByteString tokenWE :: MonadIO m => m ByteString tokenWS :: m ByteString tokenDHEE :: m () tokenDHES :: m () tokenDHSE :: m () tokenDHSS :: m () instance (Monad m, Cipher c, Curve d, Hash h) => MonadHandshake (MessagePatternT c d h m) where tokenPreLS = tokenPreLX hssLocalStaticKey tokenPreRS = tokenPreRX hssRemoteStaticKey tokenPreLE = tokenPreLX hssLocalEphemeralKey tokenPreRE = tokenPreRX hssRemoteEphemeralKey tokenRE buf = do hs <- get let (b, rest) = B.splitAt (curveLength (Proxy :: Proxy d)) buf reBytes = convert b ss = hs ^. hssSymmetricState ss' = mixHash reBytes ss ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss' put $ hs & hssRemoteEphemeralKey .~ Just (curveBytesToPub reBytes) & hssSymmetricState .~ ss'' return rest tokenRS buf = do hs <- get if isJust (hs ^. hssRemoteStaticKey) then throw . HandshakeStateFailure $ "unable to overwrite remote static key" else do let hasKey = hs ^. hssSymmetricState . ssHasKey (b, rest) = B.splitAt (d hasKey) buf ct = cipherBytesToText . convert $ b ss = hs ^. hssSymmetricState (Plaintext pt, ss') = decryptAndHash ct ss put $ hs & hssRemoteStaticKey .~ Just (curveBytesToPub pt) & hssSymmetricState .~ ss' return rest where len = curveLength (Proxy :: Proxy d) d hk | hk = len + 16 | otherwise = len tokenWE = do ~kp@(_, pk) <- liftIO curveGenKey hs <- get let pk' = curvePubToBytes pk ss = hs ^. hssSymmetricState ss' = mixHash pk' ss ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss' put $ hs & hssLocalEphemeralKey .~ Just kp & hssSymmetricState .~ ss'' return . convert $ pk' tokenWS = do hs <- get let pk = curvePubToBytes . snd . getLocalStaticKey $ hs ss = hs ^. hssSymmetricState (ct, ss') = encryptAndHash ((Plaintext . convert) pk) ss put $ hs & hssSymmetricState .~ ss' return . convert $ ct tokenDHEE = do hs <- get let ss = hs ^. hssSymmetricState ~(sk, _) = getLocalEphemeralKey hs rpk = getRemoteEphemeralKey hs dh = curveDH sk rpk ss' = mixKey dh ss put $ hs & hssSymmetricState .~ ss' tokenDHES = do hs <- get let ss = hs ^. hssSymmetricState ~(sk, _) = getLocalEphemeralKey hs rpk = getRemoteStaticKey hs dh = curveDH sk rpk ss' = mixKey dh ss put $ hs & hssSymmetricState .~ ss' tokenDHSE = do hs <- get let ss = hs ^. hssSymmetricState ~(sk, _) = getLocalStaticKey hs rpk = getRemoteEphemeralKey hs dh = curveDH sk rpk ss' = mixKey dh ss put $ hs & hssSymmetricState .~ ss' tokenDHSS = do hs <- get let ss = hs ^. hssSymmetricState ~(sk, _) = getLocalStaticKey hs rpk = getRemoteStaticKey hs dh = curveDH sk rpk ss' = mixKey dh ss put $ hs & hssSymmetricState .~ ss' getLocalStaticKey :: Curve d => HandshakeState c d h -> KeyPair d getLocalStaticKey hs = fromMaybe (throw (HandshakeStateFailure "local static key not set")) (hs ^. hssLocalStaticKey) getLocalEphemeralKey :: Curve d => HandshakeState c d h -> KeyPair d getLocalEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "local ephemeral key not set")) (hs ^. hssLocalEphemeralKey) -- | Returns the remote party's public static key. This is useful when -- the static key has been transmitted to you and you want to save it for -- future use. getRemoteStaticKey :: Curve d => HandshakeState c d h -> PublicKey d getRemoteStaticKey hs = fromMaybe (throw (HandshakeStateFailure "remote static key not set")) (hs ^. hssRemoteStaticKey) getRemoteEphemeralKey :: Curve d => HandshakeState c d h -> PublicKey d getRemoteEphemeralKey hs = fromMaybe (throw (HandshakeStateFailure "remote ephemeral key not set")) (hs ^. hssRemoteEphemeralKey) tokenPreLX :: (MonadState (HandshakeState c d h) m, Cipher c, Curve d, Hash h) => Lens' (HandshakeState c d h) (Maybe (KeyPair d)) -> m () tokenPreLX keyToView = do hs <- get let ss = hs ^. hssSymmetricState (_, pk) = fromMaybe (throw (HandshakeStateFailure "tokenPreLX: local key not set")) (hs ^. keyToView) ss' = mixHash (curvePubToBytes pk) ss put $ hs & hssSymmetricState .~ ss' tokenPreRX :: (MonadState (HandshakeState c d h) m, Cipher c, Curve d, Hash h) => Lens' (HandshakeState c d h) (Maybe (PublicKey d)) -> m () tokenPreRX keyToView = do hs <- get let ss = hs ^. hssSymmetricState pk = fromMaybe (throw (HandshakeStateFailure "tokenPreRX: remote key not set")) (hs ^. keyToView) ss' = mixHash (curvePubToBytes pk) ss put $ hs & hssSymmetricState .~ ss' -- | Constructs a 'HandshakeState'. The keys you need to provide are -- dependent on the type of handshake you are using. If you fail to -- provide a key that your handshake type depends on, or you provide -- a static key which is supposed to be set during the exchange, you will -- receive a 'HandshakeStateFailure' exception. handshakeState :: forall c d h. (Cipher c, Curve d, Hash h) => HandshakePattern c d h -- ^ The handshake pattern to use -> Plaintext -- ^ Prologue -> Maybe Plaintext -- ^ Pre-shared key -> Maybe (KeyPair d) -- ^ Local static key -> Maybe (KeyPair d) -- ^ Local ephemeral key -> Maybe (PublicKey d) -- ^ Remote public static key -> Maybe (PublicKey d) -- ^ Remote public ephemeral key -> HandshakeState c d h handshakeState hp (Plaintext pro) (Just (Plaintext psk)) ls le rs re = maybe hs' hs'' $ hp ^. hpPreMsg where hsPro x = x & hssSymmetricState %~ mixHash pro hsPSK x = x & hssSymmetricState %~ mixPSK psk hs = HandshakeState (symmetricState (mkHPN hp True)) hp ls le rs re hs' = hsPSK . hsPro $ hs hs'' mp = snd . runIdentity $ runMessagePatternT mp hs' handshakeState hp (Plaintext pro) Nothing ls le rs re = maybe hs' hs'' $ hp ^. hpPreMsg where hsPro x = x & hssSymmetricState %~ mixHash pro hs = HandshakeState (symmetricState (mkHPN hp False)) hp ls le rs re hs' = hsPro hs hs'' mp = snd . runIdentity $ runMessagePatternT mp hs' mkHPN :: forall c d h. (Cipher c, Curve d, Hash h) => HandshakePattern c d h -> Bool -> ScrubbedBytes mkHPN hp psk = if psk then hpn' ppsk else hpn' p where p = bsToSB' "Noise_" ppsk = bsToSB' "NoisePSK_" a = curveName (Proxy :: Proxy d) b = cipherName (Proxy :: Proxy c) c = hashName (Proxy :: Proxy h) u = bsToSB' "_" hpn' pfx = concatSB [pfx, bsToSB' (hp ^. hpName), u, a, u, b, u, c] -- | Creates a handshake message. The plaintext can be left empty if no -- plaintext is to be transmitted. All subsequent handshake processing -- must use the returned state. writeMessage :: (Cipher c, Curve d, Hash h) => HandshakeState c d h -- ^ The handshake state -> Plaintext -- ^ Optional message to transmit -> IO (ByteString, HandshakeState c d h) writeMessage hs payload = do let (wmp:wmps) = hs ^. hssHandshakePattern . hpWriteMsg (d, hs') <- runMessagePatternT wmp hs let (ep, ss') = encryptAndHash payload $ hs' ^. hssSymmetricState hs'' = hs' & hssSymmetricState .~ ss' & hssHandshakePattern . hpWriteMsg .~ wmps return (d `B.append` convert ep, hs'') -- | Reads a handshake message. All subsequent handshake processing must -- use the returned state. readMessage :: (Cipher c, Curve d, Hash h) => HandshakeState c d h -- ^ The handshake state -> ByteString -- ^ The handshake message received -> (Plaintext, HandshakeState c d h) readMessage hs buf = (dp, hs'') where (rmp:rmps) = hs ^. hssHandshakePattern . hpReadMsg (d, hs') = runIdentity $ runMessagePatternT (rmp buf) hs (dp, ss') = decryptAndHash (cipherBytesToText (convert d)) $ hs' ^. hssSymmetricState hs'' = hs' & hssSymmetricState .~ ss' & hssHandshakePattern . hpReadMsg .~ rmps -- | The final call of a handshake negotiation. Used to generate a pair of -- CipherStates, one for each transmission direction. writeMessageFinal :: (Cipher c, Curve d, Hash h) => HandshakeState c d h -- ^ The handshake state -> Plaintext -- ^ Optional message to transmit -> IO (ByteString, CipherState c, CipherState c) writeMessageFinal hs payload = do (d, hs') <- writeMessage hs payload let (cs1, cs2) = split $ hs' ^. hssSymmetricState return (d, cs1, cs2) -- | The final call of a handshake negotiation. Used to generate a pair of -- CipherStates, one for each transmission direction. readMessageFinal :: (Cipher c, Curve d, Hash h) => HandshakeState c d h -- ^ The handshake state -> ByteString -- ^ The handshake message received -> (Plaintext, CipherState c, CipherState c) readMessageFinal hs buf = (pt, cs1, cs2) where (pt, hs') = readMessage hs buf (cs1, cs2) = split $ hs' ^. hssSymmetricState -- | Encrypts a payload. The returned 'CipherState' must be used for all -- subsequent calls. encryptPayload :: Cipher c => Plaintext -- ^ The data to encrypt -> CipherState c -- ^ The CipherState to use for encryption -> (ByteString, CipherState c) encryptPayload pt cs = ((convert . cipherTextToBytes) ct, cs') where (ct, cs') = encryptAndIncrement ad pt cs ad = AssocData . bsToSB' $ "" -- | Decrypts a payload. The returned 'CipherState' must be used for all -- subsequent calls. decryptPayload :: Cipher c => ByteString -- ^ The data to decrypt -> CipherState c -- ^ The CipherState to use for decryption -> (Plaintext, CipherState c) decryptPayload ct cs = (pt, cs') where (pt, cs') = decryptAndIncrement ad ((cipherBytesToText . convert) ct) cs ad = AssocData . bsToSB' $ ""