-- |
-- Module      : Crypto.PubKey.MaskGenFunction
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
{-# LANGUAGE BangPatterns #-}
module Crypto.PubKey.MaskGenFunction
    ( MaskGenAlgorithm
    , mgf1
    ) where

import           Crypto.Number.Serialize (i2ospOf_)
import           Crypto.Hash
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes)
import qualified Crypto.Internal.ByteArray as B

-- | Represent a mask generation algorithm
type MaskGenAlgorithm seed output =
       seed   -- ^ seed
    -> Int    -- ^ length to generate
    -> output

-- | Mask generation algorithm MGF1
mgf1 :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hashAlg)
     => hashAlg
     -> seed
     -> Int
     -> output
mgf1 :: hashAlg -> seed -> Int -> output
mgf1 hashAlg
hashAlg seed
seed Int
len =
    let !seededCtx :: Context hashAlg
seededCtx = Context hashAlg -> seed -> Context hashAlg
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (hashAlg -> Context hashAlg
forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hashAlg
hashAlg) seed
seed
     in Int -> output -> output
forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
len (output -> output) -> output -> output
forall a b. (a -> b) -> a -> b
$ [Digest hashAlg] -> output
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat ([Digest hashAlg] -> output) -> [Digest hashAlg] -> output
forall a b. (a -> b) -> a -> b
$ (Integer -> Digest hashAlg) -> [Integer] -> [Digest hashAlg]
forall a b. (a -> b) -> [a] -> [b]
map (Context hashAlg -> Integer -> Digest hashAlg
forall a. HashAlgorithm a => Context a -> Integer -> Digest a
hashCounter Context hashAlg
seededCtx) [Integer
0..Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
maxCounterInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]
  where
    digestLen :: Int
digestLen     = hashAlg -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize hashAlg
hashAlg
    (Int
chunks,Int
left) = Int
len Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
digestLen
    maxCounter :: Int
maxCounter    = if Int
left Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int
chunks Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 else Int
chunks

    hashCounter :: HashAlgorithm a => Context a -> Integer -> Digest a
    hashCounter :: Context a -> Integer -> Digest a
hashCounter Context a
ctx Integer
counter = Context a -> Digest a
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize (Context a -> Digest a) -> Context a -> Digest a
forall a b. (a -> b) -> a -> b
$ Context a -> Bytes -> Context a
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context a
ctx (Int -> Integer -> Bytes
forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
4 Integer
counter :: Bytes)