{-# LANGUAGE TemplateHaskell #-}
--------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.CipherState
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Internal.CipherState where

import Control.Exception.Safe
import Control.Lens

import Crypto.Noise.Cipher
import Crypto.Noise.Exception

data CipherState c =
  CipherState { forall c. CipherState c -> Maybe (SymmetricKey c)
_csk     :: Maybe (SymmetricKey c)
              , forall c. CipherState c -> Nonce c
_csn     :: Nonce c
              } deriving Int -> CipherState c -> ShowS
[CipherState c] -> ShowS
CipherState c -> String
(Int -> CipherState c -> ShowS)
-> (CipherState c -> String)
-> ([CipherState c] -> ShowS)
-> Show (CipherState c)
forall c. Int -> CipherState c -> ShowS
forall c. [CipherState c] -> ShowS
forall c. CipherState c -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall c. Int -> CipherState c -> ShowS
showsPrec :: Int -> CipherState c -> ShowS
$cshow :: forall c. CipherState c -> String
show :: CipherState c -> String
$cshowList :: forall c. [CipherState c] -> ShowS
showList :: [CipherState c] -> ShowS
Show

$(makeLenses ''CipherState)

-- | Creates a new CipherState with an optional symmetric key and a zero nonce.
cipherState :: Cipher c
            => Maybe (SymmetricKey c)
            -> CipherState c
cipherState :: forall c. Cipher c => Maybe (SymmetricKey c) -> CipherState c
cipherState Maybe (SymmetricKey c)
sk = Maybe (SymmetricKey c) -> Nonce c -> CipherState c
forall c. Maybe (SymmetricKey c) -> Nonce c -> CipherState c
CipherState Maybe (SymmetricKey c)
sk Nonce c
forall c. Cipher c => Nonce c
cipherZeroNonce

-- | Encrypts the provided plaintext and increments the nonce. If this
--   CipherState does not have a key associated with it, the plaintext
--   will be returned.
encryptWithAd :: (MonadThrow m, Cipher c)
              => AssocData
              -> Plaintext
              -> CipherState c
              -> m (Ciphertext c, CipherState c)
encryptWithAd :: forall (m :: * -> *) c.
(MonadThrow m, Cipher c) =>
AssocData
-> AssocData -> CipherState c -> m (Ciphertext c, CipherState c)
encryptWithAd AssocData
ad AssocData
plaintext CipherState c
cs
  | CipherState c -> Bool
forall c. Cipher c => CipherState c -> Bool
validNonce CipherState c
cs = (Ciphertext c, CipherState c) -> m (Ciphertext c, CipherState c)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ciphertext c
result, CipherState c
newState)
  | Bool
otherwise     = NoiseException -> m (Ciphertext c, CipherState c)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM NoiseException
MessageLimitReached
  where
    result :: Ciphertext c
result = Ciphertext c
-> (SymmetricKey c -> Ciphertext c)
-> Maybe (SymmetricKey c)
-> Ciphertext c
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (AssocData -> Ciphertext c
forall c. Cipher c => AssocData -> Ciphertext c
cipherBytesToText AssocData
plaintext)
                   (\SymmetricKey c
k -> SymmetricKey c -> Nonce c -> AssocData -> AssocData -> Ciphertext c
forall c.
Cipher c =>
SymmetricKey c -> Nonce c -> AssocData -> AssocData -> Ciphertext c
cipherEncrypt SymmetricKey c
k (CipherState c
cs CipherState c
-> Getting (Nonce c) (CipherState c) (Nonce c) -> Nonce c
forall s a. s -> Getting a s a -> a
^. Getting (Nonce c) (CipherState c) (Nonce c)
forall c (f :: * -> *).
Functor f =>
(Nonce c -> f (Nonce c)) -> CipherState c -> f (CipherState c)
csn) AssocData
ad AssocData
plaintext)
                   (Maybe (SymmetricKey c) -> Ciphertext c)
