module Crypto.TripleSec.Utils where
import Data.Monoid ((<>))
import Control.Monad (when)
import Data.Maybe
import Control.Monad.Except
import Crypto.Error
import Crypto.Cipher.Types hiding (Cipher)
import qualified Crypto.Cipher.XSalsa as XSalsa
import Crypto.TripleSec.Internal (ByteArray)
import qualified Crypto.TripleSec.Internal as I
import Crypto.TripleSec.Types
import Crypto.TripleSec.Constants
trustedCipherInit :: (ByteArray ba, BlockCipher c) => ba -> c
trustedCipherInit = fromJust . maybeCryptoError . cipherInit
initXSalsa :: ByteArray ba => ba -> ba -> XSalsa.State
initXSalsa = XSalsa.initialize 20
xSalsaCombine :: ByteArray ba => XSalsa.State -> ba -> ba
xSalsaCombine state input = output
where (output, _) = XSalsa.combine state input
checkCipher :: (ByteArray ba, MonadError TripleSecException m)
=> TripleSec ba
-> ba
-> m ()
checkCipher cipher providedSalt
= when (providedSalt /= passwordSalt cipher) (throwError (DecryptionException MisMatchedCipherSalt))
checkPrefix :: (ByteArray ba, MonadError TripleSecException m)
=> ba
-> m (ba, ba, ba)
checkPrefix cipherText = checkLength cipherText >> checkMagicBytes cipherText >>= checkVersionBytes
checkSalt :: (ByteArray ba, MonadError TripleSecException m)
=> ba
-> m ()
checkSalt salt = when (I.length salt /= saltLen) $ throwError $ CipherInitException InvalidSaltLength
checkLength :: (ByteArray ba, MonadError TripleSecException m) => ba -> m ()
checkLength cipherText
= when (I.length cipherText <= overhead) $ throwError $ DecryptionException InvalidCipherTextLength
checkMagicBytes :: (ByteArray ba, MonadError TripleSecException m) => ba -> m (ba, ba)
checkMagicBytes cipherText = do
let (providedMagicBytes, lessMagicBytes) = I.splitAt (length magicBytes) cipherText
when (providedMagicBytes /= packedMagicBytes) $ throwError $ DecryptionException InvalidMagicBytes
return (providedMagicBytes, lessMagicBytes)
checkVersionBytes :: (ByteArray ba, MonadError TripleSecException m) => (ba, ba) -> m (ba, ba, ba)
checkVersionBytes (providedMagicBytes, lessMagicBytes) = do
let (providedVersionBytes, lessVersion) = I.splitAt (length versionBytes) lessMagicBytes
when (providedVersionBytes /= packedVersionBytes) $ throwError $ DecryptionException InvalidVersion
let (providedSalt, lessPrefix) = I.splitAt saltLen lessVersion
let prefix = providedMagicBytes <> providedVersionBytes <> providedSalt
return (prefix, providedSalt, lessPrefix)