{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wall #-}
module Net.IPv4.Range
(
range
, fromBounds
, normalize
, contains
, member
, lowerInclusive
, upperInclusive
, toList
, toGenerator
, private24
, private20
, private16
, encode
, decode
, builder
, parser
, print
, IPv4Range(..)
) where
import Prelude hiding (print)
import Net.IPv4 (IPv4(..))
import Data.Bits ((.&.),(.|.),shiftR,complement,shift)
import Control.Monad
import Data.Text (Text)
import Data.Word (Word8,Word32,Word64)
import Data.Hashable (Hashable)
import Data.Aeson (FromJSON(..),ToJSON(..))
import GHC.Generics (Generic)
import Data.Monoid ((<>))
import qualified Net.IPv4 as IPv4
import qualified Data.Bits as Bits
import qualified Data.Text.IO as Text
import qualified Data.Attoparsec.Text as AT
import qualified Data.Text.Lazy.Builder as TBuilder
import qualified Data.Text.Lazy.Builder.Int as TBI
import qualified Data.Vector.Generic as GVector
import qualified Data.Vector.Generic.Mutable as MGVector
import qualified Data.Vector.Unboxed.Mutable as MUVector
import qualified Data.Vector.Unboxed as UVector
import qualified Data.Aeson as Aeson
import qualified Data.Text.Lazy as LText
range :: IPv4 -> Word8 -> IPv4Range
range addr len = normalize (IPv4Range addr len)
fromBounds :: IPv4 -> IPv4 -> IPv4Range
fromBounds (IPv4 a) (IPv4 b) =
let lo = min a b
hi = max a b
in normalize (IPv4Range (IPv4 lo) (maskFromBounds lo hi))
maskFromBounds :: Word32 -> Word32 -> Word8
maskFromBounds lo hi = fromIntegral (Bits.countLeadingZeros (Bits.xor lo hi))
contains :: IPv4Range -> IPv4 -> Bool
contains (IPv4Range (IPv4 wsubnet) len) =
let theMask = mask len
wsubnetNormalized = wsubnet .&. theMask
in \(IPv4 w) -> (w .&. theMask) == wsubnetNormalized
mask :: Word8 -> Word32
mask = complement . shiftR 0xffffffff . fromIntegral
member :: IPv4 -> IPv4Range -> Bool
member = flip contains
lowerInclusive :: IPv4Range -> IPv4
lowerInclusive (IPv4Range (IPv4 w) len) =
IPv4 (w .&. mask len)
upperInclusive :: IPv4Range -> IPv4
upperInclusive (IPv4Range (IPv4 w) len) =
let theInvertedMask = shiftR 0xffffffff (fromIntegral len)
theMask = complement theInvertedMask
in IPv4 ((w .&. theMask) .|. theInvertedMask)
countAddrs :: Word8 -> Word64
countAddrs w =
let amountToShift = if w > 32
then 0
else 32 - fromIntegral w
in shift 1 amountToShift
wordSuccessors :: Word64 -> IPv4 -> [IPv4]
wordSuccessors !w (IPv4 !a) = if w > 0
then IPv4 a : wordSuccessors (w - 1) (IPv4 (a + 1))
else []
wordSuccessorsM :: MonadPlus m => Word64 -> IPv4 -> m IPv4
wordSuccessorsM = go where
go !w (IPv4 !a) = if w > 0
then mplus (return (IPv4 a)) (go (w - 1) (IPv4 (a + 1)))
else mzero
toList :: IPv4Range -> [IPv4]
toList (IPv4Range ip len) =
let totalAddrs = countAddrs len
in wordSuccessors totalAddrs ip
toGenerator :: MonadPlus m => IPv4Range -> m IPv4
toGenerator (IPv4Range ip len) =
let totalAddrs = countAddrs len
in wordSuccessorsM totalAddrs ip
private24 :: IPv4Range
private24 = IPv4Range (IPv4.fromOctets 10 0 0 0) 8
private20 :: IPv4Range
private20 = IPv4Range (IPv4.fromOctets 172 16 0 0) 12
private16 :: IPv4Range
private16 = IPv4Range (IPv4.fromOctets 192 168 0 0) 16
normalize :: IPv4Range -> IPv4Range
normalize (IPv4Range (IPv4 w) len) =
let len' = min len 32
w' = w .&. mask len'
in IPv4Range (IPv4 w') len'
encode :: IPv4Range -> Text
encode = rangeToDotDecimalText
decode :: Text -> Maybe IPv4Range
decode = rightToMaybe . AT.parseOnly (parser <* AT.endOfInput)
builder :: IPv4Range -> TBuilder.Builder
builder = rangeToDotDecimalBuilder
parser :: AT.Parser IPv4Range
parser = do
ip <- IPv4.parser
_ <- AT.char '/'
theMask <- AT.decimal >>= limitSize
return (normalize (IPv4Range ip theMask))
where
limitSize i =
if i > 32
then fail "An IP range length must be between 0 and 32"
else return i
print :: IPv4Range -> IO ()
print = Text.putStrLn . encode
rightToMaybe :: Either a b -> Maybe b
rightToMaybe = either (const Nothing) Just
data IPv4Range = IPv4Range
{ ipv4RangeBase :: {-# UNPACK #-} !IPv4
, ipv4RangeLength :: {-# UNPACK #-} !Word8
} deriving (Eq,Ord,Show,Read,Generic)
instance Hashable IPv4Range
instance ToJSON IPv4Range where
toJSON = Aeson.String . encode
instance FromJSON IPv4Range where
parseJSON (Aeson.String t) = case decode t of
Nothing -> fail "Could not decode IPv4 range"
Just res -> return res
parseJSON _ = mzero
data instance MUVector.MVector s IPv4Range = MV_IPv4Range
!(MUVector.MVector s IPv4)
!(MUVector.MVector s Word8)
data instance UVector.Vector IPv4Range = V_IPv4Range
!(UVector.Vector IPv4)
!(UVector.Vector Word8)
instance UVector.Unbox IPv4Range
instance MGVector.MVector MUVector.MVector IPv4Range where
{-# INLINE basicLength #-}
basicLength (MV_IPv4Range as _) = MGVector.basicLength as
{-# INLINE basicUnsafeSlice #-}
basicUnsafeSlice i_ m_ (MV_IPv4Range as bs)
= MV_IPv4Range (MGVector.basicUnsafeSlice i_ m_ as)
(MGVector.basicUnsafeSlice i_ m_ bs)
{-# INLINE basicOverlaps #-}
basicOverlaps (MV_IPv4Range as1 bs1) (MV_IPv4Range as2 bs2)
= MGVector.basicOverlaps as1 as2
|| MGVector.basicOverlaps bs1 bs2
{-# INLINE basicUnsafeNew #-}
basicUnsafeNew n_
= do
as <- MGVector.basicUnsafeNew n_
bs <- MGVector.basicUnsafeNew n_
return $ MV_IPv4Range as bs
{-# INLINE basicInitialize #-}
basicInitialize (MV_IPv4Range as bs)
= do
MGVector.basicInitialize as
MGVector.basicInitialize bs
{-# INLINE basicUnsafeReplicate #-}
basicUnsafeReplicate n_ (IPv4Range a b)
= do
as <- MGVector.basicUnsafeReplicate n_ a
bs <- MGVector.basicUnsafeReplicate n_ b
return (MV_IPv4Range as bs)
{-# INLINE basicUnsafeRead #-}
basicUnsafeRead (MV_IPv4Range as bs) i_
= do
a <- MGVector.basicUnsafeRead as i_
b <- MGVector.basicUnsafeRead bs i_
return (IPv4Range a b)
{-# INLINE basicUnsafeWrite #-}
basicUnsafeWrite (MV_IPv4Range as bs) i_ (IPv4Range a b)
= do
MGVector.basicUnsafeWrite as i_ a
MGVector.basicUnsafeWrite bs i_ b
{-# INLINE basicClear #-}
basicClear (MV_IPv4Range as bs)
= do
MGVector.basicClear as
MGVector.basicClear bs
{-# INLINE basicSet #-}
basicSet (MV_IPv4Range as bs) (IPv4Range a b)
= do
MGVector.basicSet as a
MGVector.basicSet bs b
{-# INLINE basicUnsafeCopy #-}
basicUnsafeCopy (MV_IPv4Range as1 bs1) (MV_IPv4Range as2 bs2)
= do
MGVector.basicUnsafeCopy as1 as2
MGVector.basicUnsafeCopy bs1 bs2
{-# INLINE basicUnsafeMove #-}
basicUnsafeMove (MV_IPv4Range as1 bs1) (MV_IPv4Range as2 bs2)
= do
MGVector.basicUnsafeMove as1 as2
MGVector.basicUnsafeMove bs1 bs2
{-# INLINE basicUnsafeGrow #-}
basicUnsafeGrow (MV_IPv4Range as bs) m_
= do
as' <- MGVector.basicUnsafeGrow as m_
bs' <- MGVector.basicUnsafeGrow bs m_
return $ MV_IPv4Range as' bs'
instance GVector.Vector UVector.Vector IPv4Range where
{-# INLINE basicUnsafeFreeze #-}
basicUnsafeFreeze (MV_IPv4Range as bs)
= do
as' <- GVector.basicUnsafeFreeze as
bs' <- GVector.basicUnsafeFreeze bs
return $ V_IPv4Range as' bs'
{-# INLINE basicUnsafeThaw #-}
basicUnsafeThaw (V_IPv4Range as bs)
= do
as' <- GVector.basicUnsafeThaw as
bs' <- GVector.basicUnsafeThaw bs
return $ MV_IPv4Range as' bs'
{-# INLINE basicLength #-}
basicLength (V_IPv4Range as _) = GVector.basicLength as
{-# INLINE basicUnsafeSlice #-}
basicUnsafeSlice i_ m_ (V_IPv4Range as bs)
= V_IPv4Range (GVector.basicUnsafeSlice i_ m_ as)
(GVector.basicUnsafeSlice i_ m_ bs)
{-# INLINE basicUnsafeIndexM #-}
basicUnsafeIndexM (V_IPv4Range as bs) i_
= do
a <- GVector.basicUnsafeIndexM as i_
b <- GVector.basicUnsafeIndexM bs i_
return (IPv4Range a b)
{-# INLINE basicUnsafeCopy #-}
basicUnsafeCopy (MV_IPv4Range as1 bs1) (V_IPv4Range as2 bs2)
= do
GVector.basicUnsafeCopy as1 as2
GVector.basicUnsafeCopy bs1 bs2
{-# INLINE elemseq #-}
elemseq _ (IPv4Range a b)
= GVector.elemseq (undefined :: UVector.Vector a) a
. GVector.elemseq (undefined :: UVector.Vector b) b
rangeBitwise :: (IPv4 -> IPv4 -> IPv4) -> IPv4Range -> IPv4Range -> IPv4Range
rangeBitwise fun l r = range ip len
where
l' = normalize l
r' = normalize r
ip = (ipv4RangeBase l') `fun` (ipv4RangeBase r')
len = maximum [ipv4RangeLength l, ipv4RangeLength r]
rangeRebase :: (IPv4 -> IPv4) -> IPv4Range -> IPv4Range
rangeRebase fun r = range (fun $ ipv4RangeBase r) (ipv4RangeLength r)
instance Bits.Bits IPv4Range where
(.&.) = rangeBitwise (.&.)
(.|.) = rangeBitwise (.|.)
xor = rangeBitwise Bits.xor
complement = rangeRebase Bits.complement
shift r i = rangeRebase (flip Bits.shift i) r
rotate r i = rangeRebase (flip Bits.rotate i) r
bitSize = Bits.finiteBitSize
bitSizeMaybe = Just . Bits.finiteBitSize
isSigned = Bits.isSigned . ipv4RangeBase
testBit ip i = Bits.testBit (ipv4RangeBase ip) i
bit i = IPv4Range (Bits.bit i) $ fromIntegral $ i + 1
popCount = Bits.popCount . ipv4RangeBase . normalize
instance Bits.FiniteBits IPv4Range where
finiteBitSize = fromIntegral . ipv4RangeLength
rangeToDotDecimalText :: IPv4Range -> Text
rangeToDotDecimalText = LText.toStrict . TBuilder.toLazyText . rangeToDotDecimalBuilder
rangeToDotDecimalBuilder :: IPv4Range -> TBuilder.Builder
rangeToDotDecimalBuilder (IPv4Range addr len) =
IPv4.builder addr
<> TBuilder.singleton '/'
<> TBI.decimal len