{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Context.Internal
    (
    
      ClientParams(..)
    , ServerParams(..)
    , defaultParamsClient
    , SessionID
    , SessionData(..)
    , MaxFragmentEnum(..)
    , Measurement(..)
    
    , Context(..)
    , Hooks(..)
    , Established(..)
    , PendingAction(..)
    , ctxEOF
    , ctxHasSSLv2ClientHello
    , ctxDisableSSLv2ClientHello
    , ctxEstablished
    , withLog
    , ctxWithHooks
    , contextModifyHooks
    , setEOF
    , setEstablished
    , contextFlush
    , contextClose
    , contextSend
    , contextRecv
    , updateMeasure
    , withMeasure
    , withReadLock
    , withWriteLock
    , withStateLock
    , withRWLock
    
    , Information(..)
    , contextGetInformation
    
    , throwCore
    , failOnEitherError
    , usingState
    , usingState_
    , runTxState
    , runRxState
    , usingHState
    , getHState
    , saveHState
    , restoreHState
    , getStateRNG
    , tls13orLater
    , addCertRequest13
    , getCertRequest13
    ) where
import Network.TLS.Backend
import Network.TLS.Extension
import Network.TLS.Cipher
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Compression (Compression)
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util
import qualified Data.ByteString as B
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Control.Exception (throwIO, Exception())
import Data.IORef
import Data.Tuple
data Information = Information
    { infoVersion      :: Version
    , infoCipher       :: Cipher
    , infoCompression  :: Compression
    , infoMasterSecret :: Maybe ByteString
    , infoClientRandom :: Maybe ClientRandom
    , infoServerRandom :: Maybe ServerRandom
    , infoNegotiatedGroup     :: Maybe Group
    , infoTLS13HandshakeMode  :: Maybe HandshakeMode13
    , infoIsEarlyDataAccepted :: Bool
    } deriving (Show,Eq)
data Context = Context
    { ctxConnection       :: Backend   
    , ctxSupported        :: Supported
    , ctxShared           :: Shared
    , ctxState            :: MVar TLSState
    , ctxMeasurement      :: IORef Measurement
    , ctxEOF_             :: IORef Bool    
    , ctxEstablished_     :: IORef Established 
    , ctxNeedEmptyPacket  :: IORef Bool    
    , ctxSSLv2ClientHello :: IORef Bool    
                                           
                                           
    , ctxTxState          :: MVar RecordState 
    , ctxRxState          :: MVar RecordState 
    , ctxHandshake        :: MVar (Maybe HandshakeState) 
    , ctxDoHandshake      :: Context -> IO ()
    , ctxDoHandshakeWith  :: Context -> Handshake -> IO ()
    , ctxDoRequestCertificate :: Context -> IO Bool
    , ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
    , ctxHooks            :: IORef Hooks   
    , ctxLockWrite        :: MVar ()       
    , ctxLockRead         :: MVar ()       
    , ctxLockState        :: MVar ()       
                                           
    , ctxPendingActions   :: IORef [PendingAction]
    , ctxCertRequests     :: IORef [Handshake13]  
    , ctxKeyLogger        :: String -> IO ()
    }
data Established = NotEstablished
                 | EarlyDataAllowed Int    
                 | EarlyDataNotAllowed Int 
                 | Established
                 deriving (Eq, Show)
data PendingAction
    = PendingAction Bool (Handshake13 -> IO ())
      
    | PendingActionHash Bool (ByteString -> Handshake13 -> IO ())
      
updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
updateMeasure ctx f = do
    x <- readIORef (ctxMeasurement ctx)
    writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: Context -> (Measurement -> IO a) -> IO a
withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation ctx = do
    ver    <- usingState_ ctx $ gets stVersion
    hstate <- getHState ctx
    let (ms, cr, sr, hm13, grp) = case hstate of
                           Just st -> (hstMasterSecret st,
                                       Just (hstClientRandom st),
                                       hstServerRandom st,
                                       if ver == Just TLS13 then Just (hstTLS13HandshakeMode st) else Nothing,
                                       hstNegotiatedGroup st)
                           Nothing -> (Nothing, Nothing, Nothing, Nothing, Nothing)
    (cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st)
    let accepted = case hstate of
            Just st -> hstTLS13RTT0Status st == RTT0Accepted
            Nothing -> False
    case (ver, cipher) of
        (Just v, Just c) -> return $ Just $ Information v c comp ms cr sr grp hm13 accepted
        _                -> return Nothing
contextSend :: Context -> ByteString -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO ByteString
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: Context -> IO Bool
ctxEOF ctx = readIORef $ ctxEOF_ ctx
ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx
ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False
setEOF :: Context -> IO ()
setEOF ctx = writeIORef (ctxEOF_ ctx) True
ctxEstablished :: Context -> IO Established
ctxEstablished ctx = readIORef $ ctxEstablished_ ctx
ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f
contextModifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks ctx = modifyIORef (ctxHooks ctx)
setEstablished :: Context -> Established -> IO ()
setEstablished ctx = writeIORef (ctxEstablished_ ctx)
withLog :: Context -> (Logging -> IO ()) -> IO ()
withLog ctx f = ctxWithHooks ctx (f . hookLogging)
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
failOnEitherError f = do
    ret <- f
    case ret of
        Left err -> throwCore err
        Right r  -> return r
usingState :: Context -> TLSSt a -> IO (Either TLSError a)
usingState ctx f =
    modifyMVar (ctxState ctx) $ \st ->
            let (a, newst) = runTLSState f st
             in newst `seq` return (newst, a)
usingState_ :: Context -> TLSSt a -> IO a
usingState_ ctx f = failOnEitherError $ usingState ctx f
usingHState :: MonadIO m => Context -> HandshakeM a -> m a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
    case mst of
        Nothing -> throwCore $ Error_Misc "missing handshake"
        Just st -> return $ swap (Just <$> runHandshake st f)
getHState :: MonadIO m => Context -> m (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
saveHState :: Context -> IO (Saved (Maybe HandshakeState))
saveHState ctx = saveMVar (ctxHandshake ctx)
restoreHState :: Context
              -> Saved (Maybe HandshakeState)
              -> IO (Saved (Maybe HandshakeState))
restoreHState ctx = restoreMVar (ctxHandshake ctx)
runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState ctx f = do
    ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
    hrr <- usingState_ ctx getTLS13HRR
    
    
    
    let ver'
         | ver >= TLS13 = if hrr then TLS12 else TLS10
         | otherwise    = ver
        opt = RecordOptions { recordVersion = ver'
                            , recordTLS13   = ver >= TLS13
                            }
    modifyMVar (ctxTxState ctx) $ \st ->
        case runRecordM f opt st of
            Left err         -> return (st, Left err)
            Right (a, newSt) -> return (newSt, Right a)
runRxState :: Context -> RecordM a -> IO (Either TLSError a)
runRxState ctx f = do
    ver <- usingState_ ctx getVersion
    
    let opt = RecordOptions { recordVersion = ver
                            , recordTLS13   = ver >= TLS13
                            }
    modifyMVar (ctxRxState ctx) $ \st ->
        case runRecordM f opt st of
            Left err         -> return (st, Left err)
            Right (a, newSt) -> return (newSt, Right a)
getStateRNG :: Context -> Int -> IO ByteString
getStateRNG ctx n = usingState_ ctx $ genRandom n
withReadLock :: Context -> IO a -> IO a
withReadLock ctx f = withMVar (ctxLockRead ctx) (const f)
withWriteLock :: Context -> IO a -> IO a
withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f)
withRWLock :: Context -> IO a -> IO a
withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f
withStateLock :: Context -> IO a -> IO a
withStateLock ctx f = withMVar (ctxLockState ctx) (const f)
tls13orLater :: MonadIO m => Context -> m Bool
tls13orLater ctx = do
    ev <- liftIO $ usingState ctx $ getVersionWithDefault TLS10 
    return $ case ev of
               Left  _ -> False
               Right v -> v >= TLS13
addCertRequest13 :: Context -> Handshake13 -> IO ()
addCertRequest13 ctx certReq = modifyIORef (ctxCertRequests ctx) (certReq:)
getCertRequest13 :: Context -> CertReqContext -> IO (Maybe Handshake13)
getCertRequest13 ctx context = do
    let ref = ctxCertRequests ctx
    l <- readIORef ref
    let (matched, others) = partition (\(CertRequest13 c _) -> context == c) l
    case matched of
        []          -> return Nothing
        (certReq:_) -> writeIORef ref others >> return (Just certReq)