{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
{- |
-- borrowed from snap-server. Check there periodically for updates.
-}
module Happstack.Server.Internal.TimeoutSocketTLS where

import           Control.Exception             (SomeException, catch)
import qualified Data.ByteString.Char8         as B
import qualified Data.ByteString.Lazy.Char8    as L
import qualified Data.ByteString.Lazy.Internal as L
import qualified Data.ByteString               as S
import qualified Happstack.Server.Internal.TimeoutManager as TM
import           Happstack.Server.Internal.TimeoutIO (TimeoutIO(..))
import           Network.Socket                (Socket, close)
import           Network.Socket.SendFile (ByteCount, Offset)
import           OpenSSL.Session               (SSL)
import qualified OpenSSL.Session               as SSL
import           Prelude                       hiding (catch)
import           System.IO (IOMode(ReadMode), SeekMode(AbsoluteSeek), hSeek, withBinaryFile)
import           System.IO.Unsafe (unsafeInterleaveIO)

sPutLazyTickle :: TM.Handle -> SSL -> L.ByteString -> IO ()
sPutLazyTickle :: Handle -> SSL -> ByteString -> IO ()
sPutLazyTickle Handle
thandle SSL
ssl ByteString
cs =
    do (ByteString -> IO () -> IO ()) -> IO () -> ByteString -> IO ()
forall a. (ByteString -> a -> a) -> a -> ByteString -> a
L.foldrChunks (\ByteString
c IO ()
rest -> SSL -> ByteString -> IO ()
SSL.write SSL
ssl ByteString
c IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
TM.tickle Handle
thandle IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
rest) (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) ByteString
cs
{-# INLINE sPutLazyTickle #-}

sPutTickle :: TM.Handle -> SSL -> B.ByteString -> IO ()
sPutTickle :: Handle -> SSL -> ByteString -> IO ()
sPutTickle Handle
thandle SSL
ssl ByteString
cs =
    do SSL -> ByteString -> IO ()
SSL.write SSL
ssl ByteString
cs
       Handle -> IO ()
TM.tickle Handle
thandle
{-# INLINE sPutTickle #-}

sGetContents :: TM.Handle
             -> SSL              -- ^ Connected socket
             -> IO L.ByteString  -- ^ Data received
sGetContents :: Handle -> SSL -> IO ByteString
sGetContents Handle
handle SSL
ssl =
    ([ByteString] -> ByteString) -> IO [ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [ByteString] -> ByteString
L.fromChunks IO [ByteString]
loop
    where
      chunkSize :: Int
chunkSize = Int
65536
      loop :: IO [ByteString]
loop = IO [ByteString] -> IO [ByteString]
forall a. IO a -> IO a
unsafeInterleaveIO (IO [ByteString] -> IO [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall a b. (a -> b) -> a -> b
$ do
               ByteString
s <- SSL -> Int -> IO ByteString
SSL.read SSL
ssl Int
chunkSize
               Handle -> IO ()
TM.tickle Handle
handle
               if ByteString -> Bool
S.null ByteString
s
                then do [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return []
                else do [ByteString]
ss <- IO [ByteString]
loop
                        [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
sByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
ss)

timeoutSocketIO :: TM.Handle -> Socket -> SSL -> TimeoutIO
timeoutSocketIO :: Handle -> Socket -> SSL -> TimeoutIO
timeoutSocketIO Handle
handle Socket
socket SSL
ssl =
    TimeoutIO :: Handle
-> (ByteString -> IO ())
-> (ByteString -> IO ())
-> IO (Maybe ByteString)
-> IO ByteString
-> (FilePath -> Offset -> Offset -> IO ())
-> IO ()
-> Bool
-> TimeoutIO
TimeoutIO { toHandle :: Handle
toHandle      = Handle
handle
              , toShutdown :: IO ()
toShutdown    = do SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
ssl ShutdownType
SSL.Unidirectional IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
ignoreException
                                   Socket -> IO ()
close Socket
socket IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
ignoreException
              , toPutLazy :: ByteString -> IO ()
toPutLazy     = Handle -> SSL -> ByteString -> IO ()
sPutLazyTickle Handle
handle SSL
ssl
              , toPut :: ByteString -> IO ()
toPut         = Handle -> SSL -> ByteString -> IO ()
sPutTickle     Handle
handle SSL
ssl
              , toGetContents :: IO ByteString
toGetContents = Handle -> SSL -> IO ByteString
sGetContents   Handle
handle SSL
ssl
              , toSendFile :: FilePath -> Offset -> Offset -> IO ()
toSendFile    = Handle -> SSL -> FilePath -> Offset -> Offset -> IO ()
sendFileTickle Handle
handle SSL
ssl
              , toSecure :: Bool
toSecure      = Bool
True
              }
    where
      ignoreException :: SomeException -> IO ()
      ignoreException :: SomeException -> IO ()
ignoreException SomeException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

sendFileTickle :: TM.Handle -> SSL -> FilePath -> Offset -> ByteCount -> IO ()
sendFileTickle :: Handle -> SSL -> FilePath -> Offset -> Offset -> IO ()
sendFileTickle Handle
thandle SSL
ssl FilePath
fp Offset
offset Offset
count =
    do FilePath -> IOMode -> (Handle -> IO ()) -> IO ()
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile FilePath
fp IOMode
ReadMode ((Handle -> IO ()) -> IO ()) -> (Handle -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
         Handle -> SeekMode -> Offset -> IO ()
hSeek Handle
h SeekMode
AbsoluteSeek Offset
offset
         ByteString
c <- Handle -> IO ByteString
L.hGetContents Handle
h
         Handle -> SSL -> ByteString -> IO ()
sPutLazyTickle Handle
thandle SSL
ssl (Int64 -> ByteString -> ByteString
L.take (Offset -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Offset
count) ByteString
c)