module Network.DomainAuth.SPF.Eval (
    evalSPF,
    Limit (..),
    defaultLimit,
) where

import Data.IORef
import Data.IP
import Data.Maybe
import Network.DomainAuth.SPF.Types
import Network.DomainAuth.Types

-- |  Limit for SPF authentication.
data Limit = Limit
    { Limit -> Int
limit :: Int
    -- ^ How many \"redirect\"/\"include\" should be followed.
    --   'DAPermError' is returned if reached to this limit.
    , Limit -> Int
ipv4_masklen :: Int
    -- ^ Ignoring IPv4 range whose mask length is shorter than this.
    , Limit -> Int
ipv6_masklen :: Int
    -- ^ Ignoring IPv6 range whose mask length is shorter than this.
    , Limit -> Bool
reject_plus_all :: Bool
    -- ^ Whether or not \"+all\" is rejected.
    }
    deriving (Limit -> Limit -> Bool
(Limit -> Limit -> Bool) -> (Limit -> Limit -> Bool) -> Eq Limit
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Limit -> Limit -> Bool
== :: Limit -> Limit -> Bool
$c/= :: Limit -> Limit -> Bool
/= :: Limit -> Limit -> Bool
Eq, Int -> Limit -> ShowS
[Limit] -> ShowS
Limit -> String
(Int -> Limit -> ShowS)
-> (Limit -> String) -> ([Limit] -> ShowS) -> Show Limit
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Limit -> ShowS
showsPrec :: Int -> Limit -> ShowS
$cshow :: Limit -> String
show :: Limit -> String
$cshowList :: [Limit] -> ShowS
showList :: [Limit] -> ShowS
Show)

-- | Default value for 'Limit'.
--
-- >>> defaultLimit
-- Limit {limit = 10, ipv4_masklen = 16, ipv6_masklen = 48, reject_plus_all = True}
defaultLimit :: Limit
defaultLimit :: Limit
defaultLimit =
    Limit
        { limit :: Int
limit = Int
10
        , ipv4_masklen :: Int
ipv4_masklen = Int
16
        , ipv6_masklen :: Int
ipv6_masklen = Int
48
        , reject_plus_all :: Bool
reject_plus_all = Bool
True
        }

----------------------------------------------------------------

evalSPF :: Limit -> IP -> [IO SpfSeq] -> IO DAResult
evalSPF :: Limit -> IP -> [IO SpfSeq] -> IO DAResult
evalSPF Limit
lim IP
ip [IO SpfSeq]
ss = do
    IORef Int
ref <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef (Int
0 :: Int)
    Maybe DAResult -> DAResult
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe DAResult -> DAResult) -> IO (Maybe DAResult) -> IO DAResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf IORef Int
ref Limit
lim IP
ip [IO SpfSeq]
ss

----------------------------------------------------------------

evalspf :: IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf :: IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf IORef Int
_ Limit
_ IP
_ [] = Maybe DAResult -> IO (Maybe DAResult)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (DAResult -> Maybe DAResult
forall a. a -> Maybe a
Just DAResult
DANeutral) -- default result
evalspf IORef Int
ref Limit
lim IP
ip (IO SpfSeq
s : [IO SpfSeq]
ss) = do
    Int
cnt <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
ref
    if Int
cnt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Limit -> Int
limit Limit
lim
        then Maybe DAResult -> IO (Maybe DAResult)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (DAResult -> Maybe DAResult
forall a. a -> Maybe a
Just DAResult
DAPermError) -- reached the limit
        else do
            Maybe DAResult
mres <- IORef Int -> Limit -> IP -> IO SpfSeq -> IO (Maybe DAResult)
eval IORef Int
ref Limit
lim IP
ip IO SpfSeq
s
            case Maybe DAResult
mres of
                Maybe DAResult
Nothing -> IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf IORef Int
ref Limit
lim IP
ip [IO SpfSeq]
ss
                Maybe DAResult
res -> Maybe DAResult -> IO (Maybe DAResult)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe DAResult
res

----------------------------------------------------------------
-- Follow N of redirect/include. But the last one is not
-- evaluated.

eval :: IORef Int -> Limit -> IP -> IO SpfSeq -> IO (Maybe DAResult)
eval :: IORef Int -> Limit -> IP -> IO SpfSeq -> IO (Maybe DAResult)
eval IORef Int
ref Limit
lim IP
ip IO SpfSeq
is = do
    Int
cnt <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
ref
    SpfSeq
s <- IO SpfSeq
is
    case SpfSeq
s of
        SS_All Qualifier
q ->
            if Qualifier
q Qualifier -> Qualifier -> Bool
forall a. Eq a => a -> a -> Bool
== Qualifier
Q_Pass Bool -> Bool -> Bool
&& Limit -> Bool
reject_plus_all Limit
lim
                then DAResult -> IO (Maybe DAResult)
forall {a}. a -> IO (Maybe a)
result DAResult
DAPermError
                else Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
        SS_IPv4Range Qualifier
q AddrRange IPv4
ipr
            | Limit -> AddrRange IPv4 -> Bool
