module Data.Acid.Remote.TLS
(
acidServerTLS
, openRemoteStateTLS
, skipAuthenticationCheck
, skipAuthenticationPerform
, sharedSecretCheck
, sharedSecretPerform
) where
import Control.Concurrent ( forkIO, threadDelay )
import Control.Exception ( Handler(..), IOException, SomeException, catch, catches, handle
, finally, throwIO )
import Control.Monad ( forever, when )
import Data.Acid ( AcidState, IsAcidic )
import Data.Acid.Remote ( CommChannel(..), process, processRemoteState, skipAuthenticationCheck
, skipAuthenticationPerform, sharedSecretCheck, sharedSecretPerform )
import Data.SafeCopy ( SafeCopy )
import GHC.IO.Exception ( IOErrorType(..) )
import OpenSSL ( withOpenSSL )
import OpenSSL.Session ( SomeSSLException, SSL, SSLContext )
import qualified OpenSSL.Session as SSL
import Network ( HostName, PortID(..), Socket, listenOn, sClose, withSocketsDo )
import Network.Socket as Socket ( Family(..), SockAddr(..), SocketType(..), accept, socket, connect )
import Network.BSD ( getHostByName, getProtocolNumber, getServicePortNumber, hostAddress )
import System.Directory ( removeFile )
import System.IO.Error ( ioeGetErrorType, isFullError, isDoesNotExistError )
debugStrLn :: String -> IO ()
debugStrLn s =
do putStrLn s
return ()
initSSLContext :: FilePath
-> FilePath
-> IO SSLContext
initSSLContext cert key =
do ctx <- SSL.context
SSL.contextSetPrivateKeyFile ctx key
SSL.contextSetCertificateFile ctx cert
SSL.contextSetDefaultCiphers ctx
certOk <- SSL.contextCheckPrivateKey ctx
when (not certOk) $ error $ "OpenTLS certificate and key do not match."
return ctx
acceptTLS :: SSLContext -> Socket -> IO (Socket, SSL, SockAddr)
acceptTLS ctx sck' =
do
(sck, sockAddr) <- accept sck'
handle (\ (e :: SomeException) -> sClose sck >> throwIO e) $ do
ssl <- SSL.connection ctx sck
SSL.accept ssl
return (sck, ssl, sockAddr)
acidServerTLS :: SafeCopy st =>
FilePath
-> FilePath
-> (CommChannel -> IO Bool)
-> PortID
-> AcidState st
-> IO ()
acidServerTLS sslCert sslKey checkAuth port acidState
= withSocketsDo $
do withOpenSSL $ return ()
debugStrLn $ "acidServerTLS: listenOn " ++ show port
tlsSocket <- listenOn port
debugStrLn $ "acidServerTLS: initSSLContext"
ctx <- initSSLContext sslCert sslKey
let worker :: (Socket, SSL, SockAddr) -> IO ()
worker (socket, ssl, _sockAddr) =
do
let socketCommChannel :: CommChannel
socketCommChannel = CommChannel
{ ccPut = SSL.write ssl
, ccGetSome = SSL.read ssl
, ccClose = shutdownClose socket ssl
}
forkIO $ (do authorized <- checkAuth socketCommChannel
when authorized $
ignoreSome $ (process socketCommChannel acidState)
ccClose socketCommChannel) `catch` (\(e::SomeException) -> do
shutdownClose socket ssl
throwIO e)
return ()
loop :: IO ()
loop = do ignoreSome $ (forever $ worker =<< acceptTLS ctx tlsSocket)
loop
loop `finally` (cleanup tlsSocket `catch` ignoreException)
where
cleanup tlsSocket
= do debugStrLn "acidServerTLS: cleanup."
sClose tlsSocket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
case port of
(UnixSocket path) ->
removeFile path `catch` (\e -> if isDoesNotExistError e then return () else throwIO e)
_ -> return ()
#endif
ignoreConnectionAbruptlyTerminated :: SSL.ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated _ = return ()
ignoreSSLException :: SSL.SomeSSLException -> IO ()
ignoreSSLException _ = return ()
ignoreException :: SomeException -> IO ()
ignoreException _ = return ()
shutdownClose :: Socket -> SSL -> IO ()
shutdownClose socket ssl =
do debugStrLn "acidServerTLS: shutdownClose."
SSL.shutdown ssl SSL.Unidirectional `catch` ignoreException
sClose socket `catch` ignoreException
ignoreSome op =
op `catches` [ Handler $ ignoreSSLException
, Handler $ ignoreConnectionAbruptlyTerminated
, Handler $ \(e :: IOException) ->
if isFullError e || isDoesNotExistError e || isResourceVanishedError e
then return ()
else throwIO e
]
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = isResourceVanishedType . ioeGetErrorType
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType ResourceVanished = True
isResourceVanishedType _ = False
openRemoteStateTLS :: IsAcidic st =>
(CommChannel -> IO ())
-> HostName
-> PortID
-> IO (AcidState st)
openRemoteStateTLS performAuthorization host port
= do withOpenSSL $ return ()
processRemoteState reconnect
where
sslCommChannel ssl =
CommChannel { ccGetSome = SSL.read ssl
, ccPut = SSL.write ssl
, ccClose = do SSL.shutdown ssl SSL.Unidirectional
}
reconnect :: IO CommChannel
reconnect
= (do ssl <- connectToTLS host port
let cc = sslCommChannel ssl
performAuthorization cc
return cc
)
`catch`
((\e -> threadDelay 1000000 >> reconnect) :: IOError -> IO CommChannel)
connectToTLS :: HostName
-> PortID
-> IO SSL
connectToTLS hostName (Service serv)
= do port <- getServicePortNumber serv
connectToTLS hostName (PortNumber port)
connectToTLS hostName (PortNumber port)
= do proto <- getProtocolNumber "tcp"
sock <- socket AF_INET Stream proto
(do he <- getHostByName hostName
Socket.connect sock (SockAddrInet port (hostAddress he))
ctx <- SSL.context
ssl <- SSL.connection ctx sock
SSL.connect ssl
return ssl) `catch` (\e -> do print (e :: SomeSSLException)
sClose sock
throwIO e
)
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectToTLS _hostName p@(UnixSocket path)
= do debugStrLn $ "connectToTLS: " ++ show p
sock <- socket AF_UNIX Stream 0
(do debugStrLn $ "connectToTLS: connect."
Socket.connect sock (SockAddrUnix path)
ctx <- SSL.context
ssl <- SSL.connection ctx sock
debugStrLn $ "connectToTLS: connect ssl."
SSL.connect ssl
debugStrLn $ "connectToTLS: done."
return ssl) `catch` (\e -> do print (e :: SomeSSLException)
sClose sock
throwIO e
)
#endif