{-# language BangPatterns #-}
{-# language BinaryLiterals #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language EmptyCase #-}
module Network.Icmp.Ping.Multihosts
( multihosts
, multirange
) where
import Control.Applicative ((<|>))
import Control.Concurrent (threadWaitReadSTM,threadWaitWrite)
import Control.Concurrent.STM.TVar (readTVar,registerDelay)
import Control.Exception (onException,mask)
import Control.Monad.Trans.Except (ExceptT(..),runExceptT)
import Data.Functor (($>))
import Data.Primitive (PrimArray,MutableByteArray,MutablePrimArray)
import Data.Word (Word64,Word8,Word16,Word32)
import Foreign.C.Error (Errno(..),eAGAIN,eWOULDBLOCK,eACCES)
import Foreign.C.Types (CSize(..))
import GHC.Clock (getMonotonicTimeNSec)
import GHC.Exts (RealWorld)
import GHC.IO (IO(..))
import Net.Types (IPv4(..),IPv4Range)
import Network.Icmp.Common (IcmpException(..))
import Network.Icmp.Marshal (peekIcmpHeaderPayload)
import Network.Icmp.Marshal (peekIcmpHeaderSequenceNumber)
import Network.Icmp.Marshal (sizeOfIcmpHeader,pokeIcmpHeader)
import Network.Icmp.Ping.Debug (debug)
import Posix.Socket (SocketAddressInternet(..))
import System.Endian (toBE32)
import System.Posix.Types (Fd(..))
import Unsafe.Coerce (unsafeCoerce)
import qualified Control.Monad.STM as STM
import qualified Data.Map.Unboxed.Unlifted as MUN
import qualified Data.Primitive as PM
import qualified Data.Set.Unboxed as SU
import qualified Linux.Socket as SCK
import qualified Net.IPv4 as IPv4
import qualified Posix.Socket as SCK
fullPacketSize :: Int
fullPacketSize = sizeOfIcmpHeader + 4
waitForRead ::
Bool
-> Int
-> Fd
-> IO Bool
waitForRead !shouldRead !maxWaitTime !sock = if maxWaitTime > 0 && shouldRead
then do
debug ("About to wait for " ++ show maxWaitTime ++ " microseconds")
(isReadyAction,deregister) <- threadWaitReadSTM sock
delay <- registerDelay maxWaitTime
isContentReady <- STM.atomically $
(isReadyAction $> True)
<|>
(do isDone <- readTVar delay
STM.check isDone
pure False
)
deregister
pure isContentReady
else pure False
multihosts ::
Int
-> Int
-> Int
-> Int
-> SU.Set IPv4
-> IO (Either IcmpException (MUN.Map IPv4 (PrimArray Word64)))
multihosts !pause !successPause' !totalPings !cutoff !theHosts
| pause <= 0 || totalPings <= 0 || cutoff <= 0 || SU.null theHosts = pure (Right mempty)
| otherwise = let !successPause = max successPause' 0 in mask $ \restore -> SCK.uninterruptibleSocket SCK.internet SCK.datagram SCK.icmp >>= \case
Left (Errno e) -> pure (Left (IcmpExceptionSocket e))
Right sock -> do
!now0 <- getMonotonicTimeNSec
!buffer <- PM.newByteArray fullPacketSize
!durations <- restore
( do let nanoPause = intToWord64 pause * 1000
let nanoSuccessPause = intToWord64 successPause * 1000
eworking <- runExceptT $ MUN.fromSetP
(\theHost -> ExceptT $ do
m <- PM.newPrimArray (totalPings + 4)
PM.setPrimArray m 0 (totalPings + 4) (0 :: Word64)
debug ("Sending initial to " ++ show theHost)
performSend 0 now0 nanoPause sock totalPings theHost buffer m >>= \case
Left err -> pure (Left err)
Right _ -> pure (Right m)
) theHosts
case eworking of
Left err -> pure (Left err)
Right working -> do
let go :: Word64 -> Word64 -> IO (Either IcmpException ())
go !currentPause !nextTime = do
let shouldRead = currentPause <= nanoPause
let microPause = div currentPause 1000
waitForRead shouldRead (word64ToInt microPause) sock >>= \case
True -> do
debug "Receiving in poll loop"
r <- SCK.uninterruptibleReceiveFromMutableByteArray_ sock buffer 0 (intToCSize fullPacketSize) SCK.dontWait
case r of
Left (Errno e) -> pure (Left (IcmpExceptionReceive e))
Right receivedBytes -> if receivedBytes == intToCSize fullPacketSize
then do
payload' <- peekIcmpHeaderPayload buffer
end <- getMonotonicTimeNSec
case MUN.lookup (IPv4 payload') working of
Nothing -> go (end - nextTime) nextTime
Just durations -> do
sequenceNumber' <- peekIcmpHeaderSequenceNumber buffer
sequenceNumber <- PM.readPrimArray durations (totalPings + 0)
if word16ToWord64 sequenceNumber' == sequenceNumber
then do
sentTime <- PM.readPrimArray durations (totalPings + 2)
successes <- PM.readPrimArray durations (totalPings + 1)
PM.writePrimArray durations (word64ToInt successes) (end - sentTime)
PM.writePrimArray durations (totalPings + 1) (successes + 1)
PM.writePrimArray durations (totalPings + 2) end
PM.writePrimArray durations (totalPings + 3) pendingSend
let possibleNextTime = end + nanoSuccessPause
if possibleNextTime < nextTime
then go nanoSuccessPause possibleNextTime
else go (nextTime - end) nextTime
else go (nextTime - end) nextTime
else do
end <- getMonotonicTimeNSec
go (nextTime - end) nextTime
False -> do
debug "Updating in poll loop"
currentTime <- getMonotonicTimeNSec
r <- runExceptT $ MUN.foldlMapWithKeyM'
(step sock nanoPause nanoSuccessPause totalPings cutoff buffer currentTime)
working
case r of
Left e -> pure (Left e)
Right (Time futureTime) -> if futureTime == maxBound
then pure (Right ())
else do
debug ("Waiting for " ++ show (futureTime - currentTime) ++ " nanoseconds before spanning for expirations")
go (futureTime - currentTime) futureTime
now1 <- getMonotonicTimeNSec
go nanoPause (now1 + nanoPause) >>= \case
Left e -> pure (Left e)
Right _ -> fmap Right
( MUN.mapMaybeP
(\durations -> do
successes <- PM.readPrimArray durations (totalPings + 1)
if successes == 0
then pure Nothing
else fmap Just (PM.resizeMutablePrimArray durations (word64ToInt successes) >>= PM.unsafeFreezePrimArray)
) working
)
)
`onException`
(SCK.uninterruptibleClose sock)
SCK.uninterruptibleClose sock >>= \case
Left (Errno e) -> pure (Left (IcmpExceptionClose e))
Right _ -> pure durations
newtype Time = Time Word64
instance Semigroup Time where
Time a <> Time b = Time (min a b)
instance Monoid Time where
mempty = Time maxBound
step ::
Fd
-> Word64
-> Word64
-> Int
-> Int
-> MutableByteArray RealWorld
-> Word64
-> IPv4
-> MutablePrimArray RealWorld Word64
-> ExceptT IcmpException IO Time
step !sock !pause !successPause !totalPings !cutoff !buffer !now !theHost !durations = ExceptT $ do
attemptedPings <- PM.readPrimArray durations (totalPings + 0)
if word64ToInt attemptedPings < totalPings
then do
successPings <- PM.readPrimArray durations (totalPings + 1)
debug ("Detected " ++ show attemptedPings ++ " attempted pings and " ++ show successPings ++ " successes")
if word64ToInt attemptedPings >= cutoff && successPings == 0
then pure (Right mempty)
else do
theState <- PM.readPrimArray durations (totalPings + 3)
if theState == pendingReceive
then do
sendTime <- PM.readPrimArray durations (totalPings + 2)
if sendTime + pause < now
then performSend attemptedPings now pause sock totalPings theHost buffer durations
else pure (Right (Time (sendTime + pause)))
else do
receiveTime <- PM.readPrimArray durations (totalPings + 2)
if receiveTime + successPause < now
then performSend attemptedPings now pause sock totalPings theHost buffer durations
else pure (Right (Time (receiveTime + successPause)))
else pure (Right mempty)
performSend :: Word64 -> Word64 -> Word64 -> Fd -> Int -> IPv4 -> MutableByteArray RealWorld -> MutablePrimArray RealWorld Word64 -> IO (Either IcmpException Time)
performSend attemptedPings now pause sock totalPings theHost buffer durations = do
PM.writePrimArray durations (totalPings + 2) now
PM.writePrimArray durations (totalPings + 0) (attemptedPings + 1)
PM.setByteArray buffer 0 sizeOfIcmpHeader (0 :: Word8)
pokeIcmpHeader buffer (word64ToWord16 (attemptedPings + 1)) (getIPv4 theHost)
let sockaddr = SCK.encodeSocketAddressInternet
(SocketAddressInternet { port = 0, address = toBE32 (getIPv4 theHost) })
mwriteError <- writeWhenReady
(SCK.uninterruptibleSendToMutableByteArray sock buffer 0 (intToCSize fullPacketSize) SCK.dontWait sockaddr)
(threadWaitWrite sock)
case mwriteError of
Left (Errno e)
| Errno e == eACCES -> do
PM.writePrimArray durations (totalPings + 0) (intToWord64 totalPings)
PM.writePrimArray durations (totalPings + 3) pendingSend
pure (Right mempty)
| otherwise -> pure (Left (IcmpExceptionSend e))
Right sentBytes -> if sentBytes == intToCSize fullPacketSize
then do
PM.writePrimArray durations (totalPings + 3) pendingReceive
pure (Right (Time (now + pause)))
else pure (Left (IcmpExceptionSendBytes sentBytes))
pendingReceive :: Word64
pendingReceive = 0
pendingSend :: Word64
pendingSend = 1
word64ToWord16 :: Word64 -> Word16
word64ToWord16 = fromIntegral
word16ToWord64 :: Word16 -> Word64
word16ToWord64 = fromIntegral
intToWord64 :: Int -> Word64
intToWord64 = fromIntegral
word64ToInt :: Word64 -> Int
word64ToInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
writeWhenReady
:: IO (Either Errno CSize)
-> IO ()
-> IO (Either Errno CSize)
writeWhenReady f wait = f >>= \case
Left e1 -> if e1 == eWOULDBLOCK || e1 == eAGAIN
then wait *> f
else pure (Left e1)
Right i -> pure (Right i)
multirange ::
Int
-> Int
-> Int
-> Int
-> IPv4Range
-> IO (Either IcmpException (MUN.Map IPv4 (PrimArray Word64)))
multirange !pause !successPause !totalPings !cutoff !r =
multihosts pause successPause totalPings cutoff $ coerceIPv4Set
(SU.enumFromTo
(getIPv4 (IPv4.lowerInclusive r))
(getIPv4 (IPv4.upperInclusive r))
)
coerceIPv4Set :: SU.Set Word32 -> SU.Set IPv4
coerceIPv4Set = unsafeCoerce