{-# LANGUAGE OverloadedStrings, ViewPatterns #-}
module Crypto.Encoding.PHKDF where
import Data.Bits(Bits, (.&.), shift)
import Data.ByteString(ByteString)
import Data.Int(Int64)
import Data.List(scanl')
import qualified Data.ByteString as B
import Debug.Trace
extendTagToList :: ByteString -> [ByteString]
extendTagToList :: ByteString -> [ByteString]
extendTagToList ByteString
tag = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 then [ByteString
tag] else [ByteString]
tag'
where
n :: Int
n = ByteString -> Int
B.length ByteString
tag
x :: Int
x = (Int
19 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
64
tag' :: [ByteString]
tag' = Int64 -> [ByteString] -> [ByteString]
takeBs (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
x)) ([ByteString] -> [ByteString]
forall a. HasCallStack => [a] -> [a]
cycle [ByteString
tag, ByteString
"\x00"])
[ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shift` Int
2)]
extendTag :: ByteString -> ByteString
extendTag :: ByteString -> ByteString
extendTag = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> [ByteString]
extendTagToList
trimExtendedTag :: ByteString -> Maybe ByteString
trimExtendedTag :: ByteString -> Maybe ByteString
trimExtendedTag ByteString
extTag
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
extTag
| ByteString
extTag ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> ByteString
extendTag ByteString
tag = Maybe ByteString
forall a. Maybe a
Nothing
| Bool
otherwise = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
tag
where
n :: Int
n = ByteString -> Int
B.length ByteString
extTag
x :: Word8
x = HasCallStack => ByteString -> Word8
ByteString -> Word8
B.last ByteString
extTag Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shift` (-Int
2)
tag :: ByteString
tag = Int -> ByteString -> ByteString
B.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
extTag
add64WhileLt :: (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt :: forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt a
b a
c
| a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
| Bool
otherwise = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)
add64WhileLt' :: (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' :: forall a. (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' a
b a
c
| a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
| Bool
otherwise = String -> a -> a
forall a. String -> a -> a
trace String
msg a
d
where
d :: a
d = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)
msg :: String
msg = a -> String
forall a. Show a => a -> String
show a
b String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" + " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show ((a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
b) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shift` (-Int
6)) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" * 64 == "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
d String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" == " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
c String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" + " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show (a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
c)
dropBs :: Int64 -> [ ByteString ] -> [ ByteString ]
dropBs :: Int64 -> [ByteString] -> [ByteString]
dropBs = Int64 -> [ByteString] -> [ByteString]
go
where
len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
go :: Int64 -> [ByteString] -> [ByteString]
go Int64
_ [] = []
go Int64
0 [ByteString]
bs = [ByteString]
bs
go Int64
n (ByteString
b:[ByteString]
bs)
| Int64
n Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteString -> Int64
len ByteString
b = Int64 -> [ByteString] -> [ByteString]
go (Int64
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
len ByteString
b) [ByteString]
bs
| Bool
otherwise = Int -> ByteString -> ByteString
B.drop (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bs
takeBs :: Int64 -> [ ByteString ] -> [ ByteString ]
takeBs :: Int64 -> [ByteString] -> [ByteString]
takeBs = Int64 -> [ByteString] -> [ByteString]
go
where
len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
go :: Int64 -> [ByteString] -> [ByteString]
go Int64
_ [] = []
go Int64
n (ByteString
b:[ByteString]
bs)
| Int64
n Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0 = []
| ByteString -> Int64
len ByteString
b Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
n = ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int64 -> [ByteString] -> [ByteString]
go (Int64
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
len ByteString
b) [ByteString]
bs
| Bool
otherwise = [Int -> ByteString -> ByteString
B.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
b]
takeBs' :: Int64 -> [ ByteString ] -> [ ByteString ]
takeBs' :: Int64 -> [ByteString] -> [ByteString]
takeBs' Int64
n [ByteString]
bs = if Bool
haveEnough then Int64 -> [ByteString] -> [ByteString]
takeBs Int64
n [ByteString]
bs else []
where
len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
haveEnough :: Bool
haveEnough = (Int64 -> Bool) -> [Int64] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
n) ((Int64 -> Int64 -> Int64) -> Int64 -> [Int64] -> [Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
(+) Int64
0 ((ByteString -> Int64) -> [ByteString] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int64
len [ByteString]
bs))
takeB' :: Int64 -> ByteString -> Maybe ByteString
takeB' :: Int64 -> ByteString -> Maybe ByteString
takeB' Int64
n ByteString
bs =
if Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bs) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
n
then Maybe ByteString
forall a. Maybe a
Nothing
else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Int -> ByteString -> ByteString
B.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
bs)
assertTakeB' :: Int64 -> ByteString -> ByteString
assertTakeB' :: Int64 -> ByteString -> ByteString
assertTakeB' = (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ByteString
forall a. HasCallStack => String -> a
error String
"not enough bytes") ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString)
-> (ByteString -> Maybe ByteString) -> ByteString -> ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) ((ByteString -> Maybe ByteString) -> ByteString -> ByteString)
-> (Int64 -> ByteString -> Maybe ByteString)
-> Int64
-> ByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> ByteString -> Maybe ByteString
takeB'
nullBuffer :: ByteString
nullBuffer :: ByteString
nullBuffer = Int -> Word8 -> ByteString
B.replicate Int
64 Word8
0
chunkify :: Int -> ByteString -> [ ByteString ]
chunkify :: Int -> ByteString -> [ByteString]
chunkify Int
n = ByteString -> [ByteString]
go
where
go :: ByteString -> [ByteString]
go ByteString
bs
| ByteString -> Bool
B.null ByteString
bs = []
| Bool
otherwise = ByteString
bs0 ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
go ByteString
bs1
where (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
n ByteString
bs
chunkifyCycle
:: Int64
-> ByteString
-> Int64
-> [ ByteString ]
chunkifyCycle :: Int64 -> ByteString -> Int64 -> [ByteString]
chunkifyCycle Int64
len ByteString
bs = Int64 -> [ByteString]
go
where
modN :: a -> a
modN a
pos = a
pos a -> a -> a
forall a. Integral a => a -> a -> a
`mod` (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bs) a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
ext :: ByteString
ext = [ByteString] -> ByteString
B.concat (ByteString
bsByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:Int64 -> [ByteString] -> [ByteString]
takeBs Int64
len ([ByteString] -> [ByteString]
forall a. HasCallStack => [a] -> [a]
cycle [ByteString
"\x00", ByteString
bs]))
go :: Int64 -> [ByteString]
go (Int64 -> Int64
forall {a}. Integral a => a -> a
modN -> Int64
pos) = Int64 -> ByteString -> ByteString
assertTakeB' Int64
len (Int -> ByteString -> ByteString
B.drop (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
pos) ByteString
ext) ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int64 -> [ByteString]
go (Int64
pos Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
len)