module Network.RPC.Curryer.StreamlyAdditions where
import Control.Monad.IO.Class
import Network.Socket (Socket, PortNumber, SocketOption, SockAddr(..), maxListenQueue, Family(..), SocketType(..), defaultProtocol, tupleToHostAddress, withSocketsDo, socket, setSocketOption, bind, getSocketName)
import qualified Network.Socket as Net
import Control.Exception (onException)
import Control.Monad.Catch (finally)
import Control.Concurrent.MVar
import Data.Word
import qualified Streamly.Internal.Data.Unfold as UF
import Streamly.Network.Socket hiding (accept)
import qualified Streamly.Internal.Data.Stream.StreamD.Type as D
import Streamly.Internal.Data.Unfold.Type (Unfold(..))

acceptOnAddrWith
    :: MonadIO m
    => [(SocketOption, Int)]
    -> Maybe (MVar SockAddr)
    -> Unfold m ((Word8, Word8, Word8, Word8), PortNumber) Socket
acceptOnAddrWith :: [(SocketOption, Int)]
-> Maybe (MVar SockAddr)
-> Unfold m ((Word8, Word8, Word8, Word8), PortNumber) Socket
acceptOnAddrWith [(SocketOption, Int)]
opts Maybe (MVar SockAddr)
mSockLock = (((Word8, Word8, Word8, Word8), PortNumber)
 -> (Int, SockSpec, SockAddr))
-> Unfold m (Int, SockSpec, SockAddr) Socket
-> Unfold m ((Word8, Word8, Word8, Word8), PortNumber) Socket
forall a c (m :: * -> *) b.
(a -> c) -> Unfold m c b -> Unfold m a b
UF.lmap ((Word8, Word8, Word8, Word8), PortNumber)
-> (Int, SockSpec, SockAddr)
f (Maybe (MVar SockAddr) -> Unfold m (Int, SockSpec, SockAddr) Socket
forall (m :: * -> *).
MonadIO m =>
Maybe (MVar SockAddr) -> Unfold m (Int, SockSpec, SockAddr) Socket
accept Maybe (MVar SockAddr)
mSockLock)
    where
    f :: ((Word8, Word8, Word8, Word8), PortNumber)
-> (Int, SockSpec, SockAddr)
f ((Word8, Word8, Word8, Word8)
addr, PortNumber
port) =
        (Int
maxListenQueue
        , SockSpec :: Family
-> SocketType
-> ProtocolNumber
-> [(SocketOption, Int)]
-> SockSpec
SockSpec
            { sockFamily :: Family
sockFamily = Family
AF_INET
            , sockType :: SocketType
sockType = SocketType
Stream
            , sockProto :: ProtocolNumber
sockProto = ProtocolNumber
defaultProtocol -- TCP
            , sockOpts :: [(SocketOption, Int)]
sockOpts = [(SocketOption, Int)]
opts
            }
        , PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port ((Word8, Word8, Word8, Word8) -> HostAddress
tupleToHostAddress (Word8, Word8, Word8, Word8)
addr)
        )

accept :: MonadIO m => Maybe (MVar SockAddr) -> Unfold m (Int, SockSpec, SockAddr) Socket
accept :: Maybe (MVar SockAddr) -> Unfold m (Int, SockSpec, SockAddr) Socket
accept Maybe (MVar SockAddr)
mSockLock = ((Socket, SockAddr) -> Socket)
-> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
-> Unfold m (Int, SockSpec, SockAddr) Socket
forall (m :: * -> *) b c a.
Functor m =>
(b -> c) -> Unfold m a b -> Unfold m a c
UF.map (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst (Maybe (MVar SockAddr)
-> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
forall (m :: * -> *).
MonadIO m =>
Maybe (MVar SockAddr)
-> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
listenTuples Maybe (MVar SockAddr)
mSockLock)

initListener :: Int -> SockSpec -> SockAddr -> IO Socket
initListener :: Int -> SockSpec -> SockAddr -> IO Socket
initListener Int
listenQLen SockSpec
sockSpec SockAddr
addr =
  IO Socket -> IO Socket
forall a. IO a -> IO a
withSocketsDo (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ do
    Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (SockSpec -> Family
sockFamily SockSpec
sockSpec) (SockSpec -> SocketType
sockType SockSpec
sockSpec) (SockSpec -> ProtocolNumber
sockProto SockSpec
sockSpec)
    Socket -> IO ()
use Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`onException` Socket -> IO ()
Net.close Socket
sock
    Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

    where

    use :: Socket -> IO ()
use Socket
sock = do
        ((SocketOption, Int) -> IO ()) -> [(SocketOption, Int)] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SocketOption -> Int -> IO ()) -> (SocketOption, Int) -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock)) (SockSpec -> [(SocketOption, Int)]
sockOpts SockSpec
sockSpec)
        Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
