-- |
-- Module      : Crypto.PubKey.MaskGenFunction
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
{-# LANGUAGE BangPatterns #-}
module Crypto.PubKey.MaskGenFunction
    ( MaskGenAlgorithm
    , mgf1
    ) where

import           Crypto.Number.Serialize (i2ospOf_)
import           Crypto.Hash
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes)
import qualified Crypto.Internal.ByteArray as B

-- | Represent a mask generation algorithm
type MaskGenAlgorithm seed output =
       seed   -- ^ seed
    -> Int    -- ^ length to generate
    -> output

-- | Mask generation algorithm MGF1
mgf1 :: (ByteArrayAccess seed, ByteArray output, HashAlgorithm hashAlg)
     => hashAlg
     -> seed
     -> Int
     -> output
mgf1 :: forall seed output hashAlg.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hashAlg) =>
hashAlg -> seed -> Int -> output
mgf1 hashAlg
hashAlg seed
seed Int
len =
    let !seededCtx :: Context hashAlg
seededCtx = forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hashAlg
hashAlg) seed
seed
     in forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
len forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. HashAlgorithm a => Context a -> Integer -> Digest a
hashCounter Context hashAlg
seededCtx) [Integer
0..forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
maxCounterforall a. Num a => a -> a -> a
-Int
1)]
  where
    digestLen :: Int
digestLen     = forall a. HashAlgorithm a => a -> Int
hashDigestSize hashAlg
hashAlg
    (Int
chunks,Int
left) = Int
len forall a. Integral a => a -> a -> (a, a)
`divMod` Int
digestLen
    maxCounter :: Int
maxCounter    = if Int
left forall a. Ord a => a -> a -> Bool
> Int
0 then Int
chunks forall a. Num a => a -> a -> a
+ Int
1 else Int
chunks

    hashCounter :: HashAlgorithm a => Context a -> Integer -> Digest a
    hashCounter :: forall a. HashAlgorithm a => Context a -> Integer -> Digest a
hashCounter Context a
ctx Integer
counter = forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize forall a b. (a -> b) -> a -> b
$ forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context a
ctx (forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
4 Integer
counter :: Bytes)