{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Keter.SSL
    ( SslConfig (..)
    , setDir
    , runTCPServerSsl
    ) where

import Keter.Prelude ((++))
import Prelude hiding ((++), FilePath, readFile)
import Data.Yaml (FromJSON (parseJSON), (.:), (.:?), (.!=), Value (Object))
import Control.Applicative ((<$>), (<*>))
import Control.Monad (mzero, forever)
import Data.String (fromString)
import Filesystem.Path.CurrentOS ((</>), FilePath)
import Filesystem (readFile)
import qualified Data.ByteString.Lazy as L
import qualified Data.Certificate.KeyRSA as KeyRSA
import qualified Data.PEM as PEM
import qualified Network.TLS as TLS
import qualified Data.Certificate.X509 as X509
import Data.Conduit.Network (HostPreference, Application, bindPort, sinkSocket)
import Data.Conduit (($$), yield)
import qualified Data.Conduit.List as CL
import Data.Either (rights)
import Keter.PortManager (Port)
import Network.Socket (sClose, accept)
import Network.Socket.ByteString (recv)
import Control.Exception (bracket, finally)
import Control.Concurrent (forkIO)
import Control.Monad.Trans.Class (lift)
import qualified Network.TLS.Extra as TLSExtra
import Crypto.Random

data SslConfig = SslConfig
    { sslHost :: HostPreference
    , sslPort :: Port
    , sslCertificate :: FilePath
    , sslKey :: FilePath

setDir :: FilePath -> SslConfig -> SslConfig
setDir dir ssl = ssl
    { sslCertificate = dir </> sslCertificate ssl
    , sslKey = dir </> sslKey ssl

instance FromJSON SslConfig where
    parseJSON (Object o) = SslConfig
        <$> (fmap fromString <$> o .:? "host") .!= "*"
        <*> o .:? "port" .!= 443
        <*> (fromString <$> o .: "certificate")
        <*> (fromString <$> o .: "key")
    parseJSON _ = mzero

runTCPServerSsl :: SslConfig -> Application IO -> IO ()
runTCPServerSsl SslConfig{..} app = do
    cert <- readCertificate sslCertificate
    key <- readPrivateKey sslKey
        (bindPort sslPort sslHost)
        (forever . serve cert key)
    serve cert key lsocket = do
        (socket, _addr) <- accept lsocket -- FIXME exception safety
        _ <- forkIO $ handle socket
        return ()
        handle socket = do
            gen <- newGenIO
            ctx <- TLS.serverWith
                (gen :: SystemRandom)
                (return ()) -- flush
                (\bs -> yield bs $$ sinkSocket socket)
                (recv socket)

            TLS.handshake ctx
            let conn = Connection
                    { connSendMany = TLS.sendData ctx . L.fromChunks
                    , connSendAll = TLS.sendData ctx . L.fromChunks . return
                    , connSendFile = \fp offset len _th headers -> do
                        TLS.sendData ctx $ L.fromChunks headers
                        C.runResourceT $ sourceFileRange fp (Just offset) (Just len) C.$$ CL.mapM_ (TLS.sendData ctx . L.fromChunks . return)
                    , connClose = do
                        TLS.bye ctx
                        sClose s
                    , connRecv = TLS.recvData ctx
            return (conn, sa)

            let src = lift (TLS.recvData ctx) >>= yield >> src
                sink = CL.mapM_ $ TLS.sendData ctx . L.fromChunks . return

            app src sink `finally` sClose socket

        params = TLS.defaultParams
            { TLS.pWantClientCert = False
            , TLS.pAllowedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
            , TLS.pCiphers         = ciphers
            , TLS.pCertificates    = [(cert, Just key)]

-- taken from stunnel example in tls-extra
ciphers :: [TLS.Cipher]
ciphers =
    [ TLSExtra.cipher_AES128_SHA1
    , TLSExtra.cipher_AES256_SHA1
    , TLSExtra.cipher_RC4_128_MD5
    , TLSExtra.cipher_RC4_128_SHA1

readCertificate :: FilePath -> IO X509.X509
readCertificate filepath = do
    certs <- rights . parseCerts . PEM.pemParseBS <$> readFile filepath
    case certs of
        []    -> error "no valid certificate found"
        (x:_) -> return x
    where parseCerts (Right pems) = map (X509.decodeCertificate . L.fromChunks . (:[]) . PEM.pemContent)
                                  $ filter (flip elem ["CERTIFICATE", "TRUSTED CERTIFICATE"] . PEM.pemName) pems
          parseCerts (Left err) = error $ "cannot parse PEM file: " ++ err

readPrivateKey :: FilePath -> IO TLS.PrivateKey
readPrivateKey filepath = do
    pk <- rights . parseKey . PEM.pemParseBS <$> readFile filepath
    case pk of
        []    -> error "no valid RSA key found"
        (x:_) -> return x

    where parseKey (Right pems) = map (fmap (TLS.PrivRSA . snd) . KeyRSA.decodePrivate . L.fromChunks . (:[]) . PEM.pemContent)
                                $ filter ((== "RSA PRIVATE KEY") . PEM.pemName) pems
          parseKey (Left err) = error $ "Cannot parse PEM file: " ++ err