module Main where import Control.Concurrent.ParallelIO.Global ( stopGlobalPool ) import Control.Monad (when) import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack) import Data.List ((\\), intercalate) import Data.Maybe (catMaybes, isNothing) import Data.String.Utils (splitWs) import System.Exit (ExitCode(..), exitWith) import System.IO (stderr, hPutStrLn) import Text.Read (readMaybe) import Cidr ( Cidr(..), combine_all, enumerate, max_octet1, max_octet2, max_octet3, max_octet4, min_octet1, min_octet2, min_octet3, min_octet4 ) import CommandLine (Args(..), get_args) import DNS (Domain, PTRResult, lookup_ptrs) import ExitCodes ( exit_invalid_cidr ) import Octet () -- | A regular expression that matches a non-address character. non_addr_char :: String non_addr_char = "[^\\.0-9]" -- | Add non_addr_chars on either side of the given String. This -- prevents (for example) the regex '127.0.0.1' from matching -- '127.0.0.100'. add_barriers :: String -> String add_barriers x = non_addr_char ++ x ++ non_addr_char -- | The magic happens here. We take a CIDR String as an argument, and -- return the equivalent regular expression. We do this as follows: -- -- 1. Compute the minimum possible value of each octet. -- 2. Compute the maximum possible value of each octet. -- 3. Generate a regex matching every value between those min and -- max values. -- 4. Join the regexes from step 3 with regexes matching periods. -- 5. Stick an address boundary on either side of the result if -- use_barriers is True. -- cidr_to_regex :: Bool -> Cidr.Cidr -> String cidr_to_regex use_barriers cidr = let f = if use_barriers then add_barriers else id in f (intercalate "\\." [range1, range2, range3, range4]) where range1 = numeric_range min1 max1 range2 = numeric_range min2 max2 range3 = numeric_range min3 max3 range4 = numeric_range min4 max4 min1 = fromEnum (min_octet1 cidr) min2 = fromEnum (min_octet2 cidr) min3 = fromEnum (min_octet3 cidr) min4 = fromEnum (min_octet4 cidr) max1 = fromEnum (max_octet1 cidr) max2 = fromEnum (max_octet2 cidr) max3 = fromEnum (max_octet3 cidr) max4 = fromEnum (max_octet4 cidr) -- | Take a list of Strings, and return a regular expression matching -- any of them. alternate :: [String] -> String alternate terms = "(" ++ (intercalate "|" terms) ++ ")" -- | Take two Ints as parameters, and return a regex matching any -- integer between them (inclusive). -- -- IMPORTANT: we match from max to min so that if e.g. the last -- octet is '255', we want '255' to match before '2' in the regex -- (255|254|...|3|2|1) which does not happen if we use -- (1|2|3|...|254|255). -- numeric_range :: Int -> Int -> String numeric_range x y = alternate (map show $ reverse [lower..upper]) where lower = minimum [x,y] upper = maximum [x,y] main :: IO () main = do args <- get_args -- This reads stdin. input <- getContents let cidr_strings = splitWs input let cidrs = map readMaybe cidr_strings when (any isNothing cidrs) $ do hPutStrLn stderr "ERROR: not valid CIDR notation:" -- Output the bad lines, safely. let pairs = zip cidr_strings cidrs let print_pair (x, Nothing) = hPutStrLn stderr (" * " ++ x) print_pair (_, _) = return () mapM_ print_pair pairs exitWith (ExitFailure exit_invalid_cidr) -- Filter out only the valid ones. let valid_cidrs = catMaybes cidrs case args of Regexed{} -> do let cidrs' = combine_all valid_cidrs let regexes = map (cidr_to_regex (barriers args)) cidrs' putStrLn $ alternate regexes Reduced{} -> mapM_ print (combine_all valid_cidrs) Duped{} -> mapM_ print dupes where dupes = valid_cidrs \\ (combine_all valid_cidrs) Diffed{} -> do mapM_ putStrLn deletions mapM_ putStrLn additions where dupes = valid_cidrs \\ (combine_all valid_cidrs) deletions = map (\s -> '-' : (show s)) dupes newcidrs = (combine_all valid_cidrs) \\ valid_cidrs additions = map (\s -> '+' : (show s)) newcidrs Listed{} -> do let combined_cidrs = combine_all valid_cidrs let addrs = concatMap enumerate combined_cidrs mapM_ print addrs Reversed{} -> do let combined_cidrs = combine_all valid_cidrs let addrs = concatMap enumerate combined_cidrs let addr_bytestrings = map (BS.pack . show) addrs ptrs <- lookup_ptrs addr_bytestrings let pairs = zip addr_bytestrings ptrs mapM_ (putStrLn . show_pair) pairs stopGlobalPool where show_pair :: (Domain, PTRResult) -> String show_pair (s, eds) = (BS.unpack s) ++ ": " ++ results where space = BS.pack " " results = case eds of Left err -> "ERROR (" ++ (show err) ++ ")" Right ds -> BS.unpack $ BS.intercalate space ds