module Network.Wai.Handler.WarpTLS (
TLSSettings
, certFile
, keyFile
, onInsecure
, tlsLogging
, tlsAllowedVersions
, tlsCiphers
, defaultTlsSettings
, tlsSettings
, OnInsecure (..)
, runTLS
, runTLSSocket
, WarpTLSException (..)
) where
import qualified Network.TLS as TLS
import Network.Wai.Handler.Warp
import Network.Wai (Application)
import Network.Socket (Socket, sClose, withSocketsDo)
import qualified Data.ByteString.Lazy as L
import Control.Exception (bracket, finally, handle)
import qualified Network.TLS.Extra as TLSExtra
import qualified Data.ByteString as B
import Data.Streaming.Network (bindPortTCP, acceptSafe, safeRecv)
import Control.Applicative ((<$>))
import qualified Data.IORef as I
import Control.Exception (Exception, throwIO)
import Data.Typeable (Typeable)
import Data.Default.Class (def)
import qualified Crypto.Random.AESCtr
import Network.Wai.Handler.Warp.Buffer (allocateBuffer, bufferSize, freeBuffer)
import Network.Socket.ByteString (sendAll)
import Control.Monad (unless)
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import qualified System.IO as IO
data TLSSettings = TLSSettings
{ certFile :: FilePath
, keyFile :: FilePath
, onInsecure :: OnInsecure
, tlsLogging :: TLS.Logging
, tlsAllowedVersions :: [TLS.Version]
, tlsCiphers :: [TLS.Cipher]
}
data OnInsecure = DenyInsecure L.ByteString
| AllowInsecure
tlsSettings :: FilePath
-> FilePath
-> TLSSettings
tlsSettings cert key = defaultTlsSettings
{ certFile = cert
, keyFile = key
}
defaultTlsSettings :: TLSSettings
defaultTlsSettings = TLSSettings
{ certFile = "certificate.pem"
, keyFile = "key.pem"
, onInsecure = DenyInsecure "This server only accepts secure HTTPS connections."
, tlsLogging = def
, tlsAllowedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
, tlsCiphers = ciphers
}
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings {..} set sock app = do
credential <- either error id <$> TLS.credentialLoadX509 certFile keyFile
let params = def
{ TLS.serverWantClientCert = False
, TLS.serverSupported = def
{ TLS.supportedVersions = tlsAllowedVersions
, TLS.supportedCiphers = tlsCiphers
}
, TLS.serverShared = def
{ TLS.sharedCredentials = TLS.Credentials [credential]
}
}
runSettingsConnectionMakerSecure set (getter params) app
where
getter params = do
(s, sa) <- acceptSafe sock
let mkConn :: IO (Connection, Bool)
mkConn = do
firstBS <- safeRecv s 4096
cachedRef <- I.newIORef firstBS
let getNext size = do
cached <- I.readIORef cachedRef
loop cached
where
loop bs | B.length bs >= size = do
let (x, y) = B.splitAt size bs
I.writeIORef cachedRef y
return x
loop bs1 = do
bs2 <- safeRecv s 4096
if B.null bs2
then do
I.writeIORef cachedRef B.empty
return bs1
else loop $ B.append bs1 bs2
if not (B.null firstBS) && B.head firstBS == 0x16
then do
gen <- Crypto.Random.AESCtr.makeSystem
ctx <- TLS.contextNew
TLS.Backend
{ TLS.backendFlush = return ()
, TLS.backendClose = sClose s
, TLS.backendSend = sendAll s
, TLS.backendRecv = getNext
}
params
gen
TLS.contextHookSetLogging ctx tlsLogging
TLS.handshake ctx
readBuf <- allocateBuffer bufferSize
writeBuf <- allocateBuffer bufferSize
let conn = Connection
{ connSendMany = TLS.sendData ctx . L.fromChunks
, connSendAll = TLS.sendData ctx . L.fromChunks . return
, connSendFile = \fp offset len _th headers -> do
TLS.sendData ctx $ L.fromChunks headers
IO.withBinaryFile fp IO.ReadMode $ \h -> do
IO.hSeek h IO.AbsoluteSeek offset
let loop remaining | remaining <= 0 = return ()
loop remaining = do
bs <- B.hGetSome h defaultChunkSize
unless (B.null bs) $ do
let x = B.take remaining bs
TLS.sendData ctx $ L.fromChunks [x]
loop $ remaining B.length x
loop $ fromIntegral len
, connClose =
freeBuffer readBuf `finally`
freeBuffer writeBuf `finally`
TLS.bye ctx `finally`
TLS.contextClose ctx
, connRecv =
let onEOF TLS.Error_EOF = return B.empty
onEOF e = throwIO e
go = do
x <- TLS.recvData ctx
if B.null x
then go
else return x
in handle onEOF go
, connSendFileOverride = NotOverride
, connReadBuffer = readBuf
, connWriteBuffer = writeBuf
, connBufferSize = bufferSize
}
return (conn, True)
else
case onInsecure of
AllowInsecure -> do
conn' <- socketConnection s
return (conn'
{ connRecv = getNext 4096
}, False)
DenyInsecure lbs -> do
sendAll s "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"
mapM_ (sendAll s) $ L.toChunks lbs
sClose s
throwIO InsecureConnectionDenied
return (mkConn, sa)
data WarpTLSException = InsecureConnectionDenied
deriving (Show, Typeable)
instance Exception WarpTLSException
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS tset set app = withSocketsDo $
bracket
(bindPortTCP (getPort set) (getHost set))
sClose
(\sock -> runTLSSocket tset set sock app)
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
, TLSExtra.cipher_RC4_128_MD5
, TLSExtra.cipher_RC4_128_SHA1
]