-- | Internal socket utilities implementing missing
--   features of 'System.Socket' which are yet to be
--   upstreamed.
module Network.Gopher.Util.Socket
  ( gracefulClose
  ) where

import Control.Concurrent.MVar (withMVar)
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import Control.Exception.Base (throwIO)
import Control.Monad (void, when)
import Data.Functor ((<&>))
import Foreign.C.Error (Errno (..), getErrno)
import Foreign.C.Types (CInt (..))
import System.Socket (receive, msgNoSignal, SocketException (..), close, Family ())
import System.Socket.Type.Stream (Stream ())
import System.Socket.Protocol.TCP (TCP ())
import System.Socket.Unsafe (Socket (..))

-- Until https://github.com/lpeterse/haskell-socket/pull/67 gets
-- merged, we have to implement shutdown ourselves.
foreign import ccall unsafe "shutdown"
  c_shutdown :: CInt -> CInt -> IO CInt

data ShutdownHow
  -- | Disallow Reading (calls to 'receive' are empty).
  = ShutdownRead
  -- | Disallow Writing (calls to 'send' throw).
  | ShutdownWrite
  -- | Disallow both.
  | ShutdownReadWrite
  deriving (Int -> ShutdownHow -> ShowS
[ShutdownHow] -> ShowS
ShutdownHow -> String
(Int -> ShutdownHow -> ShowS)
-> (ShutdownHow -> String)
-> ([ShutdownHow] -> ShowS)
-> Show ShutdownHow
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ShutdownHow] -> ShowS
$cshowList :: [ShutdownHow] -> ShowS
show :: ShutdownHow -> String
$cshow :: ShutdownHow -> String
showsPrec :: Int -> ShutdownHow -> ShowS
$cshowsPrec :: Int -> ShutdownHow -> ShowS
Show, ShutdownHow -> ShutdownHow -> Bool
(ShutdownHow -> ShutdownHow -> Bool)
-> (ShutdownHow -> ShutdownHow -> Bool) -> Eq ShutdownHow
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ShutdownHow -> ShutdownHow -> Bool
$c/= :: ShutdownHow -> ShutdownHow -> Bool
== :: ShutdownHow -> ShutdownHow -> Bool
$c== :: ShutdownHow -> ShutdownHow -> Bool
Eq, Eq ShutdownHow
Eq ShutdownHow
-> (ShutdownHow -> ShutdownHow -> Ordering)
-> (ShutdownHow -> ShutdownHow -> Bool)
-> (ShutdownHow -> ShutdownHow -> Bool)
-> (ShutdownHow -> ShutdownHow -> Bool)
-> (ShutdownHow -> ShutdownHow -> Bool)
-> (ShutdownHow -> ShutdownHow -> ShutdownHow)
-> (ShutdownHow -> ShutdownHow -> ShutdownHow)
-> Ord ShutdownHow
ShutdownHow -> ShutdownHow -> Bool
ShutdownHow -> ShutdownHow -> Ordering
ShutdownHow -> ShutdownHow -> ShutdownHow
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ShutdownHow -> ShutdownHow -> ShutdownHow
$cmin :: ShutdownHow -> ShutdownHow -> ShutdownHow
max :: ShutdownHow -> ShutdownHow -> ShutdownHow
$cmax :: ShutdownHow -> ShutdownHow -> ShutdownHow
>= :: ShutdownHow -> ShutdownHow -> Bool
$c>= :: ShutdownHow -> ShutdownHow -> Bool
> :: ShutdownHow -> ShutdownHow -> Bool
$c> :: ShutdownHow -> ShutdownHow -> Bool
<= :: ShutdownHow -> ShutdownHow -> Bool
$c<= :: ShutdownHow -> ShutdownHow -> Bool
< :: ShutdownHow -> ShutdownHow -> Bool
$c< :: ShutdownHow -> ShutdownHow -> Bool
compare :: ShutdownHow -> ShutdownHow -> Ordering
$ccompare :: ShutdownHow -> ShutdownHow -> Ordering
$cp1Ord :: Eq ShutdownHow
Ord, Int -> ShutdownHow
ShutdownHow -> Int
ShutdownHow -> [ShutdownHow]
ShutdownHow -> ShutdownHow
ShutdownHow -> ShutdownHow -> [ShutdownHow]
ShutdownHow -> ShutdownHow -> ShutdownHow -> [ShutdownHow]
(ShutdownHow -> ShutdownHow)
-> (ShutdownHow -> ShutdownHow)
-> (Int -> ShutdownHow)
-> (ShutdownHow -> Int)
-> (ShutdownHow -> [ShutdownHow])
-> (ShutdownHow -> ShutdownHow -> [ShutdownHow])
-> (ShutdownHow -> ShutdownHow -> [ShutdownHow])
-> (ShutdownHow -> ShutdownHow -> ShutdownHow -> [ShutdownHow])
-> Enum ShutdownHow
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: ShutdownHow -> ShutdownHow -> ShutdownHow -> [ShutdownHow]
$cenumFromThenTo :: ShutdownHow -> ShutdownHow -> ShutdownHow -> [ShutdownHow]
enumFromTo :: ShutdownHow -> ShutdownHow -> [ShutdownHow]
$cenumFromTo :: ShutdownHow -> ShutdownHow -> [ShutdownHow]
enumFromThen :: ShutdownHow -> ShutdownHow -> [ShutdownHow]
$cenumFromThen :: ShutdownHow -> ShutdownHow -> [ShutdownHow]
enumFrom :: ShutdownHow -> [ShutdownHow]
$cenumFrom :: ShutdownHow -> [ShutdownHow]
fromEnum :: ShutdownHow -> Int
$cfromEnum :: ShutdownHow -> Int
toEnum :: Int -> ShutdownHow
$ctoEnum :: Int -> ShutdownHow
pred :: ShutdownHow -> ShutdownHow
$cpred :: ShutdownHow -> ShutdownHow
succ :: ShutdownHow -> ShutdownHow
$csucc :: ShutdownHow -> ShutdownHow
Enum)

