{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE Haskell2010   #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}

{-# LANGUAGE Trustworthy   #-}

-- |
-- Copyright: © 2020  Herbert Valerio Riedel
-- SPDX-License-Identifier: GPL-2.0-or-later
--
-- Apply XOR-masks to 'BS.ByteString's and memory regions.
--
module Data.XOR
    ( -- * Apply 32-bit XOR mask
      xor32StrictByteString
    , xor32StrictByteString'
    , xor32LazyByteString
    , xor32ShortByteString
    , xor32CStringLen

      -- * Apply 8-bit XOR mask
    , xor8StrictByteString
    , xor8LazyByteString
    , xor8ShortByteString
    , xor8CStringLen

    ) where

-- base
import           Control.Exception              (assert)
import           Control.Monad                  (void)
import           Control.Monad.ST               (ST, runST)
import           Data.Bits
import           Data.Tuple                     (swap)
import           Endianness                     (ByteOrder (..), Word32, Word8, byteSwap32,
                                                 targetByteOrder)
import           Foreign.C                      (CStringLen)
import           Foreign.ForeignPtr             (withForeignPtr)
import           Foreign.Marshal.Utils          (copyBytes)
import           Foreign.Ptr                    (Ptr, alignPtr, castPtr, minusPtr, plusPtr)
import           Foreign.Storable               (peek, poke)
import           System.IO.Unsafe               (unsafeDupablePerformIO)

import qualified GHC.Exts                       as X
import qualified GHC.ST                         as X
import qualified GHC.Word                       as X

-- bytestring
import qualified Data.ByteString                as BS
import           Data.ByteString.Internal       (mallocByteString)
import qualified Data.ByteString.Internal       as BS (ByteString (..))
import qualified Data.ByteString.Lazy.Internal  as BL (ByteString (..))
import qualified Data.ByteString.Short          as SBS
import           Data.ByteString.Short.Internal (ShortByteString (SBS))

----------------------------------------------------------------------------

{- high-level reference impl

-- about 6-7 times slower
xor32StrictByteString'ref :: Word32 -> BS.ByteString -> BS.ByteString
xor32StrictByteString'ref 0    = id
xor32StrictByteString'ref msk0 = snd . BS.mapAccumL go msk0
  where
    go :: Word32 -> Word8 -> (Word32,Word8)
    go msk b = let b'   = fromIntegral msk' `xor` b
                   msk' = rotateL msk 8
               in b' `seq` (msk',b')

-- about 3 times slower
xor8StrictByteString'ref :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString'ref 0    = id
xor8StrictByteString'ref msk0 = BS.map (xor msk0)

-}

-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to 'BS.ByteString'.
--
-- >>> xor32StrictByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
-- In other words, the 32-bit word @0x37fa213d@ is taken as the infinite series of octets @('cycle' [0x37,0xfa,0x21,0x3d])@ and 'xor'ed with the respective octets from the input 'BS.ByteString'.
--
-- The 'xor' laws give rise to the following laws:
--
-- prop> xor32StrictByteString m (xor32StrictByteString m x) == x
--
-- prop> xor32StrictByteString 0 x == x
--
-- prop> xor32StrictByteString m (xor32StrictByteString n x) == xor32StrictByteString (m `xor` n) x
--
-- This function is semantically equivalent to the (less efficient) implementation shown below
--
-- > xor32StrictByteString'ref :: Word32 -> BS.ByteString -> BS.ByteString
-- > xor32StrictByteString'ref 0    = id
-- > xor32StrictByteString'ref msk0 = snd . BS.mapAccumL go msk0
-- >   where
-- >     go :: Word32 -> Word8 -> (Word32,Word8)
-- >     go msk b = let b'   = fromIntegral (msk' .&. 0xff) `xor` b
-- >                    msk' = rotateL msk 8
-- >                in (msk',b')
--
-- The 'xor32StrictByteString' implementation is about 6-7 times faster than the naive implementation above.
xor32StrictByteString :: Word32 -> BS.ByteString -> BS.ByteString
xor32StrictByteString :: Word32 -> ByteString -> ByteString
xor32StrictByteString Word32
0 ByteString
bs   = ByteString
bs
xor32StrictByteString Word32
_ ByteString
bs   | ByteString -> Bool
BS.null ByteString
bs = ByteString
bs
xor32StrictByteString Word32
msk ByteString
bs = (ByteString, Word32) -> ByteString
forall a b. (a, b) -> a
fst (Word32 -> ByteString -> (ByteString, Word32)
xor32StrictByteString'' Word32
msk ByteString
bs)

-- | Convenience version of 'xor32StrictByteString' which also returns the rotated XOR-mask useful for chained masking.
--
-- >>> xor32StrictByteString' 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- (0xfa213d37,"Hello")
--
xor32StrictByteString' :: Word32 -> BS.ByteString -> (Word32,BS.ByteString)
xor32StrictByteString' :: Word32 -> ByteString -> (Word32, ByteString)
xor32StrictByteString' Word32
0 ByteString
bs   = (Word32
0,ByteString
bs)
xor32StrictByteString' Word32
msk ByteString
bs | ByteString -> Bool
BS.null ByteString
bs = (Word32
msk,ByteString
bs)
xor32StrictByteString' Word32
msk ByteString
bs = (ByteString, Word32) -> (Word32, ByteString)
forall a b. (a, b) -> (b, a)
swap (Word32 -> ByteString -> (ByteString, Word32)
xor32StrictByteString'' Word32
msk ByteString
bs)

-- | Variant of 'xor32StrictByteString' for masking lazy 'BL.ByteString's.
--
-- >>> xor32LazyByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
xor32LazyByteString :: Word32 -> BL.ByteString -> BL.ByteString
xor32LazyByteString :: Word32 -> ByteString -> ByteString
xor32LazyByteString Word32
0 = ByteString -> ByteString
forall a. a -> a
id
xor32LazyByteString Word32
msk0 = Word32 -> ByteString -> ByteString
go Word32
msk0
  where
    go :: Word32 -> ByteString -> ByteString
go Word32
_ ByteString
BL.Empty = ByteString
BL.Empty
    go Word32
msk (BL.Chunk ByteString
x ByteString
xs) = ByteString -> ByteString -> ByteString
BL.Chunk ByteString
x' (Word32 -> ByteString -> ByteString
go Word32
msk' ByteString
xs)
      where
        (ByteString
x',Word32
msk') = Word32 -> ByteString -> (ByteString, Word32)
xor32StrictByteString'' Word32
msk ByteString
x

{-# INLINE xor32StrictByteString'' #-}
-- internal
xor32StrictByteString'' :: Word32 -> BS.ByteString -> (BS.ByteString,Word32)
xor32StrictByteString'' :: Word32 -> ByteString -> (ByteString, Word32)
xor32StrictByteString'' Word32
msk0 (BS.PS ForeignPtr Word8
x Int
s Int
l)
    = Int -> (Ptr Word8 -> IO Word32) -> (ByteString, Word32)
forall a. Int -> (Ptr Word8 -> IO a) -> (ByteString, a)
unsafeCreate' Int
l ((Ptr Word8 -> IO Word32) -> (ByteString, Word32))
-> (Ptr Word8 -> IO Word32) -> (ByteString, Word32)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p8 ->
        ForeignPtr Word8 -> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
x ((Ptr Word8 -> IO Word32) -> IO Word32)
-> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
f -> do
          Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
p8 (Ptr Word8
f Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s) (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l)

          case Ptr Word8 -> Int -> Int
forall a. Ptr a -> Int -> Int
remPtr Ptr Word8
p8 Int
4 of
            Int
0 -> do
              let trailer :: Int
trailer = Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
4
                  lbytes :: Int
lbytes = Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
trailer
              Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned Word32
msk0 (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
p8) Int
lbytes
              Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
msk0 (Ptr Word8
p8 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
lbytes) Int
trailer
            Int
_ ->
              -- misaligned bytestring...
              --
              -- This should not happen, as newly allocated
              -- bytestrings ought to be word-aligned; but if the
              -- impossible does happen we have a semantically sound
              -- codepath to jump to...
              Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr Word32
msk0 Ptr Word8
p8 Int
l



-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to 'SBS.ShortByteString'. See also 'xor32StrictByteString'.
--
-- >>> xor32ShortByteString 0x37fa213d "\x7f\x9f\x4d\x51\x58"
-- "Hello"
--
xor32ShortByteString :: Word32 -> SBS.ShortByteString -> SBS.ShortByteString
xor32ShortByteString :: Word32 -> ShortByteString -> ShortByteString
xor32ShortByteString Word32
0 ShortByteString
sbs = ShortByteString
sbs
xor32ShortByteString Word32
_ ShortByteString
sbs | ShortByteString -> Bool
SBS.null ShortByteString
sbs = ShortByteString
sbs
xor32ShortByteString Word32
mask0be ShortByteString
sbs = (forall s. ST s ShortByteString) -> ShortByteString
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ShortByteString) -> ShortByteString)
-> (forall s. ST s ShortByteString) -> ShortByteString
forall a b. (a -> b) -> a -> b
$ do
    MShortByteString s