addr
        Socket -> Int -> IO ()
Net.listen Socket
sock Int
listenQLen        

listenTuples :: MonadIO m
    => Maybe (MVar SockAddr)
    -> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
listenTuples :: Maybe (MVar SockAddr)
-> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
listenTuples Maybe (MVar SockAddr)
mSockLock = (Socket -> m (Step Socket (Socket, SockAddr)))
-> ((Int, SockSpec, SockAddr) -> m Socket)
-> Unfold m (Int, SockSpec, SockAddr) (Socket, SockAddr)
forall (m :: * -> *) a b s.
(s -> m (Step s b)) -> (a -> m s) -> Unfold m a b
Unfold Socket -> m (Step Socket (Socket, SockAddr))
forall (m :: * -> *).
MonadIO m =>
Socket -> m (Step Socket (Socket, SockAddr))
step (Int, SockSpec, SockAddr) -> m Socket
forall (m :: * -> *).
MonadIO m =>
(Int, SockSpec, SockAddr) -> m Socket
inject
 where
    inject :: (Int, SockSpec, SockAddr) -> m Socket
inject (Int
listenQLen, SockSpec
spec, SockAddr
addr) =
      IO Socket -> m Socket
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Socket -> m Socket) -> IO Socket -> m Socket
forall a b. (a -> b) -> a -> b
$ do
        Socket
sock <- Int -> SockSpec -> SockAddr -> IO Socket
initListener Int
listenQLen SockSpec
spec SockAddr
addr
        SockAddr
sockAddr <- Socket -> IO SockAddr
getSocketName Socket
sock
        case Maybe (MVar SockAddr)
mSockLock of
          Just MVar SockAddr
mvar ->
            MVar SockAddr -> SockAddr -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar SockAddr
mvar SockAddr
sockAddr
          Maybe (MVar SockAddr)
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Socket -> IO Socket
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
sock

    step :: Socket -> m (Step Socket (Socket, SockAddr))
step Socket
listener = do
        (Socket, SockAddr)
r <- IO (Socket, SockAddr) -> m (Socket, SockAddr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Socket -> IO (Socket, SockAddr)
Net.accept Socket
listener IO (Socket, SockAddr) -> IO () -> IO (Socket, SockAddr)
forall a b. IO a -> IO b -> IO a
`onException` Socket -> IO ()
Net.close Socket
listener)
        Step Socket (Socket, SockAddr)
-> m (Step Socket (Socket, SockAddr))
forall (m :: * -> *) a. Monad m => a -> m a
return (Step Socket (Socket, SockAddr)
 -> m (Step Socket (Socket, SockAddr)))
-> Step Socket (Socket, SockAddr)
-> m (Step Socket (Socket, SockAddr))
forall a b. (a -> b) -> a -> b
$ (Socket, SockAddr) -> Socket -> Step Socket (Socket, SockAddr)
forall s a. a -> s -> Step s a
D.Yield (Socket, SockAddr)
r Socket
listener

handleWithM :: (Socket -> IO ()) -> Socket -> IO ()
handleWithM :: (Socket -> IO ()) -> Socket -> IO ()
handleWithM Socket -> IO ()
f Socket
sk = IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
finally (Socket -> IO ()
f Socket
sk) (Socket -> IO ()
Net.close Socket
sk)