{-|
Module      : Z.Crypto.KDF
Description : Key Derivation Functions
Copyright   : Dong Han, 2021
              AnJie Dong, 2021
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

KDF(Key Derivation Function) and PBKDF(Password Based Key Derivation Function).

-}
module Z.Crypto.KDF (
  -- * KDF
    KDFType(..)
  , HashType(..)
  , MACType(..)
  , kdf
  , kdf'
  -- * PBKDF
  , PBKDFType(..)
  , pbkdf
  , pbkdfTimed
  -- * Internal helps
  , kdfTypeToCBytes
  , pbkdfTypeToParam
  ) where

import           Z.Botan.Exception
import           Z.Botan.FFI
import           Z.Crypto.Hash     (HashType (..), hashTypeToCBytes)
import           Z.Crypto.MAC      (MACType (..), macTypeToCBytes)
import           Z.Data.CBytes     (CBytes, withCBytes, withCBytesUnsafe)
import qualified Z.Data.CBytes     as CB
import qualified Z.Data.Vector     as V
import           Z.Foreign

-----------------------------
-- Key Derivation Function --
-----------------------------

-- | Key derivation functions are used to turn some amount of shared secret material into uniform random keys
-- suitable for use with symmetric algorithms. An example of an input which is useful for a KDF is a shared
-- secret created using Diffie-Hellman key agreement.
data KDFType
    = HKDF MACType
    | HKDF_Extract MACType
    | HKDF_Expand MACType
    -- ^ Defined in RFC 5869, HKDF uses HMAC to process inputs.
    -- Also available are variants HKDF-Extract and HKDF-Expand.
    -- HKDF is the combined Extract+Expand operation.
    -- Use the combined HKDF unless you need compatibility with some other system.
    | KDF2 HashType
    -- ^ KDF2 comes from IEEE 1363. It uses a hash function.
    | KDF1_18033 HashType
    -- ^ KDF1 from ISO 18033-2. Very similar to (but incompatible with) KDF2.
    | KDF1 HashType
    -- ^ KDF1 from IEEE 1363. It can only produce an output at most the length of the hash function used.
    | TLS_PRF
    -- ^ A KDF from ANSI X9.42. Sometimes used for Diffie-Hellman.
    | TLS_12_PRF MACType
    | SP800_108_Counter MACType
    -- ^ KDFs from NIST SP 800-108. Variants include “SP800-108-Counter”, “SP800-108-Feedback” and “SP800-108-Pipeline”.
    | SP800_108_Feedback MACType
    | SP800_108_Pipeline MACType
    | SP800_56AHash HashType
    -- ^ NIST SP 800-56A KDF using hash function
    | SP800_56AMAC MACType
    -- ^ NIST SP 800-56A KDF using HMAC
    | SP800_56C MACType
    -- ^ NIST SP 800-56C KDF using HMAC