tmp <- Int -> ST s (MShortByteString s)
forall s. Int -> ST s (MShortByteString s)
newSBS Int
len

    let loop4 :: Int -> ST s ()
loop4 Int
i
          | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len4  = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise  = MShortByteString s -> Int -> Word32 -> ST s ()
forall s. MShortByteString s -> Int -> Word32 -> ST s ()
writeWord32Array MShortByteString s
tmp Int
i (ShortByteString -> Int -> Word32
indexWord32Array ShortByteString
sbs Int
i Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
mask0) ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ST s ()
loop4 (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

    Int -> ST s ()
loop4 Int
0

    let writeXor8 :: Int -> Word8 -> ST s ()
writeXor8 Int
ofs Word8
msk8 = MShortByteString s -> Int -> Word8 -> ST s ()
forall s. MShortByteString s -> Int -> Word8 -> ST s ()
writeWord8Array MShortByteString s
tmp Int
ofs (ShortByteString -> Int -> Word8
indexWord8Array ShortByteString
sbs Int
ofs Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
msk8)

    case Int
len1 of
      Int
0 -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Int
1 -> do
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be Int
24))
      Int
2 -> do
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be Int
24))
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be Int
16))
      Int
3 -> do
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
3) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be Int
24))
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be Int
16))
        Int -> Word8 -> ST s ()
