{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Snap.Internal.Http.Server.Socket
  ( bindSocket
  , bindSocketImpl
  , bindUnixSocket
  , httpAcceptFunc
  , haProxyAcceptFunc
  , sendFileFunc
  , acceptAndInitialize
  ) where

------------------------------------------------------------------------------
import           Control.Exception                 (bracketOnError, finally, throwIO)
import           Control.Monad                     (when)
import           Data.Bits                         (complement, (.&.))
import           Data.ByteString.Char8             (ByteString)
import           Network.Socket                    (Socket, SocketOption (NoDelay, ReuseAddr), accept, close, getSocketName, setSocketOption, socket)
import qualified Network.Socket                    as N
#ifdef HAS_SENDFILE
import           Network.Socket                    (fdSocket)
import           System.Posix.IO                   (OpenMode (..), closeFd, defaultFileFlags, openFd)
import           System.Posix.Types                (Fd (..))
import           System.SendFile                   (sendFile, sendHeaders)
#else
import           Data.ByteString.Builder           (byteString)
import           Data.ByteString.Builder.Extra     (flush)
import           Network.Socket.ByteString         (sendAll)
#endif
#ifdef HAS_UNIX_SOCKETS
import           Control.Exception                 (bracket)
import qualified Control.Exception                 as E (catch)
import           System.FilePath                   (isRelative)
import           System.IO.Error                   (isDoesNotExistError)
import           System.Posix.Files                (accessModes, removeLink, setFileCreationMask)
#endif

------------------------------------------------------------------------------
import qualified System.IO.Streams                 as Streams
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Address (AddressNotSupportedException (..), getAddress, getSockAddr)
import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
import qualified System.IO.Streams.Network.HAProxy as HA


------------------------------------------------------------------------------
bindSocket :: ByteString -> Int -> IO Socket
bindSocket :: ByteString -> Int -> IO Socket
bindSocket = (Socket -> SocketOption -> Int -> IO ())
-> (Socket -> SockAddr -> IO ())
-> (Socket -> Int -> IO ())
-> ByteString
-> Int
-> IO Socket
bindSocketImpl Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket -> SockAddr -> IO ()
bind Socket -> Int -> IO ()
N.listen
  where
#if MIN_VERSION_network(2,7,0)
    bind :: Socket -> SockAddr -> IO ()
bind = Socket -> SockAddr -> IO ()
N.bind
#else
    bind = N.bindSocket
#endif
{-# INLINE bindSocket #-}


------------------------------------------------------------------------------
bindSocketImpl
    :: (Socket -> SocketOption -> Int -> IO ()) -- ^ mock setSocketOption
    -> (Socket -> N.SockAddr -> IO ())          -- ^ bindSocket
    -> (Socket -> Int -> IO ())                 -- ^ listen
    -> ByteString
    -> Int
    -> IO Socket
bindSocketImpl :: (Socket -> SocketOption -> Int -> IO ())
-> (Socket -> SockAddr -> IO ())
-> (Socket -> Int -> IO ())
-> ByteString
-> Int
-> IO Socket
bindSocketImpl Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket -> SockAddr -> IO ()
_bindSocket Socket -> Int -> IO ()
_listen ByteString
bindAddr Int
bindPort = do
    (Family
family, SockAddr
addr) <- Int -> ByteString -> IO (Family, SockAddr)
getSockAddr Int
bindPort ByteString
bindAddr
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
family SocketType
N.Stream ProtocolNumber
0) Socket -> IO ()
N.close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
        Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket
sock SocketOption
NoDelay Int
1
        Socket -> SockAddr -> IO ()
_bindSocket Socket
sock SockAddr
addr
        Socket -> Int -> IO ()
_listen Socket
sock Int
150
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Socket
sock

bindUnixSocket :: Maybe Int -> String -> IO Socket
#if HAS_UNIX_SOCKETS
bindUnixSocket :: Maybe Int -> [Char] -> IO Socket
bindUnixSocket Maybe Int
mode [Char]
path = do
   forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Char] -> Bool
isRelative [Char]
path) forall a b. (a -> b) -> a -> b
$
      forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ [Char] -> AddressNotSupportedException