forall {a}. Limit -> AddrRange a -> Bool
nastyMask4 Limit
lim AddrRange IPv4
ipr -> DAResult -> IO (Maybe DAResult)
forall {a}. a -> IO (Maybe a)
result DAResult
DAPermError
            | IP -> IPv4
ipv4 IP
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`isMatchedTo` AddrRange IPv4
ipr -> Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
            | Bool
otherwise -> IO (Maybe DAResult)
forall {a}. IO (Maybe a)
continue
        SS_IPv4Ranges Qualifier
q [AddrRange IPv4]
iprs
            | (AddrRange IPv4 -> Bool) -> [AddrRange IPv4] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Limit -> AddrRange IPv4 -> Bool
forall {a}. Limit -> AddrRange a -> Bool
nastyMask4 Limit
lim) [AddrRange IPv4]
iprs -> DAResult -> IO (Maybe DAResult)
forall {a}. a -> IO (Maybe a)
result DAResult
DAPermError
            | (AddrRange IPv4 -> Bool) -> [AddrRange IPv4] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPv4
ipv4 IP
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`isMatchedTo`) [AddrRange IPv4]
iprs -> Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
            | Bool
otherwise -> IO (Maybe DAResult)
forall {a}. IO (Maybe a)
continue
        SS_IPv6Range Qualifier
q AddrRange IPv6
ipr
            | Limit -> AddrRange IPv6 -> Bool
forall {a}. Limit -> AddrRange a -> Bool
nastyMask6 Limit
lim AddrRange IPv6
ipr -> DAResult -> IO (Maybe DAResult)
forall {a}. a -> IO (Maybe a)
result DAResult
DAPermError
            | IP -> IPv6
ipv6 IP
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`isMatchedTo` AddrRange IPv6
ipr -> Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
            | Bool
otherwise -> IO (Maybe DAResult)
forall {a}. IO (Maybe a)
continue
        SS_IPv6Ranges Qualifier
q [AddrRange IPv6]
iprs
            | (AddrRange IPv6 -> Bool) -> [AddrRange IPv6] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Limit -> AddrRange IPv6 -> Bool
forall {a}. Limit -> AddrRange a -> Bool
nastyMask6 Limit
lim) [AddrRange IPv6]
iprs -> DAResult -> IO (Maybe DAResult)
forall {a}. a -> IO (Maybe a)
result DAResult
DAPermError
            | (AddrRange IPv6 -> Bool) -> [AddrRange IPv6] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPv6
ipv6 IP
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`isMatchedTo`) [AddrRange IPv6]
iprs -> Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
            | Bool
otherwise -> IO (Maybe DAResult)
forall {a}. IO (Maybe a)
continue
        SS_IF_Pass Qualifier
q [IO SpfSeq]
ss -> do
            IORef Int -> Int -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Int
ref (Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            Maybe DAResult
r <- IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf IORef Int
ref Limit
lim IP
ip [IO SpfSeq]
ss
            if Maybe DAResult
r Maybe DAResult -> Maybe DAResult -> Bool
forall a. Eq a => a -> a -> Bool
== DAResult -> Maybe DAResult
forall a. a -> Maybe a
Just DAResult
DAPass
                then Qualifier -> IO (Maybe DAResult)
ret Qualifier
q
                else IO (Maybe DAResult)
forall {a}. IO (Maybe a)
continue
        SS_SpfSeq [IO SpfSeq]
ss -> do
            IORef Int -> Int -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Int
ref (Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            IORef Int -> Limit -> IP -> [IO SpfSeq] -> IO (Maybe DAResult)
evalspf IORef Int
ref Limit
lim IP
ip [IO SpfSeq]
ss
  where
    ret :: Qualifier -> IO (Maybe DAResult)
ret = Maybe DAResult -> IO (Maybe DAResult)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe DAResult -> IO (Maybe DAResult))
-> (Qualifier -> Maybe DAResult)
-> Qualifier
-> IO (Maybe DAResult)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DAResult -> Maybe DAResult
forall a. a -> Maybe a
Just (DAResult -> Maybe DAResult)
-> (Qualifier -> DAResult) -> Qualifier -> Maybe DAResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> DAResult
forall a. Enum a => Int -> a
toEnum (Int -> DAResult) -> (Qualifier -> Int) -> Qualifier -> DAResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Qualifier -> Int
forall a. Enum a => a -> Int
fromEnum
    result :: a -> IO (Maybe a)
result = Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> IO (Maybe a)) -> (a -> Maybe a) -> a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just
    continue :: IO (Maybe a)
continue = Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
    nastyMask4 :: Limit -> AddrRange a -> Bool
nastyMask4 Limit
st AddrRange a
ipr = AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
ipr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Limit -> Int
ipv4_masklen Limit
st
    nastyMask6 :: Limit -> AddrRange a -> Bool
nastyMask6 Limit
st AddrRange a
ipr = AddrRange a -> Int
forall a. AddrRange a -> Int
mlen AddrRange a
ipr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Limit -> Int
ipv6_masklen Limit
st