writeXor8 (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shiftR Word32
mask0be  Int
8))
      Int
_ -> ST s ()
forall a. (?callStack::CallStack) => a
undefined -- impossible

    MShortByteString s -> ST s ShortByteString
forall s. MShortByteString s -> ST s ShortByteString
unsafeFreezeSBS MShortByteString s
tmp
  where
    len :: Int
len = ShortByteString -> Int
SBS.length ShortByteString
sbs
    (Int
len4,Int
len1) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
len Int
4

    mask0 :: Word32
mask0 = case ByteOrder
targetByteOrder of
              ByteOrder
LittleEndian -> Word32 -> Word32
byteSwap32 Word32
mask0be
              ByteOrder
BigEndian    -> Word32
mask0be


{-# INLINEABLE xor32CStringLen #-}
-- | Apply 32-bit XOR mask (considered as four octets in big-endian order) to memory region expressed as base-pointer and size. The returned value is the input mask rotated by the word-size remained of the memory region size (useful for chained xor-masking of multiple memory-fragments).
xor32CStringLen :: Word32 -> CStringLen -> IO Word32
xor32CStringLen :: Word32 -> CStringLen -> IO Word32
xor32CStringLen Word32
m (Ptr CChar
p,Int
l) = Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr Word32
m (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
p) Int
l

{-# INLINEABLE xor32Ptr #-}
xor32Ptr :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr Word32
0      !Ptr Word8
_  !Int
_ = Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
0
xor32Ptr !Word32
mask0 !Ptr Word8
_   Int
0 = Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
mask0
xor32Ptr !Word32
mask0 !Ptr Word8
p0 !Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4 = Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask0 Ptr Word8
p0 Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> IO Word32
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"xor32Ptr: negative size argument not supported"
xor32Ptr !Word32
mask0 !Ptr Word8
p0 !Int
n
  | Bool -> Bool -> Bool
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Ptr Word8
p0 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Ptr Word8
forall {b}. Ptr b
p1 Bool -> Bool -> Bool
&& Ptr Any
forall {b}. Ptr b
p1 Ptr Any -> Ptr Any -> Bool
forall a. Ord a => a -> a -> Bool
<= Ptr Any
forall {b}. Ptr b
p2 Bool -> Bool -> Bool
&& Ptr Any
forall {b}. Ptr b
p2 Ptr Any -> Ptr Any -> Bool
forall a. Ord a => a -> a -> Bool
<= Ptr Any
forall {b}. Ptr b
p3 Bool -> Bool -> Bool
&& Int
n0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4 Bool -> Bool -> Bool
&& Int
n2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4) Bool
False = IO Word32
forall a. (?callStack::CallStack) => a
undefined -- assert invariants
  | Int
n1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask0 Ptr Word8
p0 Int
n
  | Int
n0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
      Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned    Word32
mask0 Ptr Word32
forall {b}. Ptr b
p1 Int
n1
      Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask0 Ptr Word8
forall {b}. Ptr b
p2 Int
n2
  | Bool
otherwise = do
      Word32
mask1 <- Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask0 Ptr Word8
p0 Int
n0
      Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned    Word32
mask1 Ptr Word32
forall {b}. Ptr b
p1 Int
n1
      Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask1 Ptr Word8
forall {b}. Ptr b
p2 Int
n2
  where
    -- Invariants: p0 <= p1 <= p2 <= p3
    --             0  <= n0  < 4
    --             0  <= n1
    --             0  <= n2  < 4
    --             n  == n0+n1+n2 >= 4
    p1 :: Ptr b
p1 = Ptr Word8 -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr (Ptr Word8 -> Int -> Ptr Word8
forall a. Ptr a -> Int -> Ptr a
alignPtr Ptr Word8
p0 Int
d)
    p2 :: Ptr a
p2 = Ptr a -> Int -> Ptr a
forall a. Ptr a -> Int -> Ptr a
alignPtrDown Ptr a
forall {b}. Ptr b
p3 Int
d
    p3 :: Ptr b
p3 = Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
p0 Int
n
    d :: Int
d  = Int
4

    n0 :: Int
n0 = Ptr Any
forall {b}. Ptr b
p1 Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
p0
    n1 :: Int
n1 = Ptr Any
forall {b}. Ptr b
p2 Ptr Any -> Ptr Any -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Any
forall {b}. Ptr b
p1
    n2 :: Int
n2 = Ptr Any
forall {b}. Ptr b
p3 Ptr Any -> Ptr Any -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Any
forall {b}. Ptr b
p2

-- internal
xor32PtrNonAligned :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned :: Word32 -> Ptr Word8 -> Int -> IO Word32
xor32PtrNonAligned Word32
mask0 Ptr Word8
_ Int
0 = Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
mask0
xor32PtrNonAligned Word32
mask0 Ptr Word8
p Int
1 = do
  let mask1 :: Word32
mask1 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
rotateL Word32
mask0 Int
8
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
mask1) Ptr Word8
p
  Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
