module Data.ByteString.Extended (
    module Data.ByteString
  , constTimeCompare
) where

import           Data.Bits
import           Data.ByteString
import qualified Data.List       as L
import           Prelude         hiding (length, zip, zipWith)

constTimeCompare :: ByteString -> ByteString -> Bool
constTimeCompare :: ByteString -> ByteString -> Bool
constTimeCompare ByteString
l ByteString
r = ByteString -> Int
length ByteString
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
length ByteString
r Bool -> Bool -> Bool
&& ByteString -> ByteString -> Bool
comp' ByteString
l ByteString
r
  where
    comp' :: ByteString -> ByteString -> Bool
comp' ByteString
a ByteString
b = Word8
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== (Word8 -> Word8 -> Word8) -> Word8 -> [Word8] -> Word8
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
(.|.) Word8
0 ((Word8 -> Word8 -> Word8) -> (Word8, Word8) -> Word8
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ((Word8, Word8) -> Word8) -> [(Word8, Word8)] -> [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ByteString -> [(Word8, Word8)]
zip ByteString
a ByteString
b)