{-# LANGUAGE CPP #-}

-----------------------------------------------------------------------------
-- |
-- This module allows you to set per-connection keep alive parameters on windows and linux enviroments.
-- For more information on keep alive signals see https://en.wikipedia.org/wiki/Keepalive.
-- See also https://tldp.org/HOWTO/html_single/TCP-Keepalive-HOWTO/ for a linux specific implementation.
--
-- The module is meant to be used in conjuction with the "network" package. However, in order to ensure adaptability, all functions require
-- a socket file descriptor instead of an implementation dependent socket type. For the network package, such a  descriptor can be obtained
-- with the withFdSocket function:
--
-- > -- sock is a Socket type
-- > withFdSocket sock $ \fd -> do
-- >    before <- getKeepAliveOnOff fd
-- >    print before -- False
-- >    -- set keep alive on, idle 60 seconds, interval 2 seconds
-- >    rlt <- setKeepAlive fd $ KeepAlive True 60 2
-- >    case rlt of
-- >        Left err -> print err
-- >        Right () -> return ()
-- >    after <- getKeepAliveOnOff fd
-- >    print after -- True
--
-- Please note that only the envocing process can manipulate sockets based on their file descriptors.

module Network.Socket.KeepAlive
    ( KeepAlive (..)
    , KeepAliveError (..)
    , setKeepAlive
    , getKeepAliveOnOff
    ) where

import           Foreign    (Word32)
import           Foreign.C  (CInt)
import           LibForeign (getKeepAliveOnOff_, setKeepAlive_)


-- | The main data structure defining keep alive parameters
data KeepAlive = KeepAlive
    { KeepAlive -> Bool
kaOnOff :: Bool
    -- ^ Turns on / off keep alive probes
    , KeepAlive -> Word32
kaIdle  :: Word32
    -- ^ The interval in seconds between the last data packet sent and the first keep alive probe
    , KeepAlive -> Word32
kaIntvl :: Word32
    -- ^ The interval in seconds between subsequential keepalive probes
    }
    deriving (Int -> KeepAlive -> ShowS
[KeepAlive] -> ShowS
KeepAlive -> String
(Int -> KeepAlive -> ShowS)
-> (KeepAlive -> String)
-> ([KeepAlive] -> ShowS)
-> Show KeepAlive
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KeepAlive] -> ShowS
$cshowList :: [KeepAlive] -> ShowS
show :: KeepAlive -> String
$cshow :: KeepAlive -> String
showsPrec :: Int -> KeepAlive -> ShowS
$cshowsPrec :: Int -> KeepAlive -> ShowS
Show, KeepAlive -> KeepAlive -> Bool
(KeepAlive -> KeepAlive -> Bool)
-> (KeepAlive -> KeepAlive -> Bool) -> Eq KeepAlive
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KeepAlive -> KeepAlive -> Bool
$c/= :: KeepAlive -> KeepAlive -> Bool
== :: KeepAlive -> KeepAlive -> Bool
$c== :: KeepAlive -> KeepAlive -> Bool
Eq, Eq KeepAlive
Eq KeepAlive
-> (KeepAlive -> KeepAlive -> Ordering)
-> (KeepAlive -> KeepAlive -> Bool)
-> (KeepAlive -> KeepAlive -> Bool)
-> (KeepAlive -> KeepAlive -> Bool)
-> (KeepAlive -> KeepAlive -> Bool)
-> (KeepAlive -> KeepAlive -> KeepAlive)
-> (KeepAlive -> KeepAlive -> KeepAlive)
-> Ord KeepAlive
KeepAlive -> KeepAlive -> Bool
KeepAlive -> KeepAlive -> Ordering
KeepAlive -> KeepAlive -> KeepAlive
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KeepAlive -> KeepAlive -> KeepAlive
$cmin :: KeepAlive -> KeepAlive -> KeepAlive
max :: KeepAlive -> KeepAlive -> KeepAlive
$cmax :: KeepAlive -> KeepAlive -> KeepAlive
>= :: KeepAlive -> KeepAlive -> Bool
$c>= :: KeepAlive -> KeepAlive -> Bool
> :: KeepAlive -> KeepAlive -> Bool
$c> :: KeepAlive -> KeepAlive -> Bool
<= :: KeepAlive -> KeepAlive -> Bool
$c<= :: KeepAlive -> KeepAlive -> Bool
< :: KeepAlive -> KeepAlive -> Bool
$c< :: KeepAlive -> KeepAlive -> Bool
compare :: KeepAlive -> KeepAlive -> Ordering
$ccompare :: KeepAlive -> KeepAlive -> Ordering
$cp1Ord :: Eq KeepAlive
Ord)