mask1
xor32PtrNonAligned Word32
mask0 Ptr Word8
p Int
2 = do
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
mask0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)) Ptr Word8
p
  let mask1 :: Word32
mask1 = Word32
mask0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`rotateL` Int
16
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
mask1) (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)
  Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
mask1
xor32PtrNonAligned Word32
mask0 Ptr Word8
p Int
3 = do
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
mask0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)) Ptr Word8
p
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
mask0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)) (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)
  let mask1 :: Word32
mask1 = Word32
mask0 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`rotateL` Int
24
  Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
mask1) (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2)
  Word32 -> IO Word32
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Word32
mask1
xor32PtrNonAligned Word32
mask0 Ptr Word8
p0 Int
n = Word32 -> Ptr Word8 -> IO Word32
forall {t}. (Integral t, Bits t) => t -> Ptr Word8 -> IO t
go Word32
mask0 Ptr Word8
p0
  where
    p' :: Ptr b
p' = Ptr Word8
p0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
    go :: t -> Ptr Word8 -> IO t
go t
m Ptr Word8
p
      | Ptr Word8
p Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall {b}. Ptr b
p'   = t -> IO t
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return t
m
      | Bool
otherwise = do
          let m' :: t
m' = t -> Int -> t
forall a. Bits a => a -> Int -> a
rotateL t
m Int
8
          Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 (t -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
m') Ptr Word8
p
          t -> Ptr Word8 -> IO t
go t
m' (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)

-- internal
xor32PtrAligned :: Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned :: Word32 -> Ptr Word32 -> Int -> IO ()
xor32PtrAligned Word32
_ Ptr Word32
_ Int
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
xor32PtrAligned Word32
mask0be Ptr Word32
p0 Int
n
  = Bool -> IO () -> IO ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Ptr Word32
p0 Ptr Word32 -> Int -> Int
forall a. Ptr a -> Int -> Int
`remPtr` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word32 -> IO ()
go Ptr Word32
p0
  where
    p' :: Ptr b
p' = Ptr Word32
p0 Ptr Word32 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
    go :: Ptr Word32 -> IO ()
go Ptr Word32
p
      | Ptr Word32
p Ptr Word32 -> Ptr Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word32
forall {b}. Ptr b
p'   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do { Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 Word32
mask0 Ptr Word32
p; Ptr Word32 -> IO ()
go (Ptr Word32
p Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
4) }

    mask0 :: Word32
