-- |
-- Module      : Crypto.Data.AFIS
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Haskell implementation of the Anti-forensic information splitter
-- available in LUKS. <http://clemens.endorphin.org/AFsplitter>
--
-- The algorithm bloats an arbitrary secret with many bits that are necessary for
-- the recovery of the key (merge), and allow greater way to permanently
-- destroy a key stored on disk.
--
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Data.AFIS
    ( split
    , merge
    ) where

import           Crypto.Hash
import           Crypto.Random.Types
import           Crypto.Internal.Compat
import           Control.Monad (forM_, foldM)
import           Data.Word
import           Data.Bits
import           Foreign.Storable
import           Foreign.Ptr

import           Crypto.Internal.ByteArray (ByteArray, Bytes, MemView(..))
import qualified Crypto.Internal.ByteArray as B

import           Data.Memory.PtrMethods (memSet, memCopy)

-- | Split data to diffused data, using a random generator and
-- an hash algorithm.
--
-- the diffused data will consist of random data for (expandTimes-1)
-- then the last block will be xor of the accumulated random data diffused by
-- the hash algorithm.
--
-- ----------
-- -  orig  -
-- ----------
--
-- ---------- ---------- --------------
-- - rand1  - - rand2  - - orig ^ acc -
-- ---------- ---------- --------------
--
-- where acc is :
--   acc(n+1) = hash (n ++ rand(n)) ^ acc(n)
--
split :: (ByteArray ba, HashAlgorithm hash, DRG rng)
      => hash  -- ^ Hash algorithm to use as diffuser
      -> rng   -- ^ Random generator to use
      -> Int   -- ^ Number of times to diffuse the data.
      -> ba    -- ^ original data to diffuse.
      -> (ba, rng)         -- ^ The diffused data
{-# NOINLINE split #-}
split :: hash -> rng -> Int -> ba -> (ba, rng)
split hash
hashAlg rng
rng Int
expandTimes ba
src
    | Int
expandTimes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = [Char] -> (ba, rng)
forall a. HasCallStack => [Char] -> a
error [Char]
"invalid expandTimes value"
    | Bool
otherwise        = IO (ba, rng) -> (ba, rng)
forall a. IO a -> a
unsafeDoIO (IO (ba, rng) -> (ba, rng)) -> IO (ba, rng) -> (ba, rng)
forall a b. (a -> b) -> a -> b
$ do
        (rng
rng', ba
bs) <- Int -> (Ptr Any -> IO rng) -> IO (rng, ba)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
diffusedLen Ptr Any -> IO rng
forall a. Ptr a -> IO rng
runOp
        (ba, rng) -> IO (ba, rng)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
bs, rng
rng')
  where diffusedLen :: Int
diffusedLen = Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
expandTimes
        blockSize :: Int
blockSize   = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src
        runOp :: Ptr a -> IO rng
runOp Ptr a
dstPtr = do
            let lastBlock :: Ptr b
lastBlock = Ptr a
dstPtr Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
            Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
forall b. Ptr b
lastBlock Word8
0 Int
blockSize
            let randomBlockPtrs :: [Ptr b]
randomBlockPtrs = (Int -> Ptr b) -> [Int] -> [Ptr b]
forall a b. (a -> b) -> [a] -> [b]
map (Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
dstPtr (Int -> Ptr b) -> (Int -> Int) -> Int -> Ptr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
blockSize) [Int
0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)]
            rng
