{-# LANGUAGE OverloadedStrings #-}

module Network.DomainAuth.SPF.Resolver (
    resolveSPF,
) where

import Control.Monad
import qualified Data.ByteString.Char8 as BS
import Data.IP
import Data.Maybe
import Network.DNS
import Network.DomainAuth.SPF.Parser
import Network.DomainAuth.SPF.Types

----------------------------------------------------------------

resolveSPF :: Resolver -> Domain -> IP -> IO [IO SpfSeq]
resolveSPF :: Resolver -> Domain -> IP -> IO [IO SpfSeq]
resolveSPF Resolver
resolver Domain
dom IP
ip = do
    Either DNSError [Domain]
jrc <- Resolver -> Domain -> IO (Either DNSError [Domain])
lookupTXT Resolver
resolver Domain
dom
    Either DNSError [Domain] -> String -> IO ()
forall a. Either DNSError a -> String -> IO ()
checkDNS Either DNSError [Domain]
jrc String
"TempError"
    let rr :: Domain
rr = Either DNSError [Domain] -> Domain
forall {a}. Either a [Domain] -> Domain
getSPFRR Either DNSError [Domain]
jrc
    Domain -> String -> IO ()
forall {f :: * -> *}. MonadFail f => Domain -> String -> f ()
checkExistence Domain
rr String
"None"
    let jrs :: Maybe [SPF]
jrs = Domain -> Maybe [SPF]
parseSPF Domain
rr
    Maybe [SPF] -> String -> IO ()
forall {f :: * -> *} {a}. MonadFail f => Maybe a -> String -> f ()
checkSyntax Maybe [SPF]
jrs String
"PermError"
    let is :: [SPF]
is = IP -> [SPF] -> [SPF]
filterSPFWithIP IP
ip (Maybe [SPF] -> [SPF]
forall a. HasCallStack => Maybe a -> a
fromJust Maybe [SPF]
jrs)
    [IO SpfSeq] -> IO [IO SpfSeq]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([IO SpfSeq] -> IO [IO SpfSeq]) -> [IO SpfSeq] -> IO [IO SpfSeq]
forall a b. (a -> b) -> a -> b
$ (SPF -> IO SpfSeq) -> [SPF] -> [IO SpfSeq]
forall a b. (a -> b) -> [a] -> [b]
map (Resolver -> Domain -> IP -> SPF -> IO SpfSeq
toSpfSeq Resolver
resolver Domain
dom IP
ip) [SPF]
is
  where
    getSPFRR :: Either a [Domain] -> Domain
getSPFRR (Left a
_) = String -> Domain
forall a. HasCallStack => String -> a
error String
"getSPRRR"
    getSPFRR (Right [Domain]
rc)
        | [Domain] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Domain]
ts = Domain
""
        | Bool
otherwise = [Domain] -> Domain
forall a. HasCallStack => [a] -> a
head [Domain]
ts
      where
        ts :: [Domain]
ts = (Domain -> Bool) -> [Domain] -> [Domain]
forall a. (a -> Bool) -> [a] -> [a]
filter (Domain
"v=spf1" Domain -> Domain -> Bool
`BS.isPrefixOf`) [Domain]
rc
    checkSyntax :: Maybe a -> String -> f ()
checkSyntax Maybe a
rs String
estr = Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe a -> Bool
forall a. Maybe a -> Bool
isNothing Maybe a
rs) (String -> f ()
forall a. String -> f a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
estr)
    checkExistence :: Domain -> String -> f ()
checkExistence Domain
rr String
estr = Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Domain -> Bool
BS.null Domain
rr) (String -> f ()
forall a. String -> f a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
estr)

----------------------------------------------------------------

filterSPFWithIP :: IP -> [SPF] -> [SPF]
filterSPFWithIP :: IP -> [SPF] -> [SPF]
filterSPFWithIP (IPv4 IPv4
_) [SPF]
spfs = (SPF -> Bool) -> [SPF] -> [SPF]
forall a. (a -> Bool) -> [a] -> [a]
filter SPF -> Bool
exceptIPv4 [SPF]
spfs
filterSPFWithIP (IPv6 IPv6
_) [SPF]
spfs = (SPF -> Bool) -> [SPF] -> [SPF]
forall a. (a -> Bool) -> [a] -> [a]
filter SPF -> Bool
exceptIPv6 [SPF]
spfs

exceptIPv4 :: SPF -> Bool
exceptIPv4 :: SPF -> Bool
exceptIPv4 (SPF_IPv6Range Qualifier
_ AddrRange IPv6
_) = Bool
False
exceptIPv4 SPF
_ = Bool
True