mask0 = case ByteOrder
targetByteOrder of
              ByteOrder
LittleEndian -> Word32 -> Word32
byteSwap32 Word32
mask0be
              ByteOrder
BigEndian    -> Word32
mask0be

----------------------------------------------------------------------------

remPtr :: Ptr a -> Int -> Int
remPtr :: forall a. Ptr a -> Int -> Int
remPtr (X.Ptr Addr#
x) (X.I# Int#
d) = Int# -> Int
X.I# (Addr# -> Int# -> Int#
X.remAddr# Addr#
x Int#
d)

alignPtrDown :: Ptr a -> Int -> Ptr a
alignPtrDown :: forall a. Ptr a -> Int -> Ptr a
alignPtrDown Ptr a
p Int
i
  = case Ptr a -> Int -> Int
forall a. Ptr a -> Int -> Int
remPtr Ptr a
p Int
i of
      Int
0 -> Ptr a
p
      Int
n -> Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
p (Int -> Int
forall a. Num a => a -> a
negate Int
n)

xor8Ptr1 :: Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 :: Word8 -> Ptr Word8 -> IO ()
xor8Ptr1 Word8
msk Ptr Word8
ptr  = do { Word8
x <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
ptr; Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
msk Word8
x) }

-- xor16Ptr1 :: Word16 -> Ptr Word16 -> IO ()
-- xor16Ptr1 msk ptr = do { x <- peek ptr; poke ptr (xor msk x) }

xor32Ptr1 :: Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 :: Word32 -> Ptr Word32 -> IO ()
xor32Ptr1 Word32
msk Ptr Word32
ptr = do { Word32
x <- Ptr Word32 -> IO Word32
forall a. Storable a => Ptr a -> IO a
peek Ptr Word32
ptr; Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word32
ptr (Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
msk Word32
x) }