rng' <- (rng -> Ptr Word8 -> IO rng) -> rng -> [Ptr Word8] -> IO rng
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM rng -> Ptr Word8 -> IO rng
forall b. DRG b => b -> Ptr Word8 -> IO b
fillRandomBlock rng
rng [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
            (Ptr Word8 -> IO ()) -> [Ptr Word8] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
forall b. Ptr b
lastBlock) [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
            ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
srcPtr Ptr Word8
forall b. Ptr b
lastBlock Int
blockSize
            rng -> IO rng
forall (m :: * -> *) a. Monad m => a -> m a
return rng
rng'
        addRandomBlock :: Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
lastBlock Ptr Word8
blockPtr = do
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
blockPtr Ptr Word8
lastBlock Int
blockSize
            hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
lastBlock Int
blockSize
        fillRandomBlock :: b -> Ptr Word8 -> IO b
fillRandomBlock b
g Ptr Word8
blockPtr = do
            let (Bytes
rand :: Bytes, b
g') = Int -> b -> (Bytes, b)
forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
randomBytesGenerate Int
blockSize b
g
            Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
rand ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
randPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
blockPtr Ptr Word8
randPtr Int
blockSize
            b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
g'

-- | Merge previously diffused data back to the original data.
merge :: (ByteArray ba, HashAlgorithm hash)
      => hash  -- ^ Hash algorithm used as diffuser
      -> Int   -- ^ Number of times to un-diffuse the data
      -> ba    -- ^ Diffused data
      -> ba    -- ^ Original data
{-# NOINLINE merge #-}
merge :: hash -> Int -> ba -> ba
merge hash
hashAlg Int
expandTimes ba
bs
    | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0            = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data not a multiple of expandTimes"
    | Int
originalSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data null"
    | Bool
otherwise         = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
originalSize ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
        ba -> (Ptr Any -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs ((Ptr Any -> IO ()) -> IO ()) -> (Ptr Any -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Any
srcPtr -> do
            Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
dstPtr Word8
0 Int
originalSize
            [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
                Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
                hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
dstPtr Int
originalSize
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` ((Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
            () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where (Int
originalSize,Int
r) = Int
len Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
expandTimes
        len :: Int
len              = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs

-- | inplace Xor with an input
-- dst = src `xor` dst
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
src Ptr Word8
dst Int
sz
    | Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
64 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Ptr Word64 -> Ptr Word64 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
8 (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word64) (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
    | Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
32 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Ptr Word32 -> Ptr Word32 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
4 (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word32) (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
    | Bool
otherwise        = Int -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
1 (Ptr Word8
src :: Ptr Word8) Ptr Word8
dst Int
sz
  where loop :: Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
_    Ptr b
_ Ptr b
_ Int
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop Int
incr Ptr b
s Ptr b
d Int
n = do b
a <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
s
                             b
b <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
d
                             Ptr b -> b -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
d (b
a b -> b -> b
forall a. Bits a => a -> a -> a
`xor` b
b)
                             Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
incr (Ptr b
s Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Ptr b
d Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
incr)

diffuse :: HashAlgorithm hash
        => hash      -- ^ Hash function to use as diffuser
        -> Ptr Word8 -- ^ buffer to diffuse, modify in place
        -> Int       -- ^ length of buffer to diffuse
        -> IO ()
diffuse :: hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
src Int
sz = Ptr Word8 -> Int -> IO ()
loop Ptr Word8
src Int
0
  where (Int
full,Int
pad) = Int
sz Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
digestSize 
        loop :: Ptr Word8 -> Int -> IO ()
loop Ptr Word8
s Int
i
            | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
full = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
digestSize
                            Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
digestSize
                            Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
digestSize) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            | Int
pad Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
pad
                            Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
pad
                            () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        digestSize :: Int
digestSize = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize hash
hashAlg

        -- Hash [ BE32(n), (p .. p+hashSz) ]
        hashBlock :: Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
n Ptr Word8
p Int
hashSz = do
            let ctx :: Context hash
ctx = hash -> Context hash
forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hash
hashAlg
            Digest hash -> m (Digest hash)
forall (m :: * -> *) a. Monad m => a -> m a
return (Digest hash -> m (Digest hash)) -> Digest hash -> m (Digest hash)
forall a b. (a -> b) -> a -> b
$! Context hash -> Digest hash
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize (Context hash -> Digest hash) -> Context hash -> Digest hash
forall a b. (a -> b) -> a -> b
$ Context hash -> MemView -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (Context hash -> Bytes -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context hash
ctx (Int -> Bytes
be32 Int
n)) (Ptr Word8 -> Int -> MemView
MemView Ptr Word8
p Int
hashSz)

        be32 :: Int -> Bytes
        be32 :: Int -> Bytes
be32 Int
n = Int -> (Ptr Word8 -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
4 ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr               (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
24))
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16))
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
8))
            Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
3) (Int -> Word8
f8 Int
n)
          where
                f8 :: Int -> Word8
                f8 :: Int -> Word8
f8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral