{-# LANGUAGE CPP, ForeignFunctionInterface #-}
module Network.SSH.Client.LibSSH2.WaitSocket
( threadWaitRead
, threadWaitWrite
) where
import Network.Socket(Socket)
#if MIN_VERSION_network(3,0,0)
import Network.Socket(withFdSocket)
#else
import Network.Socket(fdSocket)
#endif
import System.Posix.Types(Fd(Fd))
#ifdef mingw32_HOST_OS
import Control.Concurrent(forkIO,newEmptyMVar,putMVar,takeMVar)
import Control.Exception(IOException,throwIO,try)
import Control.Exception.Base(mask_)
import Foreign.C.Error(throwErrnoIfMinus1_)
import Foreign.C.Types(CInt(CInt))
import System.IO(hWaitForInput,stdin)
#else
import qualified GHC.Conc (threadWaitRead, threadWaitWrite)
#endif
threadWaitRead :: Socket -> IO ()
#if MIN_VERSION_network(3,0,0)
threadWaitRead :: Socket -> IO ()
threadWaitRead = (Socket -> (CInt -> IO ()) -> IO ())
-> (CInt -> IO ()) -> Socket -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket (Fd -> IO ()
threadWaitRead_ (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
Fd)
#else
threadWaitRead = threadWaitRead_ . Fd . fdSocket
#endif
threadWaitWrite :: Socket -> IO ()
#if MIN_VERSION_network(3,0,0)
threadWaitWrite :: Socket -> IO ()
threadWaitWrite = (Socket -> (CInt -> IO ()) -> IO ())
-> (CInt -> IO ()) -> Socket -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket (Fd -> IO ()
threadWaitWrite_ (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
Fd)
#else
threadWaitWrite = threadWaitWrite_ . Fd . fdSocket
#endif
threadWaitRead_ :: Fd -> IO ()
threadWaitRead_ :: Fd -> IO ()
threadWaitRead_ Fd
fd
#ifdef mingw32_HOST_OS
| threaded = withThread (waitFd fd 0)
| otherwise = case fd of
0 -> do
_ <- hWaitForInput stdin (-1)
return ()
_ ->
error "threadWaitRead requires -threaded on Windows, or use System.IO.hWaitForInput"
#else
= Fd -> IO ()
GHC.Conc.threadWaitRead Fd
fd
#endif
threadWaitWrite_ :: Fd -> IO ()
threadWaitWrite_ :: Fd -> IO ()
threadWaitWrite_ Fd
fd
#ifdef mingw32_HOST_OS
| threaded = withThread (waitFd fd 1)
| otherwise = error "threadWaitWrite requires -threaded on Windows"
#else
= Fd -> IO ()
GHC.Conc.threadWaitWrite Fd
fd
#endif
#ifdef mingw32_HOST_OS
foreign import ccall unsafe "rtsSupportsBoundThreads" threaded:: Bool
withThread :: IO a -> IO a
withThread io = do
m <- newEmptyMVar
_ <- mask_ $ forkIO $ try io >>= putMVar m
x <- takeMVar m
case x of
Right a -> return a
Left e -> throwIO (e :: IOException)
waitFd :: Fd -> CInt -> IO ()
waitFd fd write =
throwErrnoIfMinus1_ "fdReady" $ fdReady (fromIntegral fd) write iNFINITE 1
where
iNFINITE :: CInt
iNFINITE = 0xFFFFFFFF
foreign import ccall safe "fdReady"
fdReady:: CInt
-> CInt
-> CInt
-> CInt
-> IO CInt
#endif