AddressNotSupportedException
                forall a b. (a -> b) -> a -> b
$! [Char]
"Refusing to bind unix socket to non-absolute path: " forall a. [a] -> [a] -> [a]
++ [Char]
path

   forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
N.AF_UNIX SocketType
N.Stream ProtocolNumber
0) Socket -> IO ()
N.close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      forall e a. Exception e => IO a -> (e -> IO a) -> IO a
E.catch ([Char] -> IO ()
removeLink [Char]
path) forall a b. (a -> b) -> a -> b
$ \IOError
e -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ IOError -> Bool
isDoesNotExistError IOError
e) forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO IOError
e
      case Maybe Int
mode of
         Maybe Int
Nothing -> Socket -> SockAddr -> IO ()
bind Socket
sock ([Char] -> SockAddr
N.SockAddrUnix [Char]
path)
         Just Int
mode' -> forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (FileMode -> IO FileMode
setFileCreationMask forall a b. (a -> b) -> a -> b
$ forall {a}. Integral a => a -> FileMode
modeToMask Int
mode')
                              FileMode -> IO FileMode
setFileCreationMask
                              (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
bind Socket
sock ([Char] -> SockAddr
N.SockAddrUnix [Char]
path))
      Socket -> Int -> IO ()
N.listen Socket
sock Int
150
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Socket
sock
   where
#if MIN_VERSION_network(2,7,0)
     bind :: Socket -> SockAddr -> IO ()
bind = Socket -> SockAddr -> IO ()
N.bind
#else
     bind = N.bindSocket
#endif
     modeToMask :: a -> FileMode
modeToMask a
p = FileMode
accessModes forall a. Bits a => a -> a -> a
.&. forall a. Bits a => a -> a
complement (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
p)
#else
bindUnixSocket _ path = throwIO (AddressNotSupportedException $ "unix:" ++ path)
#endif

------------------------------------------------------------------------------
-- TODO(greg): move buffer size configuration into config
bUFSIZ :: Int
bUFSIZ :: Int
bUFSIZ = Int
4064


------------------------------------------------------------------------------
acceptAndInitialize :: Socket        -- ^ bound socket
                    -> (forall b . IO b -> IO b)
                    -> ((Socket, N.SockAddr) -> IO a)
                    -> IO a
