{-# LANGUAGE ExistentialQuantification #-}
{-# OPTIONS_HADDOCK hide #-}

module Network.TLS.Cipher (
    CipherKeyExchangeType (..),
    Bulk (..),
    BulkFunctions (..),
    BulkDirection (..),
    BulkState (..),
    BulkStream (..),
    BulkBlock,
    BulkAEAD,
    bulkInit,
    Hash (..),
    Cipher (..),
    CipherID,
    cipherKeyBlockSize,
    BulkKey,
    BulkIV,
    BulkNonce,
    BulkAdditionalData,
    cipherAllowedForVersion,
    hasMAC,
    hasRecordIV,
) where

import Crypto.Cipher.Types (AuthTag)
import Network.TLS.Crypto (Hash (..), hashDigestSize)
import Network.TLS.Types (CipherID, Version (..))

import qualified Data.ByteString as B

-- FIXME convert to newtype
type BulkKey = B.ByteString
type BulkIV = B.ByteString
type BulkNonce = B.ByteString
type BulkAdditionalData = B.ByteString

data BulkState
    = BulkStateStream BulkStream
    | BulkStateBlock BulkBlock
    | BulkStateAEAD BulkAEAD
    | BulkStateUninitialized

instance Show BulkState where
    show :: BulkState -> String
show (BulkStateStream BulkStream
_) = String
"BulkStateStream"
    show (BulkStateBlock BulkBlock
_) = String
"BulkStateBlock"
    show (BulkStateAEAD BulkAEAD
_) = String
"BulkStateAEAD"
    show BulkState
BulkStateUninitialized = String
"BulkStateUninitialized"

newtype BulkStream = BulkStream (B.ByteString -> (B.ByteString, BulkStream))

type BulkBlock = BulkIV -> B.ByteString -> (B.ByteString, BulkIV)

type BulkAEAD =
    BulkNonce -> B.ByteString -> BulkAdditionalData -> (B.ByteString, AuthTag)

data BulkDirection = BulkEncrypt | BulkDecrypt
    deriving (Int -> BulkDirection -> ShowS
[BulkDirection] -> ShowS
BulkDirection -> String
(Int -> BulkDirection -> ShowS)
-> (BulkDirection -> String)
-> ([BulkDirection] -> ShowS)
-> Show BulkDirection
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BulkDirection -> ShowS
showsPrec :: Int -> BulkDirection -> ShowS
$cshow :: BulkDirection -> String
show :: BulkDirection -> String
$cshowList :: [BulkDirection] -> ShowS
showList :: [BulkDirection] -> ShowS
Show, BulkDirection -> BulkDirection -> Bool
(BulkDirection -> BulkDirection -> Bool)
-> (BulkDirection -> BulkDirection -> Bool) -> Eq BulkDirection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BulkDirection -> BulkDirection -> Bool
== :: BulkDirection -> BulkDirection -> Bool
$c/= :: BulkDirection -> BulkDirection -> Bool
/= :: BulkDirection -> BulkDirection -> Bool
Eq)

bulkInit :: Bulk -> BulkDirection -> BulkKey -> BulkState
bulkInit :: Bulk -> BulkDirection -> BulkKey -> BulkState
bulkInit Bulk
bulk BulkDirection
direction BulkKey
key =
    case Bulk -> BulkFunctions
bulkF Bulk
bulk of
        BulkBlockF BulkDirection -> BulkKey -> BulkBlock
ini -> BulkBlock -> BulkState
BulkStateBlock (BulkDirection -> BulkKey -> BulkBlock
ini BulkDirection
direction BulkKey
key)
        BulkStreamF BulkDirection -> BulkKey -> BulkStream
ini -> BulkStream -> BulkState
BulkStateStream (BulkDirection -> BulkKey -> BulkStream
ini BulkDirection
direction BulkKey
key)
        BulkAeadF BulkDirection -> BulkKey -> BulkAEAD
ini -> BulkAEAD -> BulkState
BulkStateAEAD (BulkDirection -> BulkKey -> BulkAEAD
ini BulkDirection
direction BulkKey
key)

data BulkFunctions
    = BulkBlockF (BulkDirection -> BulkKey -> BulkBlock)
    | BulkStreamF (BulkDirection -> BulkKey -> BulkStream)
    | BulkAeadF (BulkDirection -> BulkKey -> BulkAEAD)

hasMAC, hasRecordIV :: BulkFunctions -> Bool
hasMAC :: BulkFunctions -> Bool
hasMAC (BulkBlockF BulkDirection -> BulkKey -> BulkBlock
_) = Bool
True
hasMAC (BulkStreamF BulkDirection -> BulkKey -> BulkStream
_) = Bool
True
hasMAC (BulkAeadF BulkDirection -> BulkKey -> BulkAEAD
_) = Bool
False
hasRecordIV :: BulkFunctions -> Bool
hasRecordIV = BulkFunctions -> Bool
hasMAC

data CipherKeyExchangeType
    = CipherKeyExchange_RSA
    | CipherKeyExchange_DH_Anon
    | CipherKeyExchange_DHE_RSA
    | CipherKeyExchange_ECDHE_RSA
    | CipherKeyExchange_DHE_DSA
    | CipherKeyExchange_DH_DSA
    | CipherKeyExchange_DH_RSA
    | CipherKeyExchange_ECDH_ECDSA
    | CipherKeyExchange_ECDH_RSA
    | CipherKeyExchange_ECDHE_ECDSA
    | CipherKeyExchange_TLS13 -- not expressed in cipher suite
    deriving (Int -> CipherKeyExchangeType -> ShowS
[CipherKeyExchangeType] -> ShowS
CipherKeyExchangeType -> String
(Int -> CipherKeyExchangeType -> ShowS)
-> (CipherKeyExchangeType -> String)
-> ([CipherKeyExchangeType] -> ShowS)
-> Show CipherKeyExchangeType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CipherKeyExchangeType -> ShowS
showsPrec :: Int -> CipherKeyExchangeType -> ShowS
$cshow :: CipherKeyExchangeType -> String
show :: CipherKeyExchangeType -> String
$cshowList :: [CipherKeyExchangeType] -> ShowS
showList :: [CipherKeyExchangeType] -> ShowS
Show, CipherKeyExchangeType -> CipherKeyExchangeType -> Bool
(CipherKeyExchangeType -> CipherKeyExchangeType -> Bool)
-> (CipherKeyExchangeType -> CipherKeyExchangeType -> Bool)
-> Eq CipherKeyExchangeType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CipherKeyExchangeType -> CipherKeyExchangeType -> Bool
== :: CipherKeyExchangeType -> CipherKeyExchangeType -> Bool
$c/= :: CipherKeyExchangeType -> CipherKeyExchangeType -> Bool
/= :: CipherKeyExchangeType -> CipherKeyExchangeType -> Bool
Eq)

data Bulk = Bulk
    { Bulk -> String
bulkName :: String
    , Bulk -> Int
bulkKeySize :: Int
    , Bulk -> Int
bulkIVSize :: Int
    , Bulk -> Int
bulkExplicitIV :: Int -- Explicit size for IV for AEAD Cipher, 0 otherwise
    , Bulk -> Int
bulkAuthTagLen :: Int -- Authentication tag length in bytes for AEAD Cipher, 0 otherwise
    , Bulk -> Int
bulkBlockSize :: Int
    , Bulk -> BulkFunctions
bulkF :: BulkFunctions
    }

instance Show Bulk where
    show :: Bulk -> String
show Bulk
bulk = Bulk -> String
bulkName Bulk
bulk
instance Eq Bulk where
    Bulk
b1 == :: Bulk -> Bulk -> Bool
== Bulk
b2 =
        [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
            [ Bulk -> String
bulkName Bulk
b1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Bulk -> String
bulkName Bulk
b2
            , Bulk -> Int
bulkKeySize Bulk
b1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Bulk -> Int
bulkKeySize Bulk
b2
            , Bulk -> Int
bulkIVSize Bulk
b1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Bulk -> Int
bulkIVSize Bulk
b2
            , Bulk -> Int
bulkBlockSize Bulk
b1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Bulk -> Int
bulkBlockSize Bulk
b2
            ]

-- | Cipher algorithm
data Cipher = Cipher
    { Cipher -> CipherID
cipherID :: CipherID
    , Cipher -> String
cipherName :: String
    , Cipher -> Hash
cipherHash :: Hash
    , Cipher -> Bulk
cipherBulk :: Bulk
    , Cipher -> CipherKeyExchangeType
cipherKeyExchange :: CipherKeyExchangeType
    , Cipher -> Maybe Version
cipherMinVer :: Maybe Version
    , Cipher -> Maybe Hash
cipherPRFHash :: Maybe Hash
    }

cipherKeyBlockSize :: Cipher -> Int
cipherKeyBlockSize :: Cipher -> Int
cipherKeyBlockSize Cipher
cipher = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Hash -> Int
hashDigestSize (Cipher -> Hash
cipherHash Cipher
cipher) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkKeySize Bulk
bulk)
  where
    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher

-- | Check if a specific 'Cipher' is allowed to be used
-- with the version specified
cipherAllowedForVersion :: Version -> Cipher -> Bool
cipherAllowedForVersion :: Version -> Cipher -> Bool
cipherAllowedForVersion Version
ver Cipher
cipher =
    case Cipher -> Maybe Version
cipherMinVer Cipher
cipher of
        Maybe Version
Nothing -> Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13
        Just Version
cVer -> Version
cVer Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
ver Bool -> Bool -> Bool
&& (Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 Bool -> Bool -> Bool
|| Version
cVer Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS13)

instance Show Cipher where
    show :: Cipher -> String
show Cipher
c = Cipher -> String
cipherName Cipher
c

instance Eq Cipher where
    == :: Cipher -> Cipher -> Bool
(==) Cipher
c1 Cipher
c2 = Cipher -> CipherID
cipherID Cipher
c1 CipherID -> CipherID -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> CipherID
cipherID Cipher
c2