-- |
-- Module      : Crypto.MAC.HMAC
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Provide the HMAC (Hash based Message Authentification Code) base algorithm.
-- <http://en.wikipedia.org/wiki/HMAC>
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.MAC.HMAC
    ( hmac
    , HMAC(..)
    -- * Incremental
    , Context(..)
    , initialize
    , update
    , updates
    , finalize
    ) where

import           Crypto.Hash hiding (Context)
import qualified Crypto.Hash as Hash (Context)
import           Crypto.Hash.IO
import           Crypto.Internal.ByteArray (ScrubbedBytes, ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B
import           Data.Memory.PtrMethods
import           Crypto.Internal.Compat

-- | Represent an HMAC that is a phantom type with the hash used to produce the mac.
--
-- The Eq instance is constant time.  No Show instance is provided, to avoid
-- printing by mistake.
newtype HMAC a = HMAC { HMAC a -> Digest a
hmacGetDigest :: Digest a }
    deriving (HMAC a -> Int
HMAC a -> Ptr p -> IO ()
HMAC a -> (Ptr p -> IO a) -> IO a
(HMAC a -> Int)
-> (forall p a. HMAC a -> (Ptr p -> IO a) -> IO a)
-> (forall p. HMAC a -> Ptr p -> IO ())
-> ByteArrayAccess (HMAC a)
forall a. HMAC a -> Int
forall p. HMAC a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall a p. HMAC a -> Ptr p -> IO ()
forall p a. HMAC a -> (Ptr p -> IO a) -> IO a
forall a p a. HMAC a -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: HMAC a -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall a p. HMAC a -> Ptr p -> IO ()
withByteArray :: HMAC a -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall a p a. HMAC a -> (Ptr p -> IO a) -> IO a
length :: HMAC a -> Int
$clength :: forall a. HMAC a -> Int
ByteArrayAccess)

instance Eq (HMAC a) where
    (HMAC Digest a
b1) == :: HMAC a -> HMAC a -> Bool
== (HMAC Digest a
b2) = Digest a -> Digest a -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
B.constEq Digest a
b1 Digest a
b2

-- | compute a MAC using the supplied hashing function
hmac :: (ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a)
     => key     -- ^ Secret key
     -> message -- ^ Message to MAC
     -> HMAC a
hmac :: key -> message -> HMAC a
hmac key
secret message
msg = Context a -> HMAC a
forall a. HashAlgorithm a => Context a -> HMAC a
finalize (Context a -> HMAC a) -> Context a -> HMAC a
forall a b. (a -> b) -> a -> b
$ Context a -> [message] -> Context a
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> [message] -> Context a
updates (key -> Context a
forall key a.
(ByteArrayAccess key, HashAlgorithm a) =>
key -> Context a
initialize key
secret) [message
msg]

-- | Represent an ongoing HMAC state, that can be appended with 'update'
-- and finalize to an HMAC with 'hmacFinalize'
data Context hashalg = Context !(Hash.Context hashalg) !(Hash.Context hashalg)

-- | Initialize a new incremental HMAC context
initialize :: (ByteArrayAccess key, HashAlgorithm a)
           => key       -- ^ Secret key
           -> Context a
initialize :: key -> Context a
initialize key
secret = IO (Context a) -> Context a
forall a. IO a -> a
unsafeDoIO (a -> IO (Context a)
forall a. HashAlgorithm a => a -> IO (Context a)
doHashAlg a
forall a. HasCallStack => a
undefined)
  where
        doHashAlg :: HashAlgorithm a => a -> IO (Context a)
        doHashAlg :: a -> IO (Context a)
doHashAlg a
alg = do
            !(Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
withKey <- case key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
secret Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
blockSize of
                            Ordering
EQ -> ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
  -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO
      ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
       -> IO (ScrubbedBytes, ScrubbedBytes)))
-> ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
    -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall a b. (a -> b) -> a -> b
$ key
-> (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray key
secret
                            Ordering
LT -> do ScrubbedBytes
key <- Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
blockSize ((Ptr Word8 -> IO ()) -> IO ScrubbedBytes)
-> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
k -> do
                                        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
k Word8
0 Int
blockSize
                                        key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray key
secret ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
k Ptr Word8
s (key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
secret)
                                     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
  -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO
      ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
       -> IO (ScrubbedBytes, ScrubbedBytes)))
-> ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
    -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes
-> (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray (ScrubbedBytes
key :: ScrubbedBytes)
                            Ordering
GT -> do
                                -- hash the secret key
                                MutableContext a
ctx <- a -> IO (MutableContext a)
forall alg. HashAlgorithm alg => alg -> IO (MutableContext alg)
hashMutableInitWith a
alg
                                MutableContext a -> key -> IO ()
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
MutableContext a -> ba -> IO ()
hashMutableUpdate MutableContext a
ctx key
secret
                                Digest a
digest <- MutableContext a -> IO (Digest a)
forall a. HashAlgorithm a => MutableContext a -> IO (Digest a)
hashMutableFinalize MutableContext a
ctx
                                MutableContext a -> IO ()
forall a. HashAlgorithm a => MutableContext a -> IO ()
hashMutableReset MutableContext a
ctx
                                -- pad it if necessary
                                if Int
digestSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
blockSize
                                    then do
                                        ScrubbedBytes
key <- Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
blockSize ((Ptr Word8 -> IO ()) -> IO ScrubbedBytes)
-> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
k -> do
                                            Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
k Word8
0 Int
blockSize
                                            Digest a -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest a
digest ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
k Ptr Word8
s (Digest a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Digest a
digest)
                                        ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
  -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO
      ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
       -> IO (ScrubbedBytes, ScrubbedBytes)))
-> ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
    -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes
-> (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray (ScrubbedBytes
key :: ScrubbedBytes)
                                    else
                                       ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
  -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO
      ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
       -> IO (ScrubbedBytes, ScrubbedBytes)))
-> ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
    -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO
     ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
      -> IO (ScrubbedBytes, ScrubbedBytes))
forall a b. (a -> b) -> a -> b
$ Digest a
-> (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest a
digest
            (ScrubbedBytes
inner, ScrubbedBytes
outer) <- (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
withKey ((Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
 -> IO (ScrubbedBytes, ScrubbedBytes))
-> (Ptr Word8 -> IO (ScrubbedBytes, ScrubbedBytes))
-> IO (ScrubbedBytes, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
                (,) (ScrubbedBytes -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes))
-> IO ScrubbedBytes
-> IO (ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
blockSize (\Ptr Word8
p -> Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith Ptr Word8
p Word8
0x36 Ptr Word8
keyPtr Int
blockSize)
                    IO (ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes))
-> IO ScrubbedBytes -> IO (ScrubbedBytes, ScrubbedBytes)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
blockSize (\Ptr Word8
p -> Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith Ptr Word8
p Word8
0x5c Ptr Word8
keyPtr Int
blockSize)
            Context a -> IO (Context a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Context a -> IO (Context a)) -> Context a -> IO (Context a)
forall a b. (a -> b) -> a -> b
$ Context a -> Context a -> Context a
forall hashalg.
Context hashalg -> Context hashalg -> Context hashalg
Context (Context a -> [ScrubbedBytes] -> Context a
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
hashUpdates Context a
initCtx [ScrubbedBytes
outer :: ScrubbedBytes])
                             (Context a -> [ScrubbedBytes] -> Context a
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
hashUpdates Context a
initCtx [ScrubbedBytes
inner :: ScrubbedBytes])
          where 
                blockSize :: Int
blockSize  = a -> Int
forall a. HashAlgorithm a => a -> Int
hashBlockSize a
alg
                digestSize :: Int
digestSize = a -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize a
alg
                initCtx :: Context a
initCtx    = a -> Context a
forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith a
alg
{-# NOINLINE initialize #-}

-- | Incrementally update a HMAC context
update :: (ByteArrayAccess message, HashAlgorithm a)
       => Context a  -- ^ Current HMAC context
       -> message    -- ^ Message to append to the MAC
       -> Context a  -- ^ Updated HMAC context
update :: Context a -> message -> Context a
update (Context Context a
octx Context a
ictx) message
msg =
    Context a -> Context a -> Context a
forall hashalg.
Context hashalg -> Context hashalg -> Context hashalg
Context Context a
octx (Context a -> message -> Context a
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context a
ictx message
msg)

-- | Increamentally update a HMAC context with multiple inputs
updates :: (ByteArrayAccess message, HashAlgorithm a)
        => Context a -- ^ Current HMAC context
        -> [message] -- ^ Messages to append to the MAC
        -> Context a -- ^ Updated HMAC context
updates :: Context a -> [message] -> Context a
updates (Context Context a
octx Context a
ictx) [message]
msgs =
    Context a -> Context a -> Context a
forall hashalg.
Context hashalg -> Context hashalg -> Context hashalg
Context Context a
octx (Context a -> [message] -> Context a
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
hashUpdates Context a
ictx [message]
msgs)

-- | Finalize a HMAC context and return the HMAC.
finalize :: HashAlgorithm a
         => Context a
         -> HMAC a
finalize :: Context a -> HMAC a
finalize (Context Context a
octx Context a
ictx) =
    Digest a -> HMAC a
forall a. Digest a -> HMAC a
HMAC (Digest a -> HMAC a) -> Digest a -> HMAC a
forall a b. (a -> b) -> a -> b
$ Context a -> Digest a
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize (Context a -> Digest a) -> Context a -> Digest a
forall a b. (a -> b) -> a -> b
$ Context a -> [Digest a] -> Context a
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
Context a -> [ba] -> Context a
hashUpdates Context a
octx [Context a -> Digest a
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize Context a
ictx]