kdfTypeToCBytes :: KDFType -> CBytes
kdfTypeToCBytes :: KDFType -> CBytes
kdfTypeToCBytes (HKDF MACType
mt        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (HKDF_Extract MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF-Extract(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (HKDF_Expand MACType
mt ) = [CBytes] -> CBytes
CB.concat [ CBytes
"HKDF-Expand(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (KDF2 HashType
ht        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF2(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDF1_18033 HashType
ht  ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF1-18033(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDF1 HashType
ht        ) = [CBytes] -> CBytes
CB.concat [ CBytes
"KDF1(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (KDFType
TLS_PRF        ) = CBytes
"TLS-PRF"
kdfTypeToCBytes (TLS_12_PRF MACType
mt  ) = [CBytes] -> CBytes
CB.concat [ CBytes
"TLS-12-PRF(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Counter MACType
mt ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Counter(" ,  MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Feedback MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Feedback(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_108_Pipeline MACType
mt) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-108-Pipeline(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_56AHash HashType
ht     ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56A(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"]
kdfTypeToCBytes (SP800_56AMAC MACType
mt      ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56A(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]
kdfTypeToCBytes (SP800_56C MACType
mt         ) = [CBytes] -> CBytes
CB.concat [ CBytes
"SP800-56C(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"]

-- | Derive a key using the given KDF algorithm.
kdf :: HasCallStack
    => KDFType    -- ^ the name of the given PBKDF algorithm
    -> Int        -- ^ length of output key
    -> V.Bytes    -- ^ secret
    -> V.Bytes    -- ^ salt
    -> V.Bytes    -- ^ label
    -> IO V.Bytes
{-# INLINABLE kdf #-}
kdf :: KDFType -> Int -> Bytes -> Bytes -> Bytes -> IO Bytes
kdf KDFType
algo Int
siz Bytes
secret Bytes
salt Bytes
label =
    CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe (KDFType -> CBytes
kdfTypeToCBytes KDFType
algo) ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
algoBA ->
        Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
secret ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
secretBA Int
secretOff Int
secretLen ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
salt ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
saltBA Int
saltOff Int
saltLen ->
                Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
label ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
labelBA Int
labelOff Int
labelLen ->
                    (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                        -- some kdf needs xor output buffer, so we clear it first
                        MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                        IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                            BA# Word8
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> BA# Word8
-> Int
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_kdf BA# Word8
algoBA MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                                BA# Word8
secretBA Int
secretOff Int
secretLen
                                BA# Word8
saltBA Int
saltOff Int
saltLen
                                BA# Word8
labelBA Int
labelOff Int
labelLen)

-- | Derive a key using the given KDF algorithm, with default empty salt and label.
kdf' :: HasCallStack
     => KDFType    -- ^ the name of the given PBKDF algorithm
     -> Int        -- ^ length of output key
     -> V.Bytes    -- ^ secret
     -> IO V.Bytes
{-# INLINABLE kdf' #-}
kdf' :: KDFType -> Int -> Bytes -> IO Bytes
kdf' KDFType
algo Int
siz Bytes
secret = HasCallStack =>
KDFType -> Int -> Bytes -> Bytes -> Bytes -> IO Bytes
KDFType -> Int -> Bytes -> Bytes -> Bytes -> IO Bytes
kdf KDFType
algo Int
siz Bytes
secret Bytes
forall a. Monoid a => a
mempty Bytes
forall a. Monoid a => a
mempty

--------------------------------------------
-- Password-Based Key Derivation Function --
--------------------------------------------

-- | Often one needs to convert a human readable password into a cryptographic key. It is useful to slow down the
-- computation of these computations in order to reduce the speed of brute force search, thus they are parameterized
-- in some way which allows their required computation to be tuned.
data PBKDFType
    = PBKDF2 MACType Int   -- ^ iterations
    -- ^ PBKDF2 is the “standard” password derivation scheme,
    -- widely implemented in many different libraries.
    | Scrypt  Int Int Int   -- ^ N, r, p
    -- ^ Scrypt is a relatively newer design which is “memory hard”,
    -- in addition to requiring large amounts of CPU power it uses a large block of memory to compute the hash.
    -- This makes brute force attacks using ASICs substantially more expensive.
    | Argon2d Int Int Int   -- ^ iterations, memory, parallelism
    -- ^ Argon2 is the winner of the PHC (Password Hashing Competition) and provides a tunable memory hard PBKDF.
    | Argon2i Int Int Int   -- ^ iterations, memory, parallelism
    | Argon2id Int Int Int  -- ^ iterations, memory, parallelism
    | Bcrypt Int            -- ^ iterations
    | OpenPGP_S2K HashType Int -- ^ iterations
    -- ^ The OpenPGP algorithm is weak and strange, and should be avoided unless implementing OpenPGP.

pbkdfTypeToParam :: PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam :: PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam (PBKDF2 MACType
mt Int
i     ) = ([CBytes] -> CBytes
CB.concat [ CBytes
"PBKDF2(" , MACType -> CBytes
macTypeToCBytes MACType
mt, CBytes
")"], Int
i, Int
0, Int
0)
pbkdfTypeToParam (Scrypt Int
n Int
r Int
p    ) = (CBytes
"Scrypt", Int
n, Int
r, Int
p)
pbkdfTypeToParam (Argon2d Int
i Int
m Int
p   ) = (CBytes
"Argon2d", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Argon2i Int
i Int
m Int
p   ) = (CBytes
"Argon2i", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Argon2id Int
i Int
m Int
p  ) = (CBytes
"Argon2id", Int
i, Int
m, Int
p)
pbkdfTypeToParam (Bcrypt Int
i        ) = (CBytes
"Bcrypt-PBKDF", Int
i, Int
0, Int
0)
pbkdfTypeToParam (OpenPGP_S2K HashType
ht Int
i) = ([CBytes] -> CBytes
CB.concat [ CBytes
"OpenPGP-S2K(" , HashType -> CBytes
hashTypeToCBytes HashType
ht, CBytes
")"], Int
i, Int
0, Int
0)

-- | Derive a key from a passphrase for a number of iterations using the given PBKDF algorithm and params.
pbkdf :: HasCallStack
      => PBKDFType  -- ^ PBKDF algorithm type
      -> Int        -- ^ length of output key
      -> CBytes     -- ^ passphrase
      -> V.Bytes    -- ^ salt
      -> IO V.Bytes
{-# INLINABLE pbkdf #-}
pbkdf :: PBKDFType -> Int -> CBytes -> Bytes -> IO Bytes
pbkdf PBKDFType
typ Int
siz CBytes
pwd Bytes
salt = do
    CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
algo ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
algoBA ->
        CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
pwd ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
pwdBA ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
salt ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
saltBA Int
saltOff Int
saltLen -> do
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                    MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        BA# Word8
-> Int
-> Int
-> Int
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash BA# Word8
algoBA
                            Int
i1 Int
i2 Int
i3
                            MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            BA# Word8
pwdBA (CBytes -> Int
CB.length CBytes
pwd)
                            BA# Word8
saltBA Int
saltOff Int
saltLen)
  where
    (CBytes
algo, Int
i1, Int
i2, Int
i3) = PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam PBKDFType
typ

-- | Derive a key from a passphrase using the given PBKDF algorithm, the iteration params are
-- ignored and PBKDF is run until given milliseconds have passed.
pbkdfTimed :: HasCallStack
           => PBKDFType  -- ^ the name of the given PBKDF algorithm
           -> Int        -- ^ run until milliseconds have passwd
           -> Int        -- ^ length of output key
           -> CBytes     -- ^ passphrase
           -> V.Bytes    -- ^ salt
           -> IO V.Bytes
{-# INLINABLE pbkdfTimed #-}
pbkdfTimed :: PBKDFType -> Int -> Int -> CBytes -> Bytes -> IO Bytes
pbkdfTimed PBKDFType
typ Int
msec Int
siz CBytes
pwd Bytes
s = do
    -- we want run it in new OS thread without stop GC from running
    -- if the expected time is too long(>0.1s)
    if Int
msec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
100
    then CBytes -> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (Ptr Word8 -> IO a) -> IO a
withCBytes CBytes
algo ((Ptr Word8 -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
algo' ->
        CBytes -> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (Ptr Word8 -> IO a) -> IO a
withCBytes CBytes
pwd ((Ptr Word8 -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
pwd' ->
            Bytes -> (Ptr Word8 -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (Ptr a -> Int -> IO b) -> IO b
withPrimVectorSafe Bytes
s ((Ptr Word8 -> Int -> IO Bytes) -> IO Bytes)
-> (Ptr Word8 -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s' Int
sLen ->
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (Ptr a -> IO b) -> IO (PrimVector a, b)
allocPrimVectorSafe Int
siz (\ Ptr Word8
buf -> do
                    Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Int -> IO ()
clearPtr Ptr Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash_timed_safe
                            Ptr Word8
algo' Int
msec Ptr Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            Ptr Word8
pwd' (CBytes -> Int
CB.length CBytes
pwd) Ptr Word8
s' Int
0 Int
sLen)
    else CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
algo ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \BA# Word8
algo' ->
        CBytes -> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a. CBytes -> (BA# Word8 -> IO a) -> IO a
withCBytesUnsafe CBytes
pwd ((BA# Word8 -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \ BA# Word8
pwd' ->
            Bytes -> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b.
Prim a =>
PrimVector a -> (BA# Word8 -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
s ((BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes)
-> (BA# Word8 -> Int -> Int -> IO Bytes) -> IO Bytes
forall a b. (a -> b) -> a -> b
$ \BA# Word8
s' Int
sOff Int
sLen ->
                (Bytes, ()) -> Bytes
forall a b. (a, b) -> a
fst ((Bytes, ()) -> Bytes) -> IO (Bytes, ()) -> IO Bytes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (MBA# Word8 -> IO ()) -> IO (Bytes, ())
forall a b.
Prim a =>
Int -> (MBA# Word8 -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe Int
siz (\ MBA# Word8
buf -> do
                    MBA# Word8 -> Int -> IO ()
forall k (a :: k). MBA# Word8 -> Int -> IO ()
clearMBA MBA# Word8
buf Int
siz
                    IO CInt -> IO ()
forall a. (HasCallStack, Integral a) => IO a -> IO ()
throwBotanIfMinus_ (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
                        BA# Word8
-> Int
-> MBA# Word8
-> Int
-> BA# Word8
-> Int
-> BA# Word8
-> Int
-> Int
-> IO CInt
hs_botan_pwdhash_timed
                            BA# Word8
algo' Int
msec MBA# Word8
buf (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
                            BA# Word8
pwd' (CBytes -> Int
CB.length CBytes
pwd) BA# Word8
s' Int
sOff Int
sLen)
  where
    (CBytes
algo, Int
_, Int
_, Int
_) = PBKDFType -> (CBytes, Int, Int, Int)
pbkdfTypeToParam PBKDFType
typ