{-# LANGUAGE BangPatterns #-}
module Crypto.Cipher.Xtea
( SymmetricKey (..)
, Endianness (..)
, EncryptionError (..)
, encryptBlock
, encrypt
, encrypt'
, DecryptionError (..)
, decryptBlock
, decrypt
, decrypt'
) where
import Control.Monad ( replicateM )
import Data.Binary.Get ( Get, getWord32be, getWord32le, runGetOrFail )
import Data.Binary.Put ( Put, putWord32be, putWord32le, runPut )
import Data.Bits ( shiftL, shiftR, xor, (.&.) )
import Data.ByteString ( ByteString )
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Word ( Word32 )
import Prelude hiding ( sum )
data SymmetricKey = SymmetricKey {-# UNPACK #-} !Word32 {-# UNPACK #-} !Word32 {-# UNPACK #-} !Word32 {-# UNPACK #-} !Word32
unsafeGetSymmetricKeyBlock :: SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock :: SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock (SymmetricKey Word32
k0 Word32
k1 Word32
k2 Word32
k3) Word32
i =
case Word32
i of
Word32
0 -> Word32
k0
Word32
1 -> Word32
k1
Word32
2 -> Word32
k2
Word32
3 -> Word32
k3
Word32
_ -> [Char] -> Word32
forall a. HasCallStack => [Char] -> a
error ([Char] -> Word32) -> [Char] -> Word32
forall a b. (a -> b) -> a -> b
$ [Char]
"impossible: requested index " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Word32 -> [Char]
forall a. Show a => a -> [Char]
show Word32
i [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" is out of range"
xteaBlockSize :: Int
xteaBlockSize :: Int
xteaBlockSize = Int
8
delta :: Word32
delta :: Word32
delta = Word32
0x9E3779B9
rounds :: Word32
rounds :: Word32
rounds = Word32
32
data Endianness
=
LittleEndian
|
BigEndian
deriving stock (Int -> Endianness -> [Char] -> [Char]
[Endianness] -> [Char] -> [Char]
Endianness -> [Char]
(Int -> Endianness -> [Char] -> [Char])
-> (Endianness -> [Char])
-> ([Endianness] -> [Char] -> [Char])
-> Show Endianness
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Endianness -> [Char] -> [Char]
showsPrec :: Int -> Endianness -> [Char] -> [Char]
$cshow :: Endianness -> [Char]
show :: Endianness -> [Char]
$cshowList :: [Endianness] -> [Char] -> [Char]
showList :: [Endianness] -> [Char] -> [Char]
Show, Endianness -> Endianness -> Bool
(Endianness -> Endianness -> Bool)
-> (Endianness -> Endianness -> Bool) -> Eq Endianness
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Endianness -> Endianness -> Bool
== :: Endianness -> Endianness -> Bool
$c/= :: Endianness -> Endianness -> Bool
/= :: Endianness -> Endianness -> Bool
Eq)
byteStringToXteaBlocks :: Endianness -> ByteString -> Maybe [(Word32, Word32)]
byteStringToXteaBlocks :: Endianness -> ByteString -> Maybe [(Word32, Word32)]
byteStringToXteaBlocks Endianness
endianness ByteString
bs
| Int
remainder Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
case Get [(Word32, Word32)]
-> ByteString
-> Either
(ByteString, ByteOffset, [Char])
(ByteString, ByteOffset, [(Word32, Word32)])
forall a.
Get a
-> ByteString
-> Either
(ByteString, ByteOffset, [Char]) (ByteString, ByteOffset, a)
runGetOrFail Get [(Word32, Word32)]
getBlocks (ByteString -> ByteString
LBS.fromStrict ByteString
bs) of
Left (ByteString, ByteOffset, [Char])
_ -> Maybe [(Word32, Word32)]
forall a. Maybe a
Nothing
Right (ByteString
_, ByteOffset
_, [(Word32, Word32)]
blocks) -> [(Word32, Word32)] -> Maybe [(Word32, Word32)]
forall a. a -> Maybe a
Just [(Word32, Word32)]
blocks
| Bool
otherwise = Maybe [(Word32, Word32)]
forall a. Maybe a
Nothing
where
numXteaBlocks :: Int
remainder :: Int
(Int
numXteaBlocks, Int
remainder) = ByteString -> Int
BS.length ByteString
bs Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
xteaBlockSize
getWord32 :: Get Word32
getWord32 :: Get Word32
getWord32 =
case Endianness
endianness of
Endianness
LittleEndian -> Get Word32
getWord32le
Endianness
BigEndian -> Get Word32
getWord32be
getBlocks :: Get [(Word32, Word32)]
getBlocks :: Get [(Word32, Word32)]
getBlocks = Int -> Get (Word32, Word32) -> Get [(Word32, Word32)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
numXteaBlocks ((,) (Word32 -> Word32 -> (Word32, Word32))
-> Get Word32 -> Get (Word32 -> (Word32, Word32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
getWord32 Get (Word32 -> (Word32, Word32))
-> Get Word32 -> Get (Word32, Word32)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word32
getWord32)
xteaBlocksToByteString :: Endianness -> [(Word32, Word32)] -> ByteString
xteaBlocksToByteString :: Endianness -> [(Word32, Word32)] -> ByteString
xteaBlocksToByteString Endianness
endianness [(Word32, Word32)]
bs = ByteString -> ByteString
LBS.toStrict (Put -> ByteString
runPut Put
putBlocks)
where
putWord32 :: Word32 -> Put
putWord32 :: Word32 -> Put
putWord32 =
case Endianness
endianness of
Endianness
LittleEndian -> Word32 -> Put
putWord32le
Endianness
BigEndian -> Word32 -> Put
putWord32be
putBlocks :: Put
putBlocks :: Put
putBlocks = ((Word32, Word32) -> Put) -> [(Word32, Word32)] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(Word32
b0, Word32
b1) -> Word32 -> Put
putWord32 Word32
b0 Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word32 -> Put
putWord32 Word32
b1) [(Word32, Word32)]
bs
encryptBlock :: SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
encryptBlock :: SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
encryptBlock SymmetricKey
k (Word32
startingV0, Word32
startingV1) = Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go Word32
rounds Word32
0 Word32
startingV0 Word32
startingV1
where
go
:: Word32
-> Word32
-> Word32
-> Word32
-> (Word32, Word32)
go :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go Word32
0 Word32
_ !Word32
v0 !Word32
v1 = (Word32
v0, Word32
v1)
go !Word32
n !Word32
sum !Word32
v0 !Word32
v1 = Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go (Word32
n Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
1) Word32
nextSum Word32
nextV0 Word32
nextV1
where
nextV0 :: Word32
nextV0 :: Word32
nextV0 =
Word32
v0
Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ ( (((Word32
v1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
v1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
5)) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
v1)
Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
sum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock SymmetricKey
k (Word32
sum Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
3))
)
nextSum :: Word32
nextSum :: Word32
nextSum = Word32
sum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
delta
nextV1 :: Word32
nextV1 :: Word32
nextV1 =
Word32
v1
Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ ( (((Word32
nextV0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
nextV0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
5)) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
nextV0)
Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
nextSum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock SymmetricKey
k ((Word32
nextSum Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
11) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
3))
)
data EncryptionError
=
EncryptionInvalidInputLengthError !Int
deriving stock (Int -> EncryptionError -> [Char] -> [Char]
[EncryptionError] -> [Char] -> [Char]
EncryptionError -> [Char]
(Int -> EncryptionError -> [Char] -> [Char])
-> (EncryptionError -> [Char])
-> ([EncryptionError] -> [Char] -> [Char])
-> Show EncryptionError
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> EncryptionError -> [Char] -> [Char]
showsPrec :: Int -> EncryptionError -> [Char] -> [Char]
$cshow :: EncryptionError -> [Char]
show :: EncryptionError -> [Char]
$cshowList :: [EncryptionError] -> [Char] -> [Char]
showList :: [EncryptionError] -> [Char] -> [Char]
Show, EncryptionError -> EncryptionError -> Bool
(EncryptionError -> EncryptionError -> Bool)
-> (EncryptionError -> EncryptionError -> Bool)
-> Eq EncryptionError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EncryptionError -> EncryptionError -> Bool
== :: EncryptionError -> EncryptionError -> Bool
$c/= :: EncryptionError -> EncryptionError -> Bool
/= :: EncryptionError -> EncryptionError -> Bool
Eq)
encrypt' :: Endianness -> SymmetricKey -> ByteString -> Either EncryptionError ByteString
encrypt' :: Endianness
-> SymmetricKey -> ByteString -> Either EncryptionError ByteString
encrypt' Endianness
endianness SymmetricKey
k ByteString
bs = do
[(Word32, Word32)]
blocks <-
case Endianness -> ByteString -> Maybe [(Word32, Word32)]
byteStringToXteaBlocks Endianness
endianness ByteString
bs of
Maybe [(Word32, Word32)]
Nothing -> EncryptionError -> Either EncryptionError [(Word32, Word32)]
forall a b. a -> Either a b
Left (EncryptionError -> Either EncryptionError [(Word32, Word32)])
-> EncryptionError -> Either EncryptionError [(Word32, Word32)]
forall a b. (a -> b) -> a -> b
$ Int -> EncryptionError
EncryptionInvalidInputLengthError (ByteString -> Int
BS.length ByteString
bs)
Just [(Word32, Word32)]
x -> [(Word32, Word32)] -> Either EncryptionError [(Word32, Word32)]
forall a b. b -> Either a b
Right [(Word32, Word32)]
x
let encryptedBlocks :: [(Word32, Word32)]
encryptedBlocks = ((Word32, Word32) -> (Word32, Word32))
-> [(Word32, Word32)] -> [(Word32, Word32)]
forall a b. (a -> b) -> [a] -> [b]
map (SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
encryptBlock SymmetricKey
k) [(Word32, Word32)]
blocks
ByteString -> Either EncryptionError ByteString
forall a b. b -> Either a b
Right (Endianness -> [(Word32, Word32)] -> ByteString
xteaBlocksToByteString Endianness
endianness [(Word32, Word32)]
encryptedBlocks)
encrypt :: SymmetricKey -> ByteString -> Either EncryptionError ByteString
encrypt :: SymmetricKey -> ByteString -> Either EncryptionError ByteString
encrypt = Endianness
-> SymmetricKey -> ByteString -> Either EncryptionError ByteString
encrypt' Endianness
BigEndian
decryptBlock :: SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
decryptBlock :: SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
decryptBlock SymmetricKey
k (Word32
startingV0, Word32
startingV1) = Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go Word32
rounds (Word32
delta Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
rounds) Word32
startingV0 Word32
startingV1
where
go
:: Word32
-> Word32
-> Word32
-> Word32
-> (Word32, Word32)
go :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go Word32
0 Word32
_ !Word32
v0 !Word32
v1 = (Word32
v0, Word32
v1)
go !Word32
n !Word32
sum !Word32
v0 !Word32
v1 = Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32)
go (Word32
n Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
1) Word32
nextSum Word32
nextV0 Word32
nextV1
where
nextV1 :: Word32
nextV1 :: Word32
nextV1 =
Word32
v1
Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- ( (((Word32
v0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
v0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
5)) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
v0)
Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
sum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock SymmetricKey
k ((Word32
sum Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
11) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
3))
)
nextSum :: Word32
nextSum :: Word32
nextSum = Word32
sum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
delta
nextV0 :: Word32
nextV0 :: Word32
nextV0 =
Word32
v0
Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- ( (((Word32
nextV1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
4) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
nextV1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
5)) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
nextV1)
Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
nextSum Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ SymmetricKey -> Word32 -> Word32
unsafeGetSymmetricKeyBlock SymmetricKey
k (Word32
nextSum Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
3))
)
data DecryptionError
=
DecryptionInvalidInputLengthError !Int
deriving stock (Int -> DecryptionError -> [Char] -> [Char]
[DecryptionError] -> [Char] -> [Char]
DecryptionError -> [Char]
(Int -> DecryptionError -> [Char] -> [Char])
-> (DecryptionError -> [Char])
-> ([DecryptionError] -> [Char] -> [Char])
-> Show DecryptionError
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> DecryptionError -> [Char] -> [Char]
showsPrec :: Int -> DecryptionError -> [Char] -> [Char]
$cshow :: DecryptionError -> [Char]
show :: DecryptionError -> [Char]
$cshowList :: [DecryptionError] -> [Char] -> [Char]
showList :: [DecryptionError] -> [Char] -> [Char]
Show, DecryptionError -> DecryptionError -> Bool
(DecryptionError -> DecryptionError -> Bool)
-> (DecryptionError -> DecryptionError -> Bool)
-> Eq DecryptionError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DecryptionError -> DecryptionError -> Bool
== :: DecryptionError -> DecryptionError -> Bool
$c/= :: DecryptionError -> DecryptionError -> Bool
/= :: DecryptionError -> DecryptionError -> Bool
Eq)
decrypt' :: Endianness -> SymmetricKey -> ByteString -> Either DecryptionError ByteString
decrypt' :: Endianness
-> SymmetricKey -> ByteString -> Either DecryptionError ByteString
decrypt' Endianness
endianness SymmetricKey
k ByteString
bs = do
[(Word32, Word32)]
blocks <-
case Endianness -> ByteString -> Maybe [(Word32, Word32)]
byteStringToXteaBlocks Endianness
endianness ByteString
bs of
Maybe [(Word32, Word32)]
Nothing -> DecryptionError -> Either DecryptionError [(Word32, Word32)]
forall a b. a -> Either a b
Left (DecryptionError -> Either DecryptionError [(Word32, Word32)])
-> DecryptionError -> Either DecryptionError [(Word32, Word32)]
forall a b. (a -> b) -> a -> b
$ Int -> DecryptionError
DecryptionInvalidInputLengthError (ByteString -> Int
BS.length ByteString
bs)
Just [(Word32, Word32)]
x -> [(Word32, Word32)] -> Either DecryptionError [(Word32, Word32)]
forall a b. b -> Either a b
Right [(Word32, Word32)]
x
let decryptedBlocks :: [(Word32, Word32)]
decryptedBlocks = ((Word32, Word32) -> (Word32, Word32))
-> [(Word32, Word32)] -> [(Word32, Word32)]
forall a b. (a -> b) -> [a] -> [b]
map (SymmetricKey -> (Word32, Word32) -> (Word32, Word32)
decryptBlock SymmetricKey
k) [(Word32, Word32)]
blocks
ByteString -> Either DecryptionError ByteString
forall a b. b -> Either a b
Right (Endianness -> [(Word32, Word32)] -> ByteString
xteaBlocksToByteString Endianness
endianness [(Word32, Word32)]
decryptedBlocks)
decrypt :: SymmetricKey -> ByteString -> Either DecryptionError ByteString
decrypt :: SymmetricKey -> ByteString -> Either DecryptionError ByteString
decrypt = Endianness
-> SymmetricKey -> ByteString -> Either DecryptionError ByteString
decrypt' Endianness
BigEndian