{-# INLINE unsafeCreate' #-}
unsafeCreate' :: Int -> (Ptr Word8 -> IO a) -> (BS.ByteString, a)
unsafeCreate' :: forall a. Int -> (Ptr Word8 -> IO a) -> (ByteString, a)
unsafeCreate' Int
l0 Ptr Word8 -> IO a
f0 = IO (ByteString, a) -> (ByteString, a)
forall a. IO a -> a
unsafeDupablePerformIO (Int -> (Ptr Word8 -> IO a) -> IO (ByteString, a)
forall a. Int -> (Ptr Word8 -> IO a) -> IO (ByteString, a)
create' Int
l0 Ptr Word8 -> IO a
f0)
  where
    {-# INLINE create' #-}
    create' :: Int -> (Ptr Word8 -> IO a) -> IO (BS.ByteString, a)
    create' :: forall a. Int -> (Ptr Word8 -> IO a) -> IO (ByteString, a)
create' Int
l Ptr Word8 -> IO a
f = do
        ForeignPtr Word8
fp <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
mallocByteString Int
l
        a
res <- ForeignPtr Word8 -> (Ptr Word8 -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO a) -> IO a) -> (Ptr Word8 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr Word8 -> IO a
f Ptr Word8
p
        (ByteString, a) -> IO (ByteString, a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Word8 -> Int -> Int -> ByteString
BS.PS ForeignPtr Word8
fp Int
0 Int
l, a
res)

----------------------------------------------------------------------------
-- single octet masks -- trivially mapped to 32-bit versions

expandW8ToW32 :: Word8 -> Word32
expandW8ToW32 :: Word8 -> Word32
expandW8ToW32 Word8
x = Word32
x' Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word32
x' Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16)
  where
    x' :: Word32
x' = Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
8)


-- | Apply 8-bit XOR mask to each octet of a 'BS.ByteString'.
--
-- >>> xor8StrictByteString 0x20 "Hello"
-- "hELLO"
--
-- This function is a faster implementation of the semantically equivalent function shown below:
--
-- > xor8StrictByteString'ref :: Word8 -> BS.ByteString -> BS.ByteString
-- > xor8StrictByteString'ref 0    = id
-- > xor8StrictByteString'ref msk0 = BS.map (xor msk0)
--
xor8StrictByteString :: Word8 -> BS.ByteString -> BS.ByteString
xor8StrictByteString :: Word8 -> ByteString -> ByteString
xor8StrictByteString Word8
x = Word32 -> ByteString -> ByteString
xor32StrictByteString (Word8 -> Word32
expandW8ToW32 Word8
x)

-- | Apply 8-bit XOR mask to each octet of a lazy 'BL.ByteString'.
--
-- See also 'xor8StrictByteString'
xor8LazyByteString :: Word8 -> BL.ByteString -> BL.ByteString
xor8LazyByteString :: Word8 -> ByteString -> ByteString
xor8LazyByteString Word8
x = Word32 -> ByteString -> ByteString
xor32LazyByteString (Word8 -> Word32
expandW8ToW32 Word8
x)

-- | Apply 8-bit XOR mask to each octet of a 'SBS.ShortByteString'.
--
-- See also 'xor8StrictByteString'
xor8ShortByteString :: Word8 -> SBS.ShortByteString -> SBS.ShortByteString
xor8ShortByteString :: Word8 -> ShortByteString -> ShortByteString
xor8ShortByteString Word8
x = Word32 -> ShortByteString -> ShortByteString
xor32ShortByteString (Word8 -> Word32
expandW8ToW32 Word8
x)

-- | Apply 8-bit XOR mask to each octet of a memory region expressed as start address and length in bytes.
--
-- See also 'xor8StrictByteString'
xor8CStringLen :: Word8 -> CStringLen -> IO ()
xor8CStringLen :: Word8 -> CStringLen -> IO ()
xor8CStringLen Word8
x (Ptr CChar
p,Int
l) = IO Word32 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Word32 -> Ptr Word8 -> Int -> IO Word32
xor32Ptr (Word8 -> Word32
expandW8ToW32 Word8
x) (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
p) Int
l)

