{-# LANGUAGE OverloadedStrings, ViewPatterns #-}

-------------------------------------------------------------------------------
-- |
-- Module:      Crypto.Encoding.PHKDF
-- Copyright:   (c) 2024 Auth Global
-- License:     Apache2
--
-------------------------------------------------------------------------------

module Crypto.Encoding.PHKDF where

import Data.Bits(Bits, (.&.), shift)
import Data.ByteString(ByteString)
import Data.Int(Int64)
import Data.List(scanl')
import qualified Data.ByteString as B

import Debug.Trace

extendTagToList :: ByteString -> [ByteString]
extendTagToList :: ByteString -> [ByteString]
extendTagToList ByteString
tag = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 then [ByteString
tag] else [ByteString]
tag'
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
tag
    x :: Int
x = (Int
19 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
64
    tag' :: [ByteString]
tag' = Int64 -> [ByteString] -> [ByteString]
takeBs (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
x)) ([ByteString] -> [ByteString]
forall a. HasCallStack => [a] -> [a]
cycle [ByteString
tag, ByteString
"\x00"])
         [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shift` Int
2)]

-- | Extends a PHKDF end-of-message tag in order to ensure the last SHA-256
--   block contains something interesting.
--
--   Tags less than 160 bits (20 bytes) long are appended directly, without
--   extension, as the final portion of the message. Thus this function
--   is the identity on short inputs.
--
--   After extension, tags that are at least 20 bytes long should be thought
--   of as a bitstring with a single null bit appended at the end to make it
--   a full bytestring.
--
--   Tags 160 bits or longer are first extended, iff necessary, to a full
--   bytestring by adding a single "1" bit followed by zero to six "0" bits.
--
--   The bytestring is then extended by 0-63 bytes as needed to make the
--   overall length equivalent to 19 (mod 64). The first byte of the extension
--   is a null byte, then followed by the bytestring, then starting again
--   at the null byte as needed.
--
--   The length of this extension takes up the first 6 bits of the last byte,
--   followed by a "0" bit denoting the tag is a bytestring, or a "1" denoting
--   that the tag is a proper bitstring whose length is not an exact multiple
--   of 8.
--
--   The final bit is reserved for SHA-256's end-of-message padding, which
--   will set it to 1.

extendTag :: ByteString -> ByteString
extendTag :: ByteString -> ByteString
extendTag = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> [ByteString]
extendTagToList

-- | This function robustly undoes 'extendTag', thus "proving" that all
--   collisions on PHKDF's tag are cryptographically non-trivial, even after
--   extension.
--
--   This is a "proof" in the sense that if
--   @trimExtendedTag (extendTag x) == Just x@ is true for all bytestrings
--   @x@, then all collisions are non-trivial, but we haven't presented a
--   full deductive proof of this property.  It is part of the test suite,
--   tested by quickcheck fuzzing.
--
--   The rest of PHKDF and the G3P's syntax follows this as an iron rule
--   of syntax design. I've not literally written a program to parse out
--   the original arguments, but I've ensured that it is straightforward
--   to do so in principle.
--
--   In the case of variable-length PHKDF, starting from some known buffer
--   position (usually either 0 or 32), first there are zero or more
--   bitstring arguments encoded via TupleHash syntax. Since TupleHash's
--   length encoding cannot start with a null byte, a single null byte
--   is used to signal the end of these input arguments. Then 0-63 end
--   padding bytes are generated in order to bring the buffer position
--   equivalent to 32 (mod 64), then 4 bytes of counter, then the extended
--   version of PHKDF's end-of-message tag, then finally SHA256's end padding.
--
--   This is easy to robustly undo, as I've started to demonstrate in this
--   subroutine. This leads to a simple categorical/combinatorial style proof
--   that all collisions over PHKDF's input arguments and domain tag are
--   cryptographically non-trivial.

trimExtendedTag :: ByteString -> Maybe ByteString
trimExtendedTag :: ByteString -> Maybe ByteString
trimExtendedTag ByteString
extTag
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
extTag
  | ByteString
extTag ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> ByteString
extendTag ByteString
tag = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
tag
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
extTag
    x :: Word8
x = HasCallStack => ByteString -> Word8
ByteString -> Word8
B.last ByteString
extTag Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shift` (-Int
2)
    tag :: ByteString
tag = Int -> ByteString -> ByteString
B.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
extTag

{--

FIXME: as written, this only works on signed arithmetic, unless the modulus @a@
is a power of 2, such as 64

-- | @addWhileLt a b c@ is equivalent to  @while (b < c) { b += a }; return b@
addWhileLt :: Integral a => a -> a -> a -> a
addWhileLt a b c
   | b >= c = b
   | otherwise = c + ((b - c) `mod` a)

--}

-- | @add64WhileLt b c@ is equivalent to  @while (b < c) { b += 64 }; return b@

add64WhileLt :: (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt :: forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt a
b a
c
   | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
   | Bool
otherwise = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)

-- | Equivalent to 'add64WhileLt', except with trace debugging.  This should
--   never be used in production.

add64WhileLt' :: (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' :: forall a. (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' a
b a
c
   | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
   | Bool
otherwise = String -> a -> a
forall a. String -> a -> a
trace String
msg a
d
     where
       d :: a
d = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)
       msg :: String
msg = a -> String
forall a. Show a => a -> String
show a
b String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" + " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show ((a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
b) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shift` (-Int
6)) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" * 64 == "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
d String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" == " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
c String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" + " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show (a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
c)

dropBs :: Int64 -> [ ByteString ] -> [ ByteString ]
dropBs :: Int64 -> [ByteString] -> [ByteString]
dropBs = Int64 -> [ByteString] -> [ByteString]
go
  where
    len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
    go :: Int64 -> [ByteString] -> [ByteString]
go Int64
_ [] = []
    go Int64
0 [ByteString]
bs = [ByteString]
bs
    go Int64
n (ByteString
b:[ByteString]
bs)
      | Int64
n Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteString -> Int64
len ByteString
b = Int64 -> [ByteString] -> [ByteString]
go (Int64
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
len ByteString
b) [ByteString]
bs
      | Bool
otherwise = Int -> ByteString -> ByteString
B.drop (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bs

takeBs :: Int64 -> [ ByteString ] -> [ ByteString ]
takeBs :: Int64 -> [ByteString] -> [ByteString]
takeBs = Int64 -> [ByteString] -> [ByteString]
go
  where
    len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
    go :: Int64 -> [ByteString] -> [ByteString]
go Int64
_ [] = []
    go Int64
n (ByteString
b:[ByteString]
bs)
      | Int64
n Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0 = []
      | ByteString -> Int64
len ByteString
b Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
n = ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int64 -> [ByteString] -> [ByteString]
go (Int64
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
len ByteString
b) [ByteString]
bs
      | Bool
otherwise = [Int -> ByteString -> ByteString
B.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
b]

takeBs' :: Int64 -> [ ByteString ] -> [ ByteString ]
takeBs' :: Int64 -> [ByteString] -> [ByteString]
takeBs' Int64
n [ByteString]
bs = if Bool
haveEnough then Int64 -> [ByteString] -> [ByteString]
takeBs Int64
n [ByteString]
bs else []
  where
    len :: ByteString -> Int64
len = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (ByteString -> Int) -> ByteString -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length
    haveEnough :: Bool
haveEnough = (Int64 -> Bool) -> [Int64] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
n) ((Int64 -> Int64 -> Int64) -> Int64 -> [Int64] -> [Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
(+) Int64
0 ((ByteString -> Int64) -> [ByteString] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int64
len [ByteString]
bs))

takeB' :: Int64 -> ByteString -> Maybe ByteString
takeB' :: Int64 -> ByteString -> Maybe ByteString
takeB' Int64
n ByteString
bs =
  -- this fromIntegral is inherently safe
  if Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bs) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
n
  then Maybe ByteString
forall a. Maybe a
Nothing
  -- this fromIntegral is safe because of the check above
  else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Int -> ByteString -> ByteString
B.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
n) ByteString
bs)

assertTakeB' :: Int64 -> ByteString -> ByteString
assertTakeB' :: Int64 -> ByteString -> ByteString
assertTakeB' = (ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ByteString
forall a. HasCallStack => String -> a
error String
"not enough bytes") ByteString -> ByteString
forall a. a -> a
id (Maybe ByteString -> ByteString)
-> (ByteString -> Maybe ByteString) -> ByteString -> ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) ((ByteString -> Maybe ByteString) -> ByteString -> ByteString)
-> (Int64 -> ByteString -> Maybe ByteString)
-> Int64
-> ByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> ByteString -> Maybe ByteString
takeB'

nullBuffer :: ByteString
nullBuffer :: ByteString
nullBuffer = Int -> Word8 -> ByteString
B.replicate Int
64 Word8
0

-- | Partition a bytestring into chunks of up to a given size

chunkify :: Int -> ByteString -> [ ByteString ]
chunkify :: Int -> ByteString -> [ByteString]
chunkify Int
n = ByteString -> [ByteString]
go
  where
    go :: ByteString -> [ByteString]
go ByteString
bs
      | ByteString -> Bool
B.null ByteString
bs = []
      | Bool
otherwise = ByteString
bs0 ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
go ByteString
bs1
        where (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
n ByteString
bs

-- | Partition a cyclically extended bytestring into chunks of
--   a given size, starting at a given offset.
--
--   Note that repetitions of the original string get a single
--   null byte placed between them.

chunkifyCycle
  :: Int64 -- ^ Desired chunk size
  -> ByteString -- ^ String to be cyclically extended.
  -> Int64 -- ^ Starting offset
  -> [ ByteString ] -- ^ Infinite stream of chunks
chunkifyCycle :: Int64 -> ByteString -> Int64 -> [ByteString]
chunkifyCycle Int64
len ByteString
bs = Int64 -> [ByteString]
go
  where
    modN :: a -> a
modN a
pos = a
pos a -> a -> a
forall a. Integral a => a -> a -> a
`mod` (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bs) a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
    ext :: ByteString
ext = [ByteString] -> ByteString
B.concat (ByteString
bsByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:Int64 -> [ByteString] -> [ByteString]
takeBs Int64
len ([ByteString] -> [ByteString]
forall a. HasCallStack => [a] -> [a]
cycle [ByteString
"\x00", ByteString
bs]))
    go :: Int64 -> [ByteString]
go (Int64 -> Int64
forall {a}. Integral a => a -> a
modN -> Int64
pos) = Int64 -> ByteString -> ByteString
assertTakeB' Int64
len (Int -> ByteString -> ByteString
B.drop (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
pos) ByteString
ext) ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int64 -> [ByteString]
go (Int64
pos Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
len)