module Crypto.Noise.Internal.SymmetricState where
import Control.Exception.Safe
import Control.Lens
import Data.ByteArray (ScrubbedBytes, length, replicate)
import Data.Proxy
import Prelude hiding (length, replicate)
import Crypto.Noise.Cipher
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
data SymmetricState c h =
SymmetricState { _ssCipher :: CipherState c
, _ssck :: ChainingKey h
, _ssh :: Either ScrubbedBytes (Digest h)
}
$(makeLenses ''SymmetricState)
symmetricState :: forall c h. (Cipher c, Hash h)
=> ScrubbedBytes
-> SymmetricState c h
symmetricState protoName = SymmetricState cs ck h
where
hashLen = hashLength (Proxy :: Proxy h)
shouldHash = length protoName > hashLen
h = if shouldHash
then Right $ hash protoName
else Left $ protoName `mappend` replicate (hashLen length protoName) 0
ck = hashBytesToCK . sshBytes $ h
cs = cipherState Nothing
mixKey :: (Cipher c, Hash h)
=> ScrubbedBytes
-> SymmetricState c h
-> SymmetricState c h
mixKey keyMat ss = ss & ssCipher .~ cs
& ssck .~ hashBytesToCK ck
where
[ck, k] = hashHKDF (ss ^. ssck) keyMat 2
cs = cipherState . Just . cipherBytesToSym $ k
mixHash :: Hash h
=> ScrubbedBytes
-> SymmetricState c h
-> SymmetricState c h
mixHash d ss = ss & ssh %~ Right . hash . (`mappend` d) . sshBytes
mixKeyAndHash :: (Cipher c, Hash h)
=> ScrubbedBytes
-> SymmetricState c h
-> SymmetricState c h
mixKeyAndHash keyMat ss = ss' & ssCipher .~ cs
& ssck .~ hashBytesToCK ck
where
[ck, h, k] = hashHKDF (ss ^. ssck) keyMat 3
ss' = mixHash h ss
cs = cipherState . Just . cipherBytesToSym $ k
encryptAndHash :: (MonadThrow m, Cipher c, Hash h)
=> Plaintext
-> SymmetricState c h
-> m (Ciphertext c, SymmetricState c h)
encryptAndHash pt ss = do
(ct, cs) <- encryptWithAd (sshBytes (ss ^. ssh)) pt (ss ^. ssCipher)
let ss' = mixHash (cipherTextToBytes ct) ss
return (ct, ss' & ssCipher .~ cs)
decryptAndHash :: (MonadThrow m, Cipher c, Hash h)
=> Ciphertext c
-> SymmetricState c h
-> m (Plaintext, SymmetricState c h)
decryptAndHash ct ss = do
(pt, cs) <- decryptWithAd (sshBytes (ss ^. ssh)) ct (ss ^. ssCipher)
let ss' = mixHash (cipherTextToBytes ct) ss
return (pt, ss' & ssCipher .~ cs)
split :: (Cipher c, Hash h)
=> SymmetricState c h
-> (CipherState c, CipherState c)
split ss = (c1, c2)
where
[k1, k2] = hashHKDF (ss ^. ssck) mempty 2
c1 = cipherState . Just . cipherBytesToSym $ k1
c2 = cipherState . Just . cipherBytesToSym $ k2
sshBytes :: Hash h
=> Either ScrubbedBytes (Digest h)
-> ScrubbedBytes
sshBytes (Left h) = h
sshBytes (Right h) = hashToBytes h