module Network.TLS.Context
	(
	
	  TLSParams(..)
	, TLSLogging(..)
	, SessionData(..)
	, Measurement(..)
	, TLSCertificateUsage(..)
	, TLSCertificateRejectReason(..)
	, defaultLogging
	, defaultParams
	
	, TLSCtx
	, ctxParams
	, ctxConnection
	, ctxEOF
	, ctxEstablished
	, ctxLogging
	, setEOF
	, setEstablished
	, connectionFlush
	, connectionSend
	, connectionRecv
	, updateMeasure
	, withMeasure
	
	, newCtxWith
	, newCtx
	
	, throwCore
	, usingState
	, usingState_
	, getStateRNG
	) where
import Network.TLS.Struct
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.State
import Network.TLS.Measurement
import Data.Maybe
import Data.Certificate.X509
import Data.List (intercalate)
import qualified Data.ByteString as B
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush)
import Prelude hiding (catch)
data TLSLogging = TLSLogging
	{ loggingPacketSent :: String -> IO ()
	, loggingPacketRecv :: String -> IO ()
	, loggingIOSent     :: B.ByteString -> IO ()
	, loggingIORecv     :: Header -> B.ByteString -> IO ()
	}
data TLSParams = TLSParams
	{ pConnectVersion    :: Version             
	, pAllowedVersions   :: [Version]           
	, pCiphers           :: [Cipher]            
	, pCompressions      :: [Compression]       
	, pWantClientCert    :: Bool                
	                                            
	, pUseSecureRenegotiation :: Bool           
	, pUseSession             :: Bool           
	, pCertificates      :: [(X509, Maybe PrivateKey)] 
	, pLogging           :: TLSLogging          
	, onHandshake        :: Measurement -> IO Bool 
	, onCertificatesRecv :: [X509] -> IO TLSCertificateUsage 
	, onSessionResumption :: SessionID -> IO (Maybe SessionData) 
	, onSessionEstablished :: SessionID -> SessionData -> IO ()  
	, onSessionInvalidated :: SessionID -> IO ()                 
	, sessionResumeWith   :: Maybe (SessionID, SessionData) 
	}
defaultLogging :: TLSLogging
defaultLogging = TLSLogging
	{ loggingPacketSent = (\_ -> return ())
	, loggingPacketRecv = (\_ -> return ())
	, loggingIOSent     = (\_ -> return ())
	, loggingIORecv     = (\_ _ -> return ())
	}
defaultParams :: TLSParams
defaultParams = TLSParams
	{ pConnectVersion         = TLS10
	, pAllowedVersions        = [TLS10,TLS11,TLS12]
	, pCiphers                = []
	, pCompressions           = [nullCompression]
	, pWantClientCert         = False
	, pUseSecureRenegotiation = True
	, pUseSession             = True
	, pCertificates           = []
	, pLogging                = defaultLogging
	, onHandshake             = (\_ -> return True)
	, onCertificatesRecv      = (\_ -> return CertificateUsageAccept)
	, onSessionResumption     = (\_ -> return Nothing)
	, onSessionEstablished    = (\_ _ -> return ())
	, onSessionInvalidated    = (\_ -> return ())
	, sessionResumeWith       = Nothing
	}
instance Show TLSParams where
	show p = "TLSParams { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
		[ ("connectVersion", show $ pConnectVersion p)
		, ("allowedVersions", show $ pAllowedVersions p)
		, ("ciphers", show $ pCiphers p)
		, ("compressions", show $ pCompressions p)
		, ("want-client-cert", show $ pWantClientCert p)
		, ("certificates", show $ length $ pCertificates p)
		]) ++ " }"
data TLSCertificateRejectReason =
	  CertificateRejectExpired
	| CertificateRejectRevoked
	| CertificateRejectUnknownCA
	| CertificateRejectOther String
	deriving (Show,Eq)
data TLSCertificateUsage =
	  CertificateUsageAccept                            
	| CertificateUsageReject TLSCertificateRejectReason 
	deriving (Show,Eq)
data TLSCtx a = TLSCtx
	{ ctxConnection      :: a             
	, ctxParams          :: TLSParams
	, ctxState           :: MVar TLSState
	, ctxMeasurement     :: IORef Measurement
	, ctxEOF_            :: IORef Bool    
	, ctxEstablished_    :: IORef Bool    
	, ctxConnectionFlush :: IO ()
	, ctxConnectionSend  :: Bytes -> IO ()
	, ctxConnectionRecv  :: Int -> IO Bytes
	}
updateMeasure :: MonadIO m => TLSCtx c -> (Measurement -> Measurement) -> m ()
updateMeasure ctx f = liftIO $ do
    x <- readIORef (ctxMeasurement ctx)
    writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: MonadIO m => TLSCtx c -> (Measurement -> IO a) -> m a
withMeasure ctx f = liftIO (readIORef (ctxMeasurement ctx) >>= f)
connectionFlush :: TLSCtx c -> IO ()
connectionFlush c = ctxConnectionFlush c
connectionSend :: TLSCtx c -> Bytes -> IO ()
connectionSend c b = updateMeasure c (addBytesSent $ B.length b) >> (ctxConnectionSend c) b
connectionRecv :: TLSCtx c -> Int -> IO Bytes
connectionRecv c sz = updateMeasure c (addBytesReceived sz) >> (ctxConnectionRecv c) sz
ctxEOF :: MonadIO m => TLSCtx a -> m Bool
ctxEOF ctx = liftIO (readIORef $ ctxEOF_ ctx)
setEOF :: MonadIO m => TLSCtx c -> m ()
setEOF ctx = liftIO $ writeIORef (ctxEOF_ ctx) True
ctxEstablished :: MonadIO m => TLSCtx a -> m Bool
ctxEstablished ctx = liftIO $ readIORef $ ctxEstablished_ ctx
setEstablished :: MonadIO m => TLSCtx c -> Bool -> m ()
setEstablished ctx v = liftIO $ writeIORef (ctxEstablished_ ctx) v
ctxLogging :: TLSCtx a -> TLSLogging
ctxLogging = pLogging . ctxParams
newCtxWith :: c -> IO () -> (Bytes -> IO ()) -> (Int -> IO Bytes) -> TLSParams -> TLSState -> IO (TLSCtx c)
newCtxWith c flushF sendF recvF params st = do
	stvar <- newMVar st
	eof   <- newIORef False
	established <- newIORef False
	stats <- newIORef newMeasurement
	return $ TLSCtx
		{ ctxConnection  = c
		, ctxParams      = params
		, ctxState       = stvar
		, ctxMeasurement = stats
		, ctxEOF_        = eof
		, ctxEstablished_    = established
		, ctxConnectionFlush = flushF
		, ctxConnectionSend  = sendF
		, ctxConnectionRecv  = recvF
		}
newCtx :: Handle -> TLSParams -> TLSState -> IO (TLSCtx Handle)
newCtx handle params st = do
	hSetBuffering handle NoBuffering
	newCtxWith handle (hFlush handle) (B.hPut handle) (B.hGet handle) params st
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
usingState :: MonadIO m => TLSCtx c -> TLSSt a -> m (Either TLSError a)
usingState ctx f =
	liftIO $ modifyMVar (ctxState ctx) $ \st ->
		let (a, newst) = runTLSState f st
		 in newst `seq` return (newst, a)
usingState_ :: MonadIO m => TLSCtx c -> TLSSt a -> m a
usingState_ ctx f = do
	ret <- usingState ctx f
	case ret of
		Left err -> throwCore err
		Right r  -> return r
getStateRNG :: MonadIO m => TLSCtx c -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx (genTLSRandom n)