-- | Shutdown a stream connection (partially).
--   Will send TCP FIN and prompt a client to
--   close the connection.
--
--   Not exposed to prevent future name clash.
shutdown :: Socket a Stream TCP -> ShutdownHow -> IO ()
shutdown :: Socket a Stream TCP -> ShutdownHow -> IO ()
shutdown (Socket MVar Fd
mvar) ShutdownHow
how = MVar Fd -> (Fd -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar Fd
mvar ((Fd -> IO ()) -> IO ()) -> (Fd -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Fd
fd -> do
  CInt
res <- CInt -> CInt -> IO CInt
c_shutdown (Fd -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Fd
fd)
    (CInt -> IO CInt) -> CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ ShutdownHow -> Int
forall a. Enum a => a -> Int
fromEnum ShutdownHow
how
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
res CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ SocketException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (SocketException -> IO ()) -> IO SocketException -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    (IO Errno
getErrno IO Errno -> (Errno -> SocketException) -> IO SocketException
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Errno CInt
errno) -> CInt -> SocketException
SocketException CInt
errno)

-- | Shutdown connection and give client a bit
--   of time to clean up on its end before closing
--   the connection to avoid a broken pipe on the
--   other side.
gracefulClose :: Family f => Socket f Stream TCP -> IO ()
gracefulClose :: Socket f Stream TCP -> IO ()
gracefulClose Socket f Stream TCP
sock = do
  -- send TCP FIN
  Socket f Stream TCP -> ShutdownHow -> IO ()
forall a. Socket a Stream TCP -> ShutdownHow -> IO ()
shutdown Socket f Stream TCP
sock ShutdownHow
ShutdownWrite
  -- wait for some kind of read from the
  -- client (either mempty, meaning TCP FIN,
  -- something else which would mean protocol
  -- violation). Give up after 1s.
  Either () ()
_ <- IO () -> IO () -> IO (Either () ())
forall a b. IO a -> IO b -> IO (Either a b)
race (IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket f Stream TCP -> Int -> MessageFlags -> IO ByteString
forall f t p. Socket f t p -> Int -> MessageFlags -> IO ByteString
receive Socket f Stream TCP
sock Int
16 MessageFlags
msgNoSignal) (Int -> IO ()
threadDelay Int
1000000)
  Socket f Stream TCP -> IO ()
forall f t p. Socket f t p -> IO ()
close Socket f Stream TCP
sock