module Crypto.Noise.Internal.SymmetricState
(
SymmetricState(SymmetricState),
ssCipher,
ssHasKey,
ssHasPSK,
ssh,
symmetricState,
mixKey,
mixPSK,
mixHash,
encryptAndHash,
decryptAndHash,
split
) where
import Control.Lens
import Data.ByteArray as BA (length, replicate)
import Data.ByteString (empty)
import Data.Proxy
import Crypto.Noise.Cipher
import Crypto.Noise.Hash
import Crypto.Noise.Internal.CipherState
import Crypto.Noise.Types
data SymmetricState c h =
SymmetricState { _ssCipher :: CipherState c
, _ssHasKey :: Bool
, _ssHasPSK :: Bool
, _ssck :: ChainingKey h
, _ssh :: Either ScrubbedBytes (Digest h)
}
$(makeLenses ''SymmetricState)
symmetricState :: forall c h. (Cipher c, Hash h) => ScrubbedBytes -> SymmetricState c h
symmetricState hsn = SymmetricState cs False False ck hsn'
where
hashLen = hashLength (Proxy :: Proxy h)
shouldHash = BA.length hsn > hashLen
hsn' = if shouldHash then
Right $ hash hsn
else
Left $ hsn `append` BA.replicate (hashLen BA.length hsn) 0
ck = hashBytesToCK . sshBytes $ hsn'
cs = CipherState undefined undefined
mixKey :: (Cipher c, Hash h) => ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixKey d ss = ss & ssCipher .~ cs
& ssHasKey .~ True
& ssck .~ ck
where
(ck, k) = hashHKDF (ss ^. ssck) d
cs = CipherState (cipherBytesToSym k) cipherZeroNonce
mixPSK :: (Cipher c, Hash h) => ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixPSK psk ss = ss'' & ssHasPSK .~ True
where
(ck, tmp) = hashHKDF (ss ^. ssck) psk
ss' = ss & ssck .~ ck
ss'' = mixHash tmp ss'
mixHash :: (Cipher c, Hash h) => ScrubbedBytes -> SymmetricState c h -> SymmetricState c h
mixHash d ss = ss & ssh %~ Right . hash . (`append` d) . sshBytes
encryptAndHash :: (Cipher c, Hash h) => Plaintext -> SymmetricState c h -> (ScrubbedBytes, SymmetricState c h)
encryptAndHash (Plaintext pt) ss
| ss ^. ssHasKey = (cipherTextToBytes ct, kss)
| otherwise = (pt, nkss)
where
(ct, cs) = encryptAndIncrement (AssocData (sshBytes (ss ^. ssh))) (Plaintext pt) (ss ^. ssCipher)
kss = mixHash (cipherTextToBytes ct) ss & ssCipher .~ cs
nkss = mixHash pt ss
decryptAndHash :: (Cipher c, Hash h) => Ciphertext c -> SymmetricState c h -> (Plaintext, SymmetricState c h)
decryptAndHash ct ss
| ss ^. ssHasKey = (pt, kss)
| otherwise = (Plaintext (cipherTextToBytes ct), nkss)
where
(pt, cs) = decryptAndIncrement (AssocData (sshBytes (ss ^. ssh))) ct (ss ^. ssCipher)
kss = mixHash (cipherTextToBytes ct) ss & ssCipher .~ cs
nkss = mixHash (cipherTextToBytes ct) ss
split :: (Cipher c, Hash h) => SymmetricState c h -> (CipherState c, CipherState c)
split ss = (cs1, cs2)
where
(cs1k, cs2k) = hashHKDF (ss ^. ssck) (convert empty)
cs1k' = cipherBytesToSym . hashCKToBytes $ cs1k
cs2k' = cipherBytesToSym cs2k
cs1 = CipherState cs1k' cipherZeroNonce
cs2 = CipherState cs2k' cipherZeroNonce
sshBytes :: Hash h => Either ScrubbedBytes (Digest h) -> ScrubbedBytes
sshBytes (Left h) = h
sshBytes (Right h) = hashToBytes h