-> Maybe (SymmetricKey c) -> Ciphertext c
forall a b. (a -> b) -> a -> b
$ CipherState c
cs CipherState c
-> Getting
     (Maybe (SymmetricKey c)) (CipherState c) (Maybe (SymmetricKey c))
-> Maybe (SymmetricKey c)
forall s a. s -> Getting a s a -> a
^. Getting
  (Maybe (SymmetricKey c)) (CipherState c) (Maybe (SymmetricKey c))
forall c (f :: * -> *).
Functor f =>
(Maybe (SymmetricKey c) -> f (Maybe (SymmetricKey c)))
-> CipherState c -> f (CipherState c)
csk
    newState :: CipherState c
newState = CipherState c
cs CipherState c -> (CipherState c -> CipherState c) -> CipherState c
forall a b. a -> (a -> b) -> b
& (Nonce c -> Identity (Nonce c))
-> CipherState c -> Identity (CipherState c)
forall c (f :: * -> *).
Functor f =>
(Nonce c -> f (Nonce c)) -> CipherState c -> f (CipherState c)
csn ((Nonce c -> Identity (Nonce c))
 -> CipherState c -> Identity (CipherState c))
-> (Nonce c -> Nonce c) -> CipherState c -> CipherState c
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Nonce c -> Nonce c
forall c. Cipher c => Nonce c -> Nonce c
cipherIncNonce

-- | Decrypts the provided ciphertext and increments the nonce. If this
--   CipherState does not have a key associated with it, the ciphertext
--   will be returned. If the CipherState does have a key and decryption
--   fails, a @DecryptionError@ will be returned.
decryptWithAd :: (MonadThrow m, Cipher c)
              => AssocData
              -> Ciphertext c
              -> CipherState c
              -> m (Plaintext, CipherState c)
decryptWithAd :: forall (m :: * -> *) c.
(MonadThrow m, Cipher c) =>
AssocData
-> Ciphertext c -> CipherState c -> m (AssocData, CipherState c)
decryptWithAd AssocData
ad Ciphertext c
ct CipherState c
cs
  | CipherState c -> Bool
forall c. Cipher c => CipherState c -> Bool
validNonce CipherState c
cs =
    m (AssocData, CipherState c)
-> (AssocData -> m (AssocData, CipherState c))
-> Maybe AssocData
-> m (AssocData, CipherState c)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (NoiseException -> m (AssocData, CipherState c)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM NoiseException
DecryptionError)
          (\AssocData
x -> (AssocData, CipherState c) -> m (AssocData, CipherState c)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (AssocData
x, CipherState c
newState))
          Maybe AssocData
result
  | Bool
otherwise     = NoiseException -> m (AssocData, CipherState c)
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM NoiseException
MessageLimitReached
  where
    result :: Maybe AssocData
result   = Maybe AssocData
-> (SymmetricKey c -> Maybe AssocData)
-> Maybe (SymmetricKey c)
-> Maybe AssocData
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (AssocData -> Maybe AssocData
forall a. a -> Maybe a
Just (AssocData -> Maybe AssocData)
-> (Ciphertext c -> AssocData) -> Ciphertext c -> Maybe AssocData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ciphertext c -> AssocData
forall c. Cipher c => Ciphertext c -> AssocData
cipherTextToBytes (Ciphertext c -> Maybe AssocData)
-> Ciphertext c -> Maybe AssocData
forall a b. (a -> b) -> a -> b
$ Ciphertext c
ct)
                     (\SymmetricKey c
k -> SymmetricKey c
-> Nonce c -> AssocData -> Ciphertext c -> Maybe AssocData
forall c.
Cipher c =>
SymmetricKey c
-> Nonce c -> AssocData -> Ciphertext c -> Maybe AssocData
cipherDecrypt SymmetricKey c
k (CipherState c
cs CipherState c
-> Getting (Nonce c) (CipherState c) (Nonce c) -> Nonce c
forall s a. s -> Getting a s a -> a
^. Getting (Nonce c) (CipherState c) (Nonce c)
forall c (f :: * -> *).
Functor f =>
(Nonce c -> f (Nonce c)) -> CipherState c -> f (CipherState c)
csn) AssocData
ad Ciphertext c
ct)
                     (Maybe (SymmetricKey c) -> Maybe AssocData)