acceptAndInitialize :: forall a.
Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore (Socket, SockAddr) -> IO a
f =
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (forall b. IO b -> IO b
restore forall a b. (a -> b) -> a -> b
$ Socket -> IO (Socket, SockAddr)
accept Socket
boundSocket)
                   (Socket -> IO ()
close forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
                   (Socket, SockAddr) -> IO a
f


------------------------------------------------------------------------------
haProxyAcceptFunc :: Socket     -- ^ bound socket
                  -> AcceptFunc
haProxyAcceptFunc :: Socket -> AcceptFunc
haProxyAcceptFunc Socket
boundSocket =
    ((forall b. IO b -> IO b)
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
AcceptFunc forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore ->
    forall a.
Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
saddr) -> do
        (InputStream ByteString
readEnd, OutputStream ByteString
writeEnd)      <- Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
Streams.socketToStreamsWithBufferSize
                                        Int
bUFSIZ Socket
sock
        ProxyInfo
localPInfo               <- Socket -> SockAddr -> IO ProxyInfo
HA.socketToProxyInfo Socket
sock SockAddr
saddr
        ProxyInfo
pinfo                    <- ProxyInfo -> InputStream ByteString -> IO ProxyInfo
HA.decodeHAProxyHeaders ProxyInfo
localPInfo InputStream ByteString
readEnd
        (Int
localPort, ByteString
localHost)   <- SockAddr -> IO (Int, ByteString)
getAddress forall a b. (a -> b) -> a -> b
$ ProxyInfo -> SockAddr
HA.getDestAddr ProxyInfo
pinfo
        (Int
remotePort, ByteString
remoteHost) <- SockAddr -> IO (Int, ByteString)
getAddress forall a b. (a -> b) -> a -> b
$ ProxyInfo -> SockAddr
HA.getSourceAddr ProxyInfo
pinfo
        let cleanup :: IO ()
cleanup              =  forall a. Maybe a -> OutputStream a -> IO ()
Streams.write forall a. Maybe a
Nothing OutputStream ByteString
writeEnd
                                        forall a b. IO a -> IO b -> IO a
`finally` Socket -> IO ()
close Socket
sock
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! ( Socket -> SendFileHandler
sendFileFunc Socket
sock
                  , ByteString
localHost
                  , Int
localPort
                  , ByteString
remoteHost
                  , Int
remotePort
                  , InputStream ByteString
readEnd
                  , OutputStream ByteString
writeEnd
                  , IO ()
cleanup
                  )


------------------------------------------------------------------------------
httpAcceptFunc :: Socket                     -- ^ bound socket
               -> AcceptFunc
httpAcceptFunc :: Socket -> AcceptFunc
httpAcceptFunc Socket
boundSocket =
    ((forall b. IO b -> IO b)
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
AcceptFunc forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore ->
    forall a.
Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
remoteAddr) -> do
        SockAddr
localAddr                <- Socket -> IO SockAddr
getSocketName Socket
sock
        (Int
localPort, ByteString
localHost)   <- SockAddr -> IO (Int, ByteString)
getAddress SockAddr
localAddr
        (Int
remotePort, ByteString
remoteHost) <- SockAddr -> IO (Int, ByteString)
getAddress SockAddr
remoteAddr
        (InputStream ByteString
readEnd, OutputStream ByteString
writeEnd)      <- Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
Streams.socketToStreamsWithBufferSize Int
bUFSIZ
                                                                          Socket
sock
        let cleanup :: IO ()
cleanup              =  forall a. Maybe a -> OutputStream a -> IO ()
Streams.write forall a. Maybe a
Nothing OutputStream ByteString
writeEnd
                                      forall a b. IO a -> IO b -> IO a
`finally` Socket -> IO ()
close Socket
sock
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! ( Socket -> SendFileHandler
sendFileFunc Socket
sock
                  , ByteString
localHost
                  , Int
localPort
                  , ByteString
remoteHost
                  , Int
remotePort
                  , InputStream ByteString
readEnd
                  , OutputStream ByteString
writeEnd
                  , IO ()
cleanup
                  )


------------------------------------------------------------------------------
sendFileFunc :: Socket -> SendFileHandler
#ifdef HAS_SENDFILE
sendFileFunc :: Socket -> SendFileHandler
sendFileFunc Socket
sock !Buffer
_ Builder
builder [Char]
fPath Word64
offset Word64
nbytes = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Fd
acquire Fd -> IO ()
closeFd Fd -> IO ()
go
  where
#if MIN_VERSION_unix(2,8,0)
    acquire   = openFd fPath ReadOnly defaultFileFlags
#else
    acquire :: IO Fd
acquire   = [Char] -> OpenMode -> Maybe FileMode -> OpenFileFlags -> IO Fd
openFd [Char]
fPath OpenMode
ReadOnly forall a. Maybe a
Nothing OpenFileFlags
defaultFileFlags
#endif

#if MIN_VERSION_network(3,0,0)
    go :: Fd -> IO ()
go Fd
fileFd = do Fd
sockFd <- ProtocolNumber -> Fd
Fd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Socket -> IO ProtocolNumber
fdSocket Socket
sock
                   Builder -> Fd -> IO ()
sendHeaders Builder
builder Fd
sockFd
                   Fd -> Fd -> Word64 -> Word64 -> IO ()
sendFile Fd
sockFd Fd
fileFd Word64
offset Word64
nbytes
#else
    go fileFd = do let sockFd = Fd $ fdSocket sock
                   sendHeaders builder sockFd
                   sendFile sockFd fileFd offset nbytes
#endif

#else
sendFileFunc sock buffer builder fPath offset nbytes =
    Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $
            \fileInput0 -> do
        fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
                     Streams.map byteString
        input     <- Streams.fromList [builder] >>=
                     flip Streams.appendInputStream fileInput
        output    <- Streams.makeOutputStream sendChunk >>=
                     Streams.unsafeBuilderStream (return buffer)
        Streams.supply input output
        Streams.write (Just flush) output

  where
    sendChunk (Just s) = sendAll sock s
    sendChunk Nothing  = return $! ()
#endif