{-# LANGUAGE OverloadedStrings, TemplateHaskell, ScopedTypeVariables #-}
{-# OPTIONS_HADDOCK hide #-}
----------------------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.NoiseState
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX

module Crypto.Noise.Internal.NoiseState where

import Control.Monad.Catch.Pure
import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.State  (MonadState(..), runStateT, get, put)
import Control.Monad.Free.Church
import Control.Lens
import Data.ByteString      (ByteString)
import Data.Maybe           (isJust)
import Data.Monoid          ((<>))
import Data.Proxy           (Proxy(..))
import Prelude hiding       (concat, splitAt, length)

import Crypto.Noise.Cipher
import Crypto.Noise.DH
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Internal.SymmetricState
import Crypto.Noise.Internal.Handshake
import Crypto.Noise.Internal.HandshakePattern
import Crypto.Noise.Internal.Types
import Data.ByteArray.Extend

-- | Represents the complete state of a Noise conversation.
data NoiseState c d h =
  NoiseState { _nsHandshakeState       :: HandshakeState c d h
             , _nsHandshakeSuspension  :: ScrubbedBytes -> Handshake c d h ()
             , _nsSendingCipherState   :: Maybe (CipherState c)
             , _nsReceivingCipherState :: Maybe (CipherState c)
             }

$(makeLenses ''NoiseState)

-- | Returns a default set of handshake options. The prologue is set to an
--   empty string, PSK-mode is disabled, and all keys are set to 'Nothing'.
defaultHandshakeOpts :: HandshakePattern
                     -> HandshakeRole
                     -> HandshakeOpts d
defaultHandshakeOpts hp r =
  HandshakeOpts { _hoPattern             = hp
                , _hoRole                = r
                , _hoPrologue            = ""
                , _hoPreSharedKey        = Nothing
                , _hoLocalStatic         = Nothing
                , _hoLocalSemiEphemeral  = Nothing
                , _hoLocalEphemeral      = Nothing
                , _hoRemoteStatic        = Nothing
                , _hoRemoteSemiEphemeral = Nothing
                , _hoRemoteEphemeral     = Nothing
                }

mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h)
                => ByteString
                -> Bool
                -> proxy (c, d, h)
                -> ScrubbedBytes
mkHandshakeName hpn psk _ = p <> convert hpn <> "_" <> d <> "_" <> c <> "_" <> h
  where
    p = if psk then "NoisePSK_" else "Noise_"
    c = cipherName (Proxy :: Proxy c)
    d = dhName     (Proxy :: Proxy d)
    h = hashName   (Proxy :: Proxy h)

invertRole :: HandshakeRole -> HandshakeRole
invertRole InitiatorRole = ResponderRole
invertRole ResponderRole = InitiatorRole

handshakeState :: forall c d h. (Cipher c, DH d, Hash h)
               => HandshakeOpts d
               -> HandshakeState c d h
handshakeState ho | not (validPSK (ho ^. hoPreSharedKey)) = error "pre-shared key must be 32 bytes in length"
                  | otherwise =
  HandshakeState { _hsSymmetricState = ss''
                 , _hsOpts           = ho
                 , _hsMsgBuffer      = mempty
                 }
  where
    validPSK = maybe True (\psk -> length psk == 32)
    ss   = symmetricState $ mkHandshakeName (ho ^. hoPattern ^. hpName)
                                            (isJust (ho ^. hoPreSharedKey))
                                            (Proxy :: Proxy (c, d, h))

    ss'  = mixHash (ho ^. hoPrologue) ss
    ss'' = maybe ss' (`mixPSK` ss') $ ho ^. hoPreSharedKey

runHandshake :: (MonadThrow m, Cipher c, Hash h)
             => ScrubbedBytes
             -> NoiseState c d h
             -> m (ScrubbedBytes, NoiseState c d h)