-> Maybe (SymmetricKey c) -> Maybe AssocData
forall a b. (a -> b) -> a -> b
$ CipherState c
cs CipherState c
-> Getting
     (Maybe (SymmetricKey c)) (CipherState c) (Maybe (SymmetricKey c))
-> Maybe (SymmetricKey c)
forall s a. s -> Getting a s a -> a
^. Getting
  (Maybe (SymmetricKey c)) (CipherState c) (Maybe (SymmetricKey c))
forall c (f :: * -> *).
Functor f =>
(Maybe (SymmetricKey c) -> f (Maybe (SymmetricKey c)))
-> CipherState c -> f (CipherState c)
csk
    newState :: CipherState c
newState = CipherState c
cs CipherState c -> (CipherState c -> CipherState c) -> CipherState c
forall a b. a -> (a -> b) -> b
& (Nonce c -> Identity (Nonce c))
-> CipherState c -> Identity (CipherState c)
forall c (f :: * -> *).
Functor f =>
(Nonce c -> f (Nonce c)) -> CipherState c -> f (CipherState c)
csn ((Nonce c -> Identity (Nonce c))
 -> CipherState c -> Identity (CipherState c))
-> (Nonce c -> Nonce c) -> CipherState c -> CipherState c
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Nonce c -> Nonce c
forall c. Cipher c => Nonce c -> Nonce c
cipherIncNonce

-- | Rekeys the CipherState. If a key has not been established yet, the
--   CipherState is returned unmodified.
rekey :: Cipher c
      => CipherState c
      -> CipherState c
rekey :: forall c. Cipher c => CipherState c -> CipherState c
rekey CipherState c
cs = CipherState c
cs CipherState c -> (CipherState c -> CipherState c) -> CipherState c
forall a b. a -> (a -> b) -> b
& (Maybe (SymmetricKey c) -> Identity (Maybe (SymmetricKey c)))
-> CipherState c -> Identity (CipherState c)
forall c (f :: * -> *).
Functor f =>
(Maybe (SymmetricKey c) -> f (Maybe (SymmetricKey c)))
-> CipherState c -> f (CipherState c)
csk ((Maybe (SymmetricKey c) -> Identity (Maybe (SymmetricKey c)))
 -> CipherState c -> Identity (CipherState c))
-> (Maybe (SymmetricKey c) -> Maybe (SymmetricKey c))
-> CipherState c
-> CipherState c
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Maybe (SymmetricKey c -> SymmetricKey c)
-> Maybe (SymmetricKey c) -> Maybe (SymmetricKey c)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) ((SymmetricKey c -> SymmetricKey c)
-> Maybe (SymmetricKey c -> SymmetricKey c)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SymmetricKey c -> SymmetricKey c
forall c. Cipher c => SymmetricKey c -> SymmetricKey c
cipherRekey)

-- | Tests whether the Nonce contained within a CipherState is valid (less
--   than the maximum allowed).
validNonce :: Cipher c
           => CipherState c
           -> Bool
validNonce :: forall c. Cipher c => CipherState c -> Bool
validNonce CipherState c
cs = CipherState c
cs CipherState c
-> Getting (Nonce c) (CipherState c) (Nonce c) -> Nonce c
forall s a. s -> Getting a s a -> a
^. Getting (Nonce c) (CipherState c) (Nonce c)
forall c (f :: * -> *).
Functor f =>
(Nonce c -> f (Nonce c)) -> CipherState c -> f (CipherState c)
csn Nonce c -> Nonce c -> Bool
forall a. Ord a => a -> a -> Bool
< Nonce c
forall c. Cipher c => Nonce c
cipherMaxNonce