-- | Errors starting with WSA are windows specific
data KeepAliveError
    = WSA_IO_PENDING
    | WSA_OPERATION_ABORTED
    | WSAEFAULT
    | WSAEINPROGRESS
    | WSAEINTR
    | WSAEINVAL
    | WSAENETDOWN
    | WSAENOPROTOOPT
    | WSAENOTSOCK
    | WSAEOPNOTSUPP
    | EBADF
    | EDOM
    | EINVAL
    | EISCONN
    | ENOPROTOOPT
    | ENOTSOCK
    | ENOMEM
    | ENOBUFS
    | OTHER_KEEPALIVE_ERROR CInt
    deriving (Int -> KeepAliveError -> ShowS
[KeepAliveError] -> ShowS
KeepAliveError -> String
(Int -> KeepAliveError -> ShowS)
-> (KeepAliveError -> String)
-> ([KeepAliveError] -> ShowS)
-> Show KeepAliveError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KeepAliveError] -> ShowS
$cshowList :: [KeepAliveError] -> ShowS
show :: KeepAliveError -> String
$cshow :: KeepAliveError -> String
showsPrec :: Int -> KeepAliveError -> ShowS
$cshowsPrec :: Int -> KeepAliveError -> ShowS
Show, KeepAliveError -> KeepAliveError -> Bool
(KeepAliveError -> KeepAliveError -> Bool)
-> (KeepAliveError -> KeepAliveError -> Bool) -> Eq KeepAliveError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KeepAliveError -> KeepAliveError -> Bool
$c/= :: KeepAliveError -> KeepAliveError -> Bool
== :: KeepAliveError -> KeepAliveError -> Bool
$c== :: KeepAliveError -> KeepAliveError -> Bool
Eq, Eq KeepAliveError
Eq KeepAliveError
-> (KeepAliveError -> KeepAliveError -> Ordering)
-> (KeepAliveError -> KeepAliveError -> Bool)
-> (KeepAliveError -> KeepAliveError -> Bool)
-> (KeepAliveError -> KeepAliveError -> Bool)
-> (KeepAliveError -> KeepAliveError -> Bool)
-> (KeepAliveError -> KeepAliveError -> KeepAliveError)
-> (KeepAliveError -> KeepAliveError -> KeepAliveError)
-> Ord KeepAliveError
KeepAliveError -> KeepAliveError -> Bool
KeepAliveError -> KeepAliveError -> Ordering
KeepAliveError -> KeepAliveError -> KeepAliveError
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KeepAliveError -> KeepAliveError -> KeepAliveError
$cmin :: KeepAliveError -> KeepAliveError -> KeepAliveError
max :: KeepAliveError -> KeepAliveError -> KeepAliveError
$cmax :: KeepAliveError -> KeepAliveError -> KeepAliveError
>= :: KeepAliveError -> KeepAliveError -> Bool
$c>= :: KeepAliveError -> KeepAliveError -> Bool
> :: KeepAliveError -> KeepAliveError -> Bool
$c> :: KeepAliveError -> KeepAliveError -> Bool
<= :: KeepAliveError -> KeepAliveError -> Bool
$c<= :: KeepAliveError -> KeepAliveError -> Bool
< :: KeepAliveError -> KeepAliveError -> Bool
$c< :: KeepAliveError -> KeepAliveError -> Bool
compare :: KeepAliveError -> KeepAliveError -> Ordering
$ccompare :: KeepAliveError -> KeepAliveError -> Ordering
$cp1Ord :: Eq KeepAliveError
Ord)

-- | Set keep alive parameters for the current socket
setKeepAlive ::
    CInt
    -- ^ Socket file descriptor
    -> KeepAlive
    -- ^ Keep alive parameters
    -> IO ( Either KeepAliveError ())
setKeepAlive :: CInt -> KeepAlive -> IO (Either KeepAliveError ())
setKeepAlive CInt
fd (KeepAlive Bool
onoff Word32
idle Word32
intvl) = do
    CInt
rlt <- CInt -> Word32 -> Word32 -> Word32 -> IO CInt
setKeepAlive_ CInt
fd (Bool -> Word32
cFromBool Bool
onoff) Word32
idle Word32
intvl
    Either KeepAliveError () -> IO (Either KeepAliveError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either KeepAliveError () -> IO (Either KeepAliveError ()))
-> Either KeepAliveError () -> IO (Either KeepAliveError ())
forall a b. (a -> b) -> a -> b
$ case CInt
rlt of
        CInt
0     -> () -> Either KeepAliveError ()
forall a b. b -> Either a b
Right ()
        CInt
997   -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSA_IO_PENDING
        CInt
995   -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSA_OPERATION_ABORTED
        CInt
10014 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAEFAULT
        CInt
10036 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAEINPROGRESS
        CInt
10004 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAEINTR
        CInt
10022 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAEINVAL
        CInt
10050 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAENETDOWN
        CInt
10042 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAENOPROTOOPT
        CInt
10038 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAENOTSOCK
        CInt
10045 -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
WSAEOPNOTSUPP
        CInt
9     -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
EBADF
        CInt
33    -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
EDOM
        CInt
22    -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
EINVAL
        CInt
106   -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
EISCONN
        CInt
92    -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
ENOPROTOOPT
        CInt
88    -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
ENOTSOCK
        CInt
12    -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
ENOMEM
        CInt
105   -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left KeepAliveError
ENOBUFS
        CInt
other -> KeepAliveError -> Either KeepAliveError ()
forall a b. a -> Either a b
Left (KeepAliveError -> Either KeepAliveError ())
-> KeepAliveError -> Either KeepAliveError ()
forall a b. (a -> b) -> a -> b
$ CInt -> KeepAliveError
OTHER_KEEPALIVE_ERROR CInt
other

-- | Returns True if keep alive is active for the specified socket
getKeepAliveOnOff ::
    CInt
    -- ^ Socket file descriptor
    -> IO Bool
getKeepAliveOnOff :: CInt -> IO Bool
getKeepAliveOnOff CInt
fd =
    CInt -> Bool
cToBool (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> IO CInt
getKeepAliveOnOff_ CInt
fd

cToBool :: CInt -> Bool
cToBool :: CInt -> Bool
cToBool CInt
x
    | CInt
x CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 = Bool
False
    | Bool
otherwise = Bool
True

cFromBool :: Bool -> Word32
cFromBool :: Bool -> Word32
cFromBool Bool
True  = Word32
1
cFromBool Bool
False = Word32
0