runHandshake msg ns = reThrow . runCatch $ do
  ((res, ns''), hs') <- runStateT st $ ns ^. nsHandshakeState
  return (res, ns'' & nsHandshakeState .~ hs')

  where
    reThrow = either throwM return
    st = do
      x <- resume . runHandshake' . (ns ^. nsHandshakeSuspension) $ msg
      case x of
        Left (Request req resp) -> return (req, ns & nsHandshakeSuspension .~ (Handshake . resp))
        Right _ -> do
          hs <- get

          let (cs1, cs2) = split (hs ^. hsSymmetricState)
              ns'        = if hs ^. hsOpts . hoRole == InitiatorRole
                             then ns & nsSendingCipherState   .~ Just cs1
                                     & nsReceivingCipherState .~ Just cs2
                             else ns & nsSendingCipherState   .~ Just cs2
                                     & nsReceivingCipherState .~ Just cs1
          return (hs ^. hsMsgBuffer, ns')

-- | Creates a 'NoiseState'.
noiseState :: forall c d h. (Cipher c, DH d, Hash h)
           => HandshakeOpts d
           -> NoiseState c d h
noiseState ho =
  NoiseState { _nsHandshakeState       = hs''
             , _nsHandshakeSuspension  = suspension
             , _nsSendingCipherState   = Nothing
             , _nsReceivingCipherState = Nothing
             }

  where
    hs        = handshakeState ho :: HandshakeState c d h
    coroutine = iterM evalPattern $ ho ^. hoPattern . hpActions
    (suspension, hs'') = case runCatch (runStateT (resume (runHandshake' coroutine)) hs) of
            Left err -> error $ "handshake pattern interpreter threw exception: " <> show err
            Right result -> case result of
              (Left (Request _ resp), hs') -> (Handshake . resp, hs')
              _ -> error "handshake pattern interpreter ended pre-maturely"

processPatternOp :: (Cipher c, DH d, Hash h)
                 => HandshakeRole
                 -> F TokenF ()
                 -> Handshake c d h ()
                 -> Handshake c d h ()
processPatternOp opRole t next = do
  hs <- get
  input <- Handshake <$> request $ hs ^. hsMsgBuffer
  hs' <- get

  if opRole == hs' ^. hsOpts . hoRole then do
    put $ hs' & hsMsgBuffer .~ mempty
    iterM (evalMsgToken opRole) t

    hs'' <- get

    let enc = encryptAndHash (convert input) $ hs'' ^. hsSymmetricState

    (ep, ss) <- either throwM return enc

    put $ hs'' & hsMsgBuffer      %~ (flip mappend . convert) ep
               & hsSymmetricState .~ ss
  else do
    put $ hs' & hsMsgBuffer .~ input
    iterM (evalMsgToken opRole) t

    hs'' <- get

    let remaining = hs'' ^. hsMsgBuffer
        dec       = decryptAndHash (cipherBytesToText (convert remaining))
                    $ hs'' ^. hsSymmetricState

    (dp, ss) <- either (const . throwM . HandshakeError $ "handshake payload failed to decrypt") return dec

    put $ hs'' & hsMsgBuffer      .~ convert dp
               & hsSymmetricState .~ ss

  next

evalPattern :: (Cipher c, DH d, Hash h)
            => HandshakePatternF (Handshake c d h ())
            -> Handshake c d h ()
evalPattern (PreInitiator t next) = do
  iterM (evalPreMsgToken InitiatorRole) t
  next

evalPattern (PreResponder t next) = do
  iterM (evalPreMsgToken ResponderRole) t
  next

evalPattern (Initiator t next) = processPatternOp InitiatorRole t next
evalPattern (Responder t next) = processPatternOp ResponderRole t next

evalMsgToken :: forall c d h. (Cipher c, DH d, Hash h)
             => HandshakeRole
             -> TokenF (Handshake c d h ())
             -> Handshake c d h ()
evalMsgToken opRole (E next) = do
  hs <- get

  if opRole == hs ^. hsOpts . hoRole then do
    (_, pk) <- getLocalEphemeral hs
    let pk'     = dhPubToBytes pk
        ss      = hs ^. hsSymmetricState
        ss'     = mixHash pk' ss
        ss''    = if ss' ^. ssHasPSK then mixKey pk' ss' else ss'

    put $ hs & hsSymmetricState .~ ss''
             & hsMsgBuffer      %~ (flip mappend . convert) pk'

  else do
    let (b, rest) = splitAt (dhLength (Proxy :: Proxy d)) $ hs ^. hsMsgBuffer
        reBytes   = convert b
        ss        = hs ^. hsSymmetricState
        ss'       = mixHash reBytes ss
        ss''      = if ss ^. ssHasPSK then mixKey reBytes ss' else ss'
        theirKey  = dhBytesToPub reBytes

    theirKey' <- maybe (throwM . HandshakeError $ "invalid remote ephemeral key") return theirKey

    put $ hs & hsOpts . hoRemoteEphemeral .~ Just theirKey'
             & hsSymmetricState           .~ ss''
             & hsMsgBuffer                .~ rest

  next

evalMsgToken opRole (S next) = do
  hs <- get

  if opRole == hs ^. hsOpts. hoRole then do
    pk <- dhPubToBytes . snd <$> getLocalStatic hs
    let ss  = hs ^. hsSymmetricState
        enc = encryptAndHash (convert pk) ss

    (ct, ss') <- either throwM return enc

    put $ hs & hsSymmetricState .~ ss'
             & hsMsgBuffer      %~ (flip mappend . convert) ct
  else
    if isJust (hs ^. hsOpts ^. hoRemoteStatic)
      then throwM . InvalidHandshakeOptions $ "unable to overwrite remote static key"
      else do
        let hasKey    = hs ^. hsSymmetricState . ssHasKey
            len       = dhLength (Proxy :: Proxy d)
            -- The magic 16 here represents the length of the auth tag.
            d         = if hasKey then len + 16 else len
            (b, rest) = splitAt d $ hs ^. hsMsgBuffer
            ct        = cipherBytesToText . convert $ b
            ss        = hs ^. hsSymmetricState
            dec       = decryptAndHash ct ss

        (pt, ss')     <- either (const . throwM . HandshakeError $ "failed to decrypt remote static key") return dec
        theirKey'     <- maybe (throwM . HandshakeError $ "invalid remote static key provided") return $ dhBytesToPub pt

        put $ hs & hsOpts . hoRemoteStatic .~ Just theirKey'
                 & hsSymmetricState        .~ ss'
                 & hsMsgBuffer             .~ rest

  next

evalMsgToken _ (Dhee next) = do
  hs <- get

  ~(sk, _) <- getLocalEphemeral hs
  rpk      <- getRemoteEphemeral hs

  let ss'  = mixKey (dhPerform sk rpk) $ hs ^. hsSymmetricState

  put $ hs & hsSymmetricState .~ ss'

  next

evalMsgToken opRole (Dhes next) = do
  hs <- get

  if opRole == hs ^. hsOpts . hoRole then do
    let ss = hs ^. hsSymmetricState

    rpk <- getRemoteStatic hs

    ~(sk, _) <- getLocalEphemeral hs
    let dh  = dhPerform sk rpk
        ss' = mixKey dh ss

    put $ hs & hsSymmetricState .~ ss'

    next
  else evalMsgToken (invertRole opRole) $ Dhse next

evalMsgToken opRole (Dhse next) = do
  hs <- get

  if opRole == hs ^. hsOpts . hoRole then do
    let ss = hs ^. hsSymmetricState

    ~(sk, _) <- getLocalStatic hs

    rpk <- getRemoteEphemeral hs
    let dh  = dhPerform sk rpk
        ss' = mixKey dh ss

    put $ hs & hsSymmetricState .~ ss'

    next
  else evalMsgToken (invertRole opRole) $ Dhes next

evalMsgToken _ (Dhss next) = do
  hs <- get

  let ss = hs ^. hsSymmetricState

  ~(sk, _) <- getLocalStatic hs
  rpk      <- getRemoteStatic hs
  let dh  = dhPerform sk rpk
      ss' = mixKey dh ss

  put $ hs & hsSymmetricState .~ ss'

  next

evalPreMsgToken :: (Cipher c, DH d, Hash h)
                => HandshakeRole
                -> TokenF (Handshake c d h ())
                -> Handshake c d h ()
evalPreMsgToken opRole (E next) = do
  hs <- get

  let ss = hs ^. hsSymmetricState
  pk <- if opRole == hs ^. hsOpts . hoRole
    then snd <$> getLocalSemiEphemeral hs
    else getRemoteSemiEphemeral hs

  let ss' = mixHash (dhPubToBytes pk) ss

  put $ hs & hsSymmetricState .~ ss'

  next

evalPreMsgToken opRole (S next) = do
  hs <- get

  let ss = hs ^. hsSymmetricState

  pk <- if opRole == hs ^. hsOpts . hoRole
    then snd <$> getLocalStatic hs
    else getRemoteStatic hs

  let ss' = mixHash (dhPubToBytes pk) ss

  put $ hs & hsSymmetricState .~ ss'

  next

evalPreMsgToken _ _ = error "invalid pre-message pattern token"

getLocalStatic :: HandshakeState c d h
               -> Handshake c d h (KeyPair d)
getLocalStatic hs = maybe (throwM (InvalidHandshakeOptions "local static key not set"))
                          return
                          (hs ^. hsOpts ^. hoLocalStatic)

getLocalSemiEphemeral :: HandshakeState c d h
                      -> Handshake c d h (KeyPair d)
getLocalSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local semi-ephemeral key not set"))
                                 return
                                 (hs ^. hsOpts ^. hoLocalSemiEphemeral)

getLocalEphemeral :: HandshakeState c d h
                  -> Handshake c d h (KeyPair d)
getLocalEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local ephemeral key not set"))
                             return
                             (hs ^. hsOpts ^. hoLocalEphemeral)

getRemoteStatic :: HandshakeState c d h
                -> Handshake c d h (PublicKey d)
getRemoteStatic hs = maybe (throwM (InvalidHandshakeOptions "remote static key not set"))
                           return
                           (hs ^. hsOpts ^. hoRemoteStatic)

getRemoteSemiEphemeral :: HandshakeState c d h
                       -> Handshake c d h (PublicKey d)
getRemoteSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote semi-ephemeral key not set"))
                                  return
                                  (hs ^. hsOpts ^. hoRemoteSemiEphemeral)

getRemoteEphemeral :: HandshakeState c d h
                   -> Handshake c d h (PublicKey d)
getRemoteEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote ephemeral key not set"))
                              return
                              (hs ^. hsOpts ^. hoRemoteEphemeral)