----------------------------------------------------------------------------
-- The missing mutable ShortByteString abstraction

data MShortByteString s = MSBS (X.MutableByteArray# s)

newSBS :: Int -> ST s (MShortByteString s)
newSBS :: forall s. Int -> ST s (MShortByteString s)
newSBS (X.I# Int#
len#) = STRep s (MShortByteString s) -> ST s (MShortByteString s)
forall s a. STRep s a -> ST s a
X.ST (STRep s (MShortByteString s) -> ST s (MShortByteString s))
-> STRep s (MShortByteString s) -> ST s (MShortByteString s)
forall a b. (a -> b) -> a -> b
$ \State# s
s0 -> case Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
X.newByteArray# Int#
len# State# s
s0 of (# State# s
s, MutableByteArray# s
mba# #) -> (# State# s
s, MutableByteArray# s -> MShortByteString s
forall s. MutableByteArray# s -> MShortByteString s
MSBS MutableByteArray# s
mba# #)

indexWord8Array :: ShortByteString -> Int -> Word8
indexWord8Array :: ShortByteString -> Int -> Word8
indexWord8Array (SBS ByteArray#
ba#) (X.I# Int#
i#) = Word8# -> Word8
X.W8# (ByteArray# -> Int# -> Word8#
X.indexWord8Array# ByteArray#
ba# Int#
i#)

writeWord8Array :: MShortByteString s -> Int -> Word8 -> ST s ()
writeWord8Array :: forall s. MShortByteString s -> Int -> Word8 -> ST s ()
writeWord8Array (MSBS MutableByteArray# s
mba#) (X.I# Int#
i#) (X.W8# Word8#
w#) = STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
X.ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s0 -> case MutableByteArray# s -> Int# -> Word8# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word8# -> State# d -> State# d
X.writeWord8Array# MutableByteArray# s
mba# Int#
i# Word8#
w# State# s
s0 of State# s
s -> (# State# s
s, () #)

indexWord32Array :: ShortByteString -> Int -> Word32
indexWord32Array :: ShortByteString -> Int -> Word32
indexWord32Array (SBS ByteArray#
ba#) (X.I# Int#
i#) = Word32# -> Word32
X.W32# (ByteArray# -> Int# -> Word32#
X.indexWord32Array# ByteArray#
ba# Int#
i#)

writeWord32Array :: MShortByteString s -> Int -> Word32 -> ST s ()
writeWord32Array :: forall s. MShortByteString s -> Int -> Word32 -> ST s ()
writeWord32Array (MSBS MutableByteArray# s
mba#) (X.I# Int#
i#) (X.W32# Word32#
w#) = STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
X.ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s0 -> case MutableByteArray# s -> Int# -> Word32# -> State# s -> State# s
forall d.
MutableByteArray# d -> Int# -> Word32# -> State# d -> State# d
X.writeWord32Array# MutableByteArray# s
mba# Int#
i# Word32#
w# State# s
s0 of State# s
s -> (# State# s
s, () #)

unsafeFreezeSBS :: MShortByteString s -> ST s ShortByteString
unsafeFreezeSBS :: forall s. MShortByteString s -> ST s ShortByteString
unsafeFreezeSBS (MSBS MutableByteArray# s
mba#) = STRep s ShortByteString -> ST s ShortByteString
forall s a. STRep s a -> ST s a
X.ST (STRep s ShortByteString -> ST s ShortByteString)
-> STRep s ShortByteString -> ST s ShortByteString
forall a b. (a -> b) -> a -> b
$ \State# s
s0 -> case MutableByteArray# s -> State# s -> (# State# s, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
X.unsafeFreezeByteArray# MutableByteArray# s
mba# State# s
s0 of (# State# s
s, ByteArray#
ba# #) -> (# State# s
s, ByteArray# -> ShortByteString
SBS ByteArray#
ba# #)