{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Snap.Internal.Http.Server.Socket
  ( bindSocket
  , bindSocketImpl
  , bindUnixSocket
  , httpAcceptFunc
  , haProxyAcceptFunc
  , sendFileFunc
  , acceptAndInitialize
  ) where

------------------------------------------------------------------------------
import           Control.Exception                 (bracketOnError, finally, throwIO)
import           Control.Monad                     (when)
import           Data.Bits                         (complement, (.&.))
import           Data.ByteString.Char8             (ByteString)
import           Network.Socket                    (Socket, SocketOption (NoDelay, ReuseAddr), accept, close, getSocketName, setSocketOption, socket)
import qualified Network.Socket                    as N
#ifdef HAS_SENDFILE
import           Network.Socket                    (fdSocket)
import           System.Posix.IO                   (OpenMode (..), closeFd, defaultFileFlags, openFd)
import           System.Posix.Types                (Fd (..))
import           System.SendFile                   (sendFile, sendHeaders)
#else
import           Data.ByteString.Builder           (byteString)
import           Data.ByteString.Builder.Extra     (flush)
import           Network.Socket.ByteString         (sendAll)
#endif
#ifdef HAS_UNIX_SOCKETS
import           Control.Exception                 (bracket)
import qualified Control.Exception                 as E (catch)
import           System.FilePath                   (isRelative)
import           System.IO.Error                   (isDoesNotExistError)
import           System.Posix.Files                (accessModes, removeLink, setFileCreationMask)
#endif

------------------------------------------------------------------------------
import qualified System.IO.Streams                 as Streams
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Address (AddressNotSupportedException (..), getAddress, getSockAddr)
import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
import qualified System.IO.Streams.Network.HAProxy as HA


------------------------------------------------------------------------------
bindSocket :: ByteString -> Int -> IO Socket
bindSocket :: ByteString -> Int -> IO Socket
bindSocket = (Socket -> SocketOption -> Int -> IO ())
-> (Socket -> SockAddr -> IO ())
-> (Socket -> Int -> IO ())
-> ByteString
-> Int
-> IO Socket
bindSocketImpl Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket -> SockAddr -> IO ()
bind Socket -> Int -> IO ()
N.listen
  where
#if MIN_VERSION_network(2,7,0)
    bind :: Socket -> SockAddr -> IO ()
bind = Socket -> SockAddr -> IO ()
N.bind
#else
    bind = N.bindSocket
#endif
{-# INLINE bindSocket #-}


------------------------------------------------------------------------------
bindSocketImpl
    :: (Socket -> SocketOption -> Int -> IO ()) -- ^ mock setSocketOption
    -> (Socket -> N.SockAddr -> IO ())          -- ^ bindSocket
    -> (Socket -> Int -> IO ())                 -- ^ listen
    -> ByteString
    -> Int
    -> IO Socket
bindSocketImpl :: (Socket -> SocketOption -> Int -> IO ())
-> (Socket -> SockAddr -> IO ())
-> (Socket -> Int -> IO ())
-> ByteString
-> Int
-> IO Socket
bindSocketImpl Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket -> SockAddr -> IO ()
_bindSocket Socket -> Int -> IO ()
_listen ByteString
bindAddr Int
bindPort = do
    (Family
family, SockAddr
addr) <- Int -> ByteString -> IO (Family, SockAddr)
getSockAddr Int
bindPort ByteString
bindAddr
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
family SocketType
N.Stream ProtocolNumber
0) Socket -> IO ()
N.close ((Socket -> IO Socket) -> IO Socket)
-> (Socket -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
        Socket -> SocketOption -> Int -> IO ()
_setSocketOption Socket
sock SocketOption
NoDelay Int
1
        Socket -> SockAddr -> IO ()
_bindSocket Socket
sock SockAddr
addr
        Socket -> Int -> IO ()
_listen Socket
sock Int
150
        Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> IO Socket) -> Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$! Socket
sock

bindUnixSocket :: Maybe Int -> String -> IO Socket
#if HAS_UNIX_SOCKETS
bindUnixSocket :: Maybe Int -> String -> IO Socket
bindUnixSocket Maybe Int
mode String
path = do
   Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (String -> Bool
isRelative String
path) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      AddressNotSupportedException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (AddressNotSupportedException -> IO ())
-> AddressNotSupportedException -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AddressNotSupportedException
AddressNotSupportedException
                (String -> AddressNotSupportedException)
-> String -> AddressNotSupportedException
forall a b. (a -> b) -> a -> b
$! String
"Refusing to bind unix socket to non-absolute path: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
path

   IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
N.AF_UNIX SocketType
N.Stream ProtocolNumber
0) Socket -> IO ()
N.close ((Socket -> IO Socket) -> IO Socket)
-> (Socket -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      IO () -> (IOError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
E.catch (String -> IO ()
removeLink String
path) ((IOError -> IO ()) -> IO ()) -> (IOError -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \IOError
e -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ IOError -> Bool
isDoesNotExistError IOError
e) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOError
e
      case Maybe Int
mode of
         Maybe Int
Nothing -> Socket -> SockAddr -> IO ()
bind Socket
sock (String -> SockAddr
N.SockAddrUnix String
path)
         Just Int
mode' -> IO FileMode
-> (FileMode -> IO FileMode) -> (FileMode -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (FileMode -> IO FileMode
setFileCreationMask (FileMode -> IO FileMode) -> FileMode -> IO FileMode
forall a b. (a -> b) -> a -> b
$ Int -> FileMode
forall a. Integral a => a -> FileMode
modeToMask Int
mode')
                              FileMode -> IO FileMode
setFileCreationMask
                              (IO () -> FileMode -> IO ()
forall a b. a -> b -> a
const (IO () -> FileMode -> IO ()) -> IO () -> FileMode -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
bind Socket
sock (String -> SockAddr
N.SockAddrUnix String
path))
      Socket -> Int -> IO ()
N.listen Socket
sock Int
150
      Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> IO Socket) -> Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$! Socket
sock
   where
#if MIN_VERSION_network(2,7,0)
     bind :: Socket -> SockAddr -> IO ()
bind = Socket -> SockAddr -> IO ()
N.bind
#else
     bind = N.bindSocket
#endif
     modeToMask :: a -> FileMode
modeToMask a
p = FileMode
accessModes FileMode -> FileMode -> FileMode
forall a. Bits a => a -> a -> a
.&. FileMode -> FileMode
forall a. Bits a => a -> a
complement (a -> FileMode
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
p)
#else
bindUnixSocket _ path = throwIO (AddressNotSupportedException $ "unix:" ++ path)
#endif

------------------------------------------------------------------------------
-- TODO(greg): move buffer size configuration into config
bUFSIZ :: Int
bUFSIZ :: Int
bUFSIZ = Int
4064


------------------------------------------------------------------------------
acceptAndInitialize :: Socket        -- ^ bound socket
                    -> (forall b . IO b -> IO b)
                    -> ((Socket, N.SockAddr) -> IO a)
                    -> IO a
acceptAndInitialize :: Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore (Socket, SockAddr) -> IO a
f =
    IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall b. IO b -> IO b
restore (IO (Socket, SockAddr) -> IO (Socket, SockAddr))
-> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a b. (a -> b) -> a -> b
$ Socket -> IO (Socket, SockAddr)
accept Socket
boundSocket)
                   (Socket -> IO ()
close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst)
                   (Socket, SockAddr) -> IO a
f


------------------------------------------------------------------------------
haProxyAcceptFunc :: Socket     -- ^ bound socket
                  -> AcceptFunc
haProxyAcceptFunc :: Socket -> AcceptFunc
haProxyAcceptFunc Socket
boundSocket =
    ((forall b. IO b -> IO b)
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
AcceptFunc (((forall b. IO b -> IO b)
  -> IO
       (SendFileHandler, ByteString, Int, ByteString, Int,
        InputStream ByteString, OutputStream ByteString, IO ()))
 -> AcceptFunc)
-> ((forall b. IO b -> IO b)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore ->
    Socket
-> (forall b. IO b -> IO b)
-> ((Socket, SockAddr)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a.
Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore (((Socket, SockAddr)
  -> IO
       (SendFileHandler, ByteString, Int, ByteString, Int,
        InputStream ByteString, OutputStream ByteString, IO ()))
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> ((Socket, SockAddr)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
saddr) -> do
        (InputStream ByteString
readEnd, OutputStream ByteString
writeEnd)      <- Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
Streams.socketToStreamsWithBufferSize
                                        Int
bUFSIZ Socket
sock
        ProxyInfo
localPInfo               <- Socket -> SockAddr -> IO ProxyInfo
HA.socketToProxyInfo Socket
sock SockAddr
saddr
        ProxyInfo
pinfo                    <- ProxyInfo -> InputStream ByteString -> IO ProxyInfo
HA.decodeHAProxyHeaders ProxyInfo
localPInfo InputStream ByteString
readEnd
        (Int
localPort, ByteString
localHost)   <- SockAddr -> IO (Int, ByteString)
getAddress (SockAddr -> IO (Int, ByteString))
-> SockAddr -> IO (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ ProxyInfo -> SockAddr
HA.getDestAddr ProxyInfo
pinfo
        (Int
remotePort, ByteString
remoteHost) <- SockAddr -> IO (Int, ByteString)
getAddress (SockAddr -> IO (Int, ByteString))
-> SockAddr -> IO (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ ProxyInfo -> SockAddr
HA.getSourceAddr ProxyInfo
pinfo
        let cleanup :: IO ()
cleanup              =  Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe ByteString
forall a. Maybe a
Nothing OutputStream ByteString
writeEnd
                                        IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Socket -> IO ()
close Socket
sock
        (SendFileHandler, ByteString, Int, ByteString, Int,
 InputStream ByteString, OutputStream ByteString, IO ())
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((SendFileHandler, ByteString, Int, ByteString, Int,
  InputStream ByteString, OutputStream ByteString, IO ())
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> (SendFileHandler, ByteString, Int, ByteString, Int,
    InputStream ByteString, OutputStream ByteString, IO ())
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a b. (a -> b) -> a -> b
$! ( Socket -> SendFileHandler
sendFileFunc Socket
sock
                  , ByteString
localHost
                  , Int
localPort
                  , ByteString
remoteHost
                  , Int
remotePort
                  , InputStream ByteString
readEnd
                  , OutputStream ByteString
writeEnd
                  , IO ()
cleanup
                  )


------------------------------------------------------------------------------
httpAcceptFunc :: Socket                     -- ^ bound socket
               -> AcceptFunc
httpAcceptFunc :: Socket -> AcceptFunc
httpAcceptFunc Socket
boundSocket =
    ((forall b. IO b -> IO b)
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
AcceptFunc (((forall b. IO b -> IO b)
  -> IO
       (SendFileHandler, ByteString, Int, ByteString, Int,
        InputStream ByteString, OutputStream ByteString, IO ()))
 -> AcceptFunc)
-> ((forall b. IO b -> IO b)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> AcceptFunc
forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore ->
    Socket
-> (forall b. IO b -> IO b)
-> ((Socket, SockAddr)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a.
Socket
-> (forall b. IO b -> IO b) -> ((Socket, SockAddr) -> IO a) -> IO a
acceptAndInitialize Socket
boundSocket forall b. IO b -> IO b
restore (((Socket, SockAddr)
  -> IO
       (SendFileHandler, ByteString, Int, ByteString, Int,
        InputStream ByteString, OutputStream ByteString, IO ()))
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> ((Socket, SockAddr)
    -> IO
         (SendFileHandler, ByteString, Int, ByteString, Int,
          InputStream ByteString, OutputStream ByteString, IO ()))
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
remoteAddr) -> do
        SockAddr
localAddr                <- Socket -> IO SockAddr
getSocketName Socket
sock
        (Int
localPort, ByteString
localHost)   <- SockAddr -> IO (Int, ByteString)
getAddress SockAddr
localAddr
        (Int
remotePort, ByteString
remoteHost) <- SockAddr -> IO (Int, ByteString)
getAddress SockAddr
remoteAddr
        (InputStream ByteString
readEnd, OutputStream ByteString
writeEnd)      <- Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
Streams.socketToStreamsWithBufferSize Int
bUFSIZ
                                                                          Socket
sock
        let cleanup :: IO ()
cleanup              =  Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe ByteString
forall a. Maybe a
Nothing OutputStream ByteString
writeEnd
                                      IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Socket -> IO ()
close Socket
sock
        (SendFileHandler, ByteString, Int, ByteString, Int,
 InputStream ByteString, OutputStream ByteString, IO ())
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((SendFileHandler, ByteString, Int, ByteString, Int,
  InputStream ByteString, OutputStream ByteString, IO ())
 -> IO
      (SendFileHandler, ByteString, Int, ByteString, Int,
       InputStream ByteString, OutputStream ByteString, IO ()))
-> (SendFileHandler, ByteString, Int, ByteString, Int,
    InputStream ByteString, OutputStream ByteString, IO ())
-> IO
     (SendFileHandler, ByteString, Int, ByteString, Int,
      InputStream ByteString, OutputStream ByteString, IO ())
forall a b. (a -> b) -> a -> b
$! ( Socket -> SendFileHandler
sendFileFunc Socket
sock
                  , ByteString
localHost
                  , Int
localPort
                  , ByteString
remoteHost
                  , Int
remotePort
                  , InputStream ByteString
readEnd
                  , OutputStream ByteString
writeEnd
                  , IO ()
cleanup
                  )


------------------------------------------------------------------------------
sendFileFunc :: Socket -> SendFileHandler
#ifdef HAS_SENDFILE
sendFileFunc :: Socket -> SendFileHandler
sendFileFunc Socket
sock !Buffer
_ Builder
builder String
fPath Word64
offset Word64
nbytes = IO Fd -> (Fd -> IO ()) -> (Fd -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Fd
acquire Fd -> IO ()
closeFd Fd -> IO ()
go
  where
    acquire :: IO Fd
acquire   = String -> OpenMode -> Maybe FileMode -> OpenFileFlags -> IO Fd
openFd String
fPath OpenMode
ReadOnly Maybe FileMode
forall a. Maybe a
Nothing OpenFileFlags
defaultFileFlags
#if MIN_VERSION_network(3,0,0)
    go :: Fd -> IO ()
go Fd
fileFd = do Fd
sockFd <- ProtocolNumber -> Fd
Fd (ProtocolNumber -> Fd) -> IO ProtocolNumber -> IO Fd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Socket -> IO ProtocolNumber
fdSocket Socket
sock
                   Builder -> Fd -> IO ()
sendHeaders Builder
builder Fd
sockFd
                   Fd -> Fd -> Word64 -> Word64 -> IO ()
sendFile Fd
sockFd Fd
fileFd Word64
offset Word64
nbytes
#else
    go fileFd = do let sockFd = Fd $ fdSocket sock
                   sendHeaders builder sockFd
                   sendFile sockFd fileFd offset nbytes
#endif

#else
sendFileFunc sock buffer builder fPath offset nbytes =
    Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $
            \fileInput0 -> do
        fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
                     Streams.map byteString
        input     <- Streams.fromList [builder] >>=
                     flip Streams.appendInputStream fileInput
        output    <- Streams.makeOutputStream sendChunk >>=
                     Streams.unsafeBuilderStream (return buffer)
        Streams.supply input output
        Streams.write (Just flush) output

  where
    sendChunk (Just s) = sendAll sock s
    sendChunk Nothing  = return $! ()
#endif