{-# LANGUAGE OverloadedStrings #-}
-- |
-- Module      : Network.TLS.KeySchedule
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.KeySchedule
    ( hkdfExtract
    , hkdfExpandLabel
    , deriveSecret
    ) where

import qualified Crypto.Hash as H
import Crypto.KDF.HKDF
import Data.ByteArray (convert)
import qualified Data.ByteString as BS
import Network.TLS.Crypto
import Network.TLS.Wire
import Network.TLS.Imports

----------------------------------------------------------------

-- | @HKDF-Extract@ function.  Returns the pseudorandom key (PRK) from salt and
-- input keying material (IKM).
hkdfExtract :: Hash -> ByteString -> ByteString -> ByteString
hkdfExtract :: Hash -> ByteString -> ByteString -> ByteString
hkdfExtract Hash
SHA1   ByteString
salt ByteString
ikm = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
extract ByteString
salt ByteString
ikm :: PRK H.SHA1)
hkdfExtract Hash
SHA256 ByteString
salt ByteString
ikm = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
extract ByteString
salt ByteString
ikm :: PRK H.SHA256)
hkdfExtract Hash
SHA384 ByteString
salt ByteString
ikm = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
extract ByteString
salt ByteString
ikm :: PRK H.SHA384)
hkdfExtract Hash
SHA512 ByteString
salt ByteString
ikm = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (forall a salt ikm.
(HashAlgorithm a, ByteArrayAccess salt, ByteArrayAccess ikm) =>
salt -> ikm -> PRK a
extract ByteString
salt ByteString
ikm :: PRK H.SHA512)
hkdfExtract Hash
_ ByteString
_ ByteString
_           = forall a. HasCallStack => [Char] -> a
error [Char]
"hkdfExtract: unsupported hash"

----------------------------------------------------------------

deriveSecret :: Hash -> ByteString -> ByteString -> ByteString -> ByteString
deriveSecret :: Hash -> ByteString -> ByteString -> ByteString -> ByteString
deriveSecret Hash
h ByteString
secret ByteString
label ByteString
hashedMsgs =
    Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
label ByteString
hashedMsgs Int
outlen
  where
    outlen :: Int
outlen = Hash -> Int
hashDigestSize Hash
h

----------------------------------------------------------------

-- | @HKDF-Expand-Label@ function.  Returns output keying material of the
-- specified length from the PRK, customized for a TLS label and context.
hkdfExpandLabel :: Hash
                -> ByteString
                -> ByteString
                -> ByteString
                -> Int
                -> ByteString
hkdfExpandLabel :: Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
label ByteString
ctx Int
outlen = Hash -> ByteString -> ByteString -> Int -> ByteString
expand' Hash
h ByteString
secret ByteString
hkdfLabel Int
outlen
  where
    hkdfLabel :: ByteString
hkdfLabel = Put -> ByteString
runPut forall a b. (a -> b) -> a -> b
$ do
        Word16 -> Put
putWord16 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
outlen
        ByteString -> Put
putOpaque8 (ByteString
"tls13 " ByteString -> ByteString -> ByteString
`BS.append` ByteString
label)
        ByteString -> Put
putOpaque8 ByteString
ctx

expand' :: Hash -> ByteString -> ByteString -> Int -> ByteString
expand' :: Hash -> ByteString -> ByteString -> Int -> ByteString
expand' Hash
SHA1   ByteString
secret ByteString
label Int
len = forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand (forall ikm a. ByteArrayAccess ikm => ikm -> PRK a
extractSkip ByteString
secret :: PRK H.SHA1)   ByteString
label Int
len
expand' Hash
SHA256 ByteString
secret ByteString
label Int
len = forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand (forall ikm a. ByteArrayAccess ikm => ikm -> PRK a
extractSkip ByteString
secret :: PRK H.SHA256) ByteString
label Int
len
expand' Hash
SHA384 ByteString
secret ByteString
label Int
len = forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand (forall ikm a. ByteArrayAccess ikm => ikm -> PRK a
extractSkip ByteString
secret :: PRK H.SHA384) ByteString
label Int
len
expand' Hash
SHA512 ByteString
secret ByteString
label Int
len = forall a info out.
(HashAlgorithm a, ByteArrayAccess info, ByteArray out) =>
PRK a -> info -> Int -> out
expand (forall ikm a. ByteArrayAccess ikm => ikm -> PRK a
extractSkip ByteString
secret :: PRK H.SHA512) ByteString
label Int
len
expand' Hash
_ ByteString
_ ByteString
_ Int
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"expand'"

----------------------------------------------------------------