module Util.Socket
    ( connectSocket
    , serverSocket
    , acceptSocket
    , closeSock
    , Socket()
    ) where

import qualified Data.ByteString       as BS
import qualified Data.ByteString.Char8 as CS
import           Data.Word
import           Network.Socket        hiding (recv, recvFrom, send, sendTo)
import           Util.IOExtra

--------------------------------------------------------------------------------
connectSocket :: BS.ByteString -> Word16 -> IO Socket
connectSocket :: ByteString -> Word16 -> IO Socket
connectSocket ByteString
hostName Word16
portNumber = do
    (Socket
sock, SockAddr
sa) <- ByteString -> Maybe Word16 -> IO (Socket, SockAddr)
forall (m :: * -> *).
MonadIO m =>
ByteString -> Maybe Word16 -> m (Socket, SockAddr)
createSocket ByteString
hostName (Word16 -> Maybe Word16
forall a. a -> Maybe a
Just Word16
portNumber)
    (SomeException -> IO Socket) -> IO Socket -> IO Socket
forall (m :: * -> *) a.
MonadCatch m =>
(SomeException -> m a) -> m a -> m a
handleAll (\SomeException
e -> Socket -> IO ()
closeSock Socket
sock IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO Socket
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
e) (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ do
        Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
NoDelay Int
1
        Socket -> SockAddr -> IO ()
connect Socket
sock SockAddr
sa
        Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

serverSocket :: BS.ByteString -> IO (Socket, Word16)
serverSocket :: ByteString -> IO (Socket, Word16)
serverSocket ByteString
hostName = do
    (Socket
sock, SockAddr
sa) <- ByteString -> Maybe Word16 -> IO (Socket, SockAddr)
forall (m :: * -> *).
MonadIO m =>
ByteString -> Maybe Word16 -> m (Socket, SockAddr)
createSocket ByteString
hostName Maybe Word16
forall a. Maybe a
Nothing
    (SomeException -> IO (Socket, Word16))
-> IO (Socket, Word16) -> IO (Socket, Word16)
forall (m :: * -> *) a.
MonadCatch m =>
(SomeException -> m a) -> m a -> m a
handleAll (\SomeException
e -> Socket -> IO ()
closeSock Socket
sock IO () -> IO (Socket, Word16) -> IO (Socket, Word16)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO (Socket, Word16)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
e) (IO (Socket, Word16) -> IO (Socket, Word16))
-> IO (Socket, Word16) -> IO (Socket, Word16)
forall a b. (a -> b) -> a -> b
$ do
        Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
sa
        Socket -> Int -> IO ()
listen Socket
sock Int
5
        PortNumber
port <- Socket -> IO PortNumber
socketPort Socket
sock
        (Socket, Word16) -> IO (Socket, Word16)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sock, PortNumber -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port)

acceptSocket :: Socket -> IO Socket
acceptSocket :: Socket -> IO Socket
acceptSocket Socket
sock = do
    (Socket
sock', SockAddr
_sa) <- Socket -> IO (Socket, SockAddr)
accept Socket
sock
    (SomeException -> IO Socket) -> IO Socket -> IO Socket
forall (m :: * -> *) a.
MonadCatch m =>
(SomeException -> m a) -> m a -> m a
handleAll (\SomeException
e -> Socket -> IO ()
closeSock Socket
sock' IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO Socket
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
e) (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ do
      Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock' SocketOption
NoDelay Int
1
      Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock'

closeSock :: Socket -> IO ()
closeSock :: Socket -> IO ()
closeSock = Socket -> IO ()
close

createSocket :: MonadIO m
             => BS.ByteString
             -> Maybe Word16
             -> m (Socket, SockAddr)
createSocket :: ByteString -> Maybe Word16 -> m (Socket, SockAddr)
createSocket ByteString
hostName Maybe Word16
portNumber =
    IO (Socket, SockAddr) -> m (Socket, SockAddr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Socket, SockAddr) -> m (Socket, SockAddr))
-> IO (Socket, SockAddr) -> m (Socket, SockAddr)
forall a b. (a -> b) -> a -> b
$ do
        AddrInfo
ai <- IO AddrInfo
addrInfo
        Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
ai) (AddrInfo -> SocketType
addrSocketType AddrInfo
ai) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
ai)
        (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sock, AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
  where
    addrInfo :: IO AddrInfo
addrInfo = do
        let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [ AddrInfoFlag
AI_CANONNAME
                                               , AddrInfoFlag
AI_NUMERICSERV
                                               , AddrInfoFlag
AI_ADDRCONFIG
                                               ]
                                 , addrFamily :: Family
addrFamily = Family
AF_INET
                                 , addrSocketType :: SocketType
addrSocketType = SocketType
Stream
                                 }
        (AddrInfo
ai : [AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints)
                                (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (ByteString -> HostName
CS.unpack ByteString
hostName))
                                (Word16 -> HostName
forall a. Show a => a -> HostName
show (Word16 -> HostName) -> Maybe Word16 -> Maybe HostName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Word16
portNumber)
        AddrInfo -> IO AddrInfo
forall (m :: * -> *) a. Monad m => a -> m a
return AddrInfo
ai