exceptIPv6 :: SPF -> Bool
exceptIPv6 :: SPF -> Bool
exceptIPv6 (SPF_IPv4Range Qualifier
_ AddrRange IPv4
_) = Bool
False
exceptIPv6 SPF
_ = Bool
True

----------------------------------------------------------------

toSpfSeq :: Resolver -> Domain -> IP -> SPF -> IO SpfSeq
toSpfSeq :: Resolver -> Domain -> IP -> SPF -> IO SpfSeq
toSpfSeq Resolver
_ Domain
_ IP
_ (SPF_IPv4Range Qualifier
q AddrRange IPv4
ipr) = SpfSeq -> IO SpfSeq
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SpfSeq -> IO SpfSeq) -> SpfSeq -> IO SpfSeq
forall a b. (a -> b) -> a -> b
$ Qualifier -> AddrRange IPv4 -> SpfSeq
SS_IPv4Range Qualifier
q AddrRange IPv4
ipr
toSpfSeq Resolver
_ Domain
_ IP
_ (SPF_IPv6Range Qualifier
q AddrRange IPv6
ipr) = SpfSeq -> IO SpfSeq
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SpfSeq -> IO SpfSeq) -> SpfSeq -> IO SpfSeq
forall a b. (a -> b) -> a -> b
$ Qualifier -> AddrRange IPv6 -> SpfSeq
SS_IPv6Range Qualifier
q AddrRange IPv6
ipr
toSpfSeq Resolver
_ Domain
_ IP
_ (SPF_All Qualifier
q) = SpfSeq -> IO SpfSeq
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SpfSeq -> IO SpfSeq) -> SpfSeq -> IO SpfSeq
forall a b. (a -> b) -> a -> b
$ Qualifier -> SpfSeq
SS_All Qualifier
q
toSpfSeq Resolver
r Domain
_ IP
ip (SPF_Include Qualifier
q Domain
dom) = Qualifier -> [IO SpfSeq] -> SpfSeq
SS_IF_Pass Qualifier
q ([IO SpfSeq] -> SpfSeq) -> IO [IO SpfSeq] -> IO SpfSeq
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Resolver -> Domain -> IP -> IO [IO SpfSeq]
resolveSPF Resolver
r Domain
dom IP
ip
toSpfSeq Resolver
r Domain
_ IP
ip (SPF_Redirect Domain
dom) = [IO SpfSeq] -> SpfSeq
SS_SpfSeq ([IO SpfSeq] -> SpfSeq) -> IO [IO SpfSeq] -> IO SpfSeq
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Resolver -> Domain -> IP -> IO [IO SpfSeq]
resolveSPF Resolver
r Domain
dom IP
ip
toSpfSeq Resolver
r Domain
dom (IPv4 IPv4
_) (SPF_MX Qualifier
q Maybe Domain
Nothing (Int
l4, Int
_)) =
    Resolver -> Domain -> IO (Either DNSError [IPv4])
lookupAviaMX Resolver
r Domain
dom IO (Either DNSError [IPv4])
-> (Either DNSError [IPv4] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 Qualifier
q Int
l4
toSpfSeq Resolver
r Domain
dom (IPv6 IPv6
_) (SPF_MX Qualifier
q Maybe Domain
Nothing (Int
_, Int
l6)) =
    Resolver -> Domain -> IO (Either DNSError [IPv6])
lookupAAAAviaMX Resolver
r Domain
dom IO (Either DNSError [IPv6])
-> (Either DNSError [IPv6] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 Qualifier
q Int
l6
toSpfSeq Resolver
r Domain
_ (IPv4 IPv4
_) (SPF_MX Qualifier
q (Just Domain
dom) (Int
l4, Int
_)) =
    Resolver -> Domain -> IO (Either DNSError [IPv4])
lookupAviaMX Resolver
r Domain
dom IO (Either DNSError [IPv4])
-> (Either DNSError [IPv4] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 Qualifier
q Int
l4
toSpfSeq Resolver
r Domain
_ (IPv6 IPv6
_) (SPF_MX Qualifier
q (Just Domain
dom) (Int
_, Int
l6)) =
    Resolver -> Domain -> IO (Either DNSError [IPv6])
lookupAAAAviaMX Resolver
r Domain
dom IO (Either DNSError [IPv6])
-> (Either DNSError [IPv6] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 Qualifier
q Int
l6
toSpfSeq Resolver
r Domain
dom (IPv4 IPv4
_) (SPF_Address Qualifier
q Maybe Domain
Nothing (Int
l4, Int
_)) =
    Resolver -> Domain -> IO (Either DNSError [IPv4])
lookupA Resolver
r Domain
dom IO (Either DNSError [IPv4])
-> (Either DNSError [IPv4] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 Qualifier
q Int
l4
toSpfSeq Resolver
r Domain
dom (IPv6 IPv6
_) (SPF_Address Qualifier
q Maybe Domain
Nothing (Int
_, Int
l6)) =
    Resolver -> Domain -> IO (Either DNSError [IPv6])
lookupAAAA Resolver
r Domain
dom IO (Either DNSError [IPv6])
-> (Either DNSError [IPv6] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 Qualifier
q Int
l6
toSpfSeq Resolver
r Domain
_ (IPv4 IPv4
_) (SPF_Address Qualifier
q (Just Domain
dom) (Int
l4, Int
_)) =
    Resolver -> Domain -> IO (Either DNSError [IPv4])
lookupA Resolver
r Domain
dom IO (Either DNSError [IPv4])
-> (Either DNSError [IPv4] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 Qualifier
q Int
l4
toSpfSeq Resolver
r Domain
_ (IPv6 IPv6
_) (SPF_Address Qualifier
q (Just Domain
dom) (Int
_, Int
l6)) =
    Resolver -> Domain -> IO (Either DNSError [IPv6])
lookupAAAA Resolver
r Domain
dom IO (Either DNSError [IPv6])
-> (Either DNSError [IPv6] -> IO SpfSeq) -> IO SpfSeq
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 Qualifier
q Int
l6

doit4 :: Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 :: Qualifier -> Int -> Either DNSError [IPv4] -> IO SpfSeq
doit4 Qualifier
_ Int
_ (Left DNSError
_) = String -> IO SpfSeq
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"TempError"
doit4 Qualifier
q Int
l4 (Right [IPv4]
is) = SpfSeq -> IO SpfSeq
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SpfSeq -> IO SpfSeq) -> SpfSeq -> IO SpfSeq
forall a b. (a -> b) -> a -> b
$ Qualifier -> [AddrRange IPv4] -> SpfSeq
SS_IPv4Ranges Qualifier
q ([AddrRange IPv4] -> SpfSeq) -> [AddrRange IPv4] -> SpfSeq
forall a b. (a -> b) -> a -> b
$ (IPv4 -> AddrRange IPv4) -> [IPv4] -> [AddrRange IPv4]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> IPv4 -> AddrRange IPv4
mkr Int
l4) [IPv4]
is
  where
    mkr :: Int -> IPv4 -> AddrRange IPv4
mkr = (IPv4 -> Int -> AddrRange IPv4) -> Int -> IPv4 -> AddrRange IPv4
forall a b c. (a -> b -> c) -> b -> a -> c
flip IPv4 -> Int -> AddrRange IPv4
forall a. Addr a => a -> Int -> AddrRange a
makeAddrRange

doit6 :: Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 :: Qualifier -> Int -> Either DNSError [IPv6] -> IO SpfSeq
doit6 Qualifier
_ Int
_ (Left DNSError
_) = String -> IO SpfSeq
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"TempError"
doit6 Qualifier
q Int
l6 (Right [IPv6]
is) = SpfSeq -> IO SpfSeq
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SpfSeq -> IO SpfSeq) -> SpfSeq -> IO SpfSeq
forall a b. (a -> b) -> a -> b
$ Qualifier -> [AddrRange IPv6] -> SpfSeq
SS_IPv6Ranges Qualifier
q ([AddrRange IPv6] -> SpfSeq) -> [AddrRange IPv6] -> SpfSeq
forall a b. (a -> b) -> a -> b
$ (IPv6 -> AddrRange IPv6) -> [IPv6] -> [AddrRange IPv6]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> IPv6 -> AddrRange IPv6
mkr Int
l6) [IPv6]
is
  where
    mkr :: Int -> IPv6 -> AddrRange IPv6
mkr = (IPv6 -> Int -> AddrRange IPv6) -> Int -> IPv6 -> AddrRange IPv6
forall a b c. (a -> b -> c) -> b -> a -> c
flip IPv6 -> Int -> AddrRange IPv6
forall a. Addr a => a -> Int -> AddrRange a
makeAddrRange

----------------------------------------------------------------

checkDNS :: Either DNSError a -> String -> IO ()
checkDNS :: forall a. Either DNSError a -> String -> IO ()
checkDNS (Right a
_) String
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkDNS (Left DNSError
_) String
estr = String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
estr