-- |
-- Module      : Network.TLS.MAC
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.MAC
    ( macSSL
    , hmac
    , prf_MD5
    , prf_SHA1
    , prf_SHA256
    , prf_TLS
    , prf_MD5SHA1
    ) where

import Network.TLS.Crypto
import Network.TLS.Types
import Network.TLS.Imports
import qualified Data.ByteArray as B (xor)
import qualified Data.ByteString as B

type HMAC = ByteString -> ByteString -> ByteString

macSSL :: Hash -> HMAC
macSSL :: Hash -> HMAC
macSSL Hash
alg ByteString
secret ByteString
msg =
    ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$! [ByteString] -> ByteString
B.concat
        [ ByteString
secret
        , Int -> Word8 -> ByteString
B.replicate Int
padLen Word8
0x5c
        , ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$! [ByteString] -> ByteString
B.concat [ ByteString
secret, Int -> Word8 -> ByteString
B.replicate Int
padLen Word8
0x36, ByteString
msg ]
        ]
  where
    padLen :: Int
padLen = case Hash
alg of
        Hash
MD5  -> Int
48
        Hash
SHA1 -> Int
40
        Hash
_    -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char]
"internal error: macSSL called with " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Hash -> [Char]
forall a. Show a => a -> [Char]
show Hash
alg)
    f :: ByteString -> ByteString
f = Hash -> ByteString -> ByteString
hash Hash
alg

hmac :: Hash -> HMAC
hmac :: Hash -> HMAC
hmac Hash
alg ByteString
secret ByteString
msg = ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$! HMAC
B.append ByteString
opad (ByteString -> ByteString
f (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$! HMAC
B.append ByteString
ipad ByteString
msg)
  where opad :: ByteString
opad = (Word8 -> Word8) -> ByteString -> ByteString
B.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
0x5c) ByteString
k'
        ipad :: ByteString
ipad = (Word8 -> Word8) -> ByteString -> ByteString
B.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
0x36) ByteString
k'

        f :: ByteString -> ByteString
f = Hash -> ByteString -> ByteString
hash Hash
alg
        bl :: Int
bl = Hash -> Int
hashBlockSize Hash
alg

        k' :: ByteString
k' = HMAC
B.append ByteString
kt ByteString
pad
          where kt :: ByteString
kt  = if ByteString -> Int
B.length ByteString
secret Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl then ByteString -> ByteString
f ByteString
secret else ByteString
secret
                pad :: ByteString
pad = Int -> Word8 -> ByteString
B.replicate (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
kt) Word8
0

hmacIter :: HMAC -> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter :: HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter HMAC
f ByteString
secret ByteString
seed ByteString
aprev Int
len =
    let an :: ByteString
an = HMAC
f ByteString
secret ByteString
aprev in
    let out :: ByteString
out = HMAC
f ByteString
secret ([ByteString] -> ByteString
B.concat [ByteString
an, ByteString
seed]) in
    let digestsize :: Int
digestsize = ByteString -> Int
B.length ByteString
out in
    if Int
digestsize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
        then [ Int -> ByteString -> ByteString
B.take (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) ByteString
out ]
        else ByteString
out ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter HMAC
f ByteString
secret ByteString
seed ByteString
an (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
digestsize)

prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA1 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
SHA1) ByteString
secret ByteString
seed ByteString
seed Int
len

prf_MD5 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
MD5) ByteString
secret ByteString
seed ByteString
seed Int
len

prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5SHA1 :: ByteString -> ByteString -> Int -> ByteString
prf_MD5SHA1 ByteString
secret ByteString
seed Int
len =
    HMAC
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor (ByteString -> ByteString -> Int -> ByteString
prf_MD5 ByteString
s1 ByteString
seed Int
len) (ByteString -> ByteString -> Int -> ByteString
prf_SHA1 ByteString
s2 ByteString
seed Int
len)
  where slen :: Int
slen  = ByteString -> Int
B.length ByteString
secret
        s1 :: ByteString
s1    = Int -> ByteString -> ByteString
B.take (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
2) ByteString
secret
        s2 :: ByteString
s2    = Int -> ByteString -> ByteString
B.drop (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ByteString
secret

prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA256 :: ByteString -> ByteString -> Int -> ByteString
prf_SHA256 ByteString
secret ByteString
seed Int
len = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
SHA256) ByteString
secret ByteString
seed ByteString
seed Int
len

-- | For now we ignore the version, but perhaps some day the PRF will depend
-- not only on the cipher PRF algorithm, but also on the protocol version.
prf_TLS :: Version -> Hash -> ByteString -> ByteString -> Int -> ByteString
prf_TLS :: Version -> Hash -> ByteString -> ByteString -> Int -> ByteString
prf_TLS Version
_ Hash
halg ByteString
secret ByteString
seed Int
len =
    [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC
-> ByteString -> ByteString -> ByteString -> Int -> [ByteString]
hmacIter (Hash -> HMAC
hmac Hash
halg) ByteString
secret ByteString
seed ByteString
seed Int
len