{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.DNS.StateBinary (
    PState(..)
  , initialState
  , SPut
  , runSPut
  , put8
  , put16
  , put32
  , putInt8
  , putInt16
  , putInt32
  , putByteString
  , putReplicate
  , SGet
  , failSGet
  , fitSGet
  , runSGet
  , runSGetAt
  , runSGetWithLeftovers
  , runSGetWithLeftoversAt
  , get8
  , get16
  , get32
  , getInt8
  , getInt16
  , getInt32
  , getNByteString
  , sGetMany
  , getPosition
  , getInput
  , getAtTime
  , wsPop
  , wsPush
  , wsPosition
  , addPositionW
  , push
  , pop
  , getNBytes
  , getNoctets
  , skipNBytes
  , parseLabel
  , unparseLabel
  ) where

import qualified Control.Exception as E
import Control.Monad.State.Strict (State, StateT)
import qualified Control.Monad.State.Strict as ST
import qualified Data.Attoparsec.ByteString as A
import qualified Data.Attoparsec.Types as T
import qualified Data.ByteString as BS
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.Map (Map)
import qualified Data.Map as M
import Data.Semigroup as Sem

import Network.DNS.Imports
import Network.DNS.Types.Internal

----------------------------------------------------------------

type SPut = State WState Builder

data WState = WState {
    WState -> Map ByteString Int
wsDomain :: Map Domain Int
  , WState -> Int
wsPosition :: Int
}

initialWState :: WState
initialWState :: WState
initialWState = Map ByteString Int -> Int -> WState
WState forall k a. Map k a
M.empty Int
0

instance Sem.Semigroup SPut where
    SPut
p1 <> :: SPut -> SPut -> SPut
<> SPut
p2 = forall a. Semigroup a => a -> a -> a
(Sem.<>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SPut
p1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SPut
p2

instance Monoid SPut where
    mempty :: SPut
mempty = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
#if !(MIN_VERSION_base(4,11,0))
    mappend = (Sem.<>)
#endif

put8 :: Word8 -> SPut
put8 :: Word8 -> SPut
put8 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
1 Word8 -> Builder
BB.word8

put16 :: Word16 -> SPut
put16 :: Word16 -> SPut
put16 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
2 Word16 -> Builder
BB.word16BE

put32 :: Word32 -> SPut
put32 :: Word32 -> SPut
put32 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
4 Word32 -> Builder
BB.word32BE

putInt8 :: Int -> SPut
putInt8 :: Int -> SPut
putInt8 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
1 (Int8 -> Builder
BB.int8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putInt16 :: Int -> SPut
putInt16 :: Int -> SPut
putInt16 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
2 (Int16 -> Builder
BB.int16BE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putInt32 :: Int -> SPut
putInt32 :: Int -> SPut
putInt32 = forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
4 (Int32 -> Builder
BB.int32BE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral)

putByteString :: ByteString -> SPut
putByteString :: ByteString -> SPut
putByteString = forall a. (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized ByteString -> Int
BS.length ByteString -> Builder
BB.byteString

putReplicate :: Int -> Word8 -> SPut
putReplicate :: Int -> Word8 -> SPut
putReplicate Int
n Word8
w =
    forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
n ByteString -> Builder
BB.lazyByteString forall a b. (a -> b) -> a -> b
$ Int64 -> Word8 -> ByteString
LB.replicate (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) Word8
w

addPositionW :: Int -> State WState ()
addPositionW :: Int -> State WState ()
addPositionW Int
n = do
    (WState Map ByteString Int
m Int
cur) <- forall s (m :: * -> *). MonadState s m => m s
ST.get
    forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put forall a b. (a -> b) -> a -> b
$ Map ByteString Int -> Int -> WState
WState Map ByteString Int
m (Int
curforall a. Num a => a -> a -> a
+Int
n)

fixedSized :: Int -> (a -> Builder) -> a -> SPut
fixedSized :: forall a. Int -> (a -> Builder) -> a -> SPut
fixedSized Int
n a -> Builder
f a
a = do Int -> State WState ()
addPositionW Int
n
                      forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Builder
f a
a)

writeSized :: (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized :: forall a. (a -> Int) -> (a -> Builder) -> a -> SPut
writeSized a -> Int
n a -> Builder
f a
a = do Int -> State WState ()
addPositionW (a -> Int
n a
a)
                      forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Builder
f a
a)

wsPop :: Domain -> State WState (Maybe Int)
wsPop :: ByteString -> State WState (Maybe Int)
wsPop ByteString
dom = do
    Map ByteString Int
doms <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets WState -> Map ByteString Int
wsDomain
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ByteString
dom Map ByteString Int
doms

wsPush :: Domain -> Int -> State WState ()
wsPush :: ByteString -> Int -> State WState ()
wsPush ByteString
dom Int
pos = do
    (WState Map ByteString Int
m Int
cur) <- forall s (m :: * -> *). MonadState s m => m s
ST.get
    forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put forall a b. (a -> b) -> a -> b
$ Map ByteString Int -> Int -> WState
WState (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert ByteString
dom Int
pos Map ByteString Int
m) Int
cur

----------------------------------------------------------------

type SGet = StateT PState (T.Parser ByteString)

data PState = PState {
    PState -> IntMap ByteString
psDomain :: IntMap Domain
  , PState -> Int
psPosition :: Int
  , PState -> ByteString
psInput :: ByteString
  , PState -> Int64
psAtTime  :: Int64
  }

----------------------------------------------------------------

getPosition :: SGet Int
getPosition :: SGet Int
getPosition = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets PState -> Int
psPosition

getInput :: SGet ByteString
getInput :: SGet ByteString
getInput = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets PState -> ByteString
psInput

getAtTime :: SGet Int64
getAtTime :: SGet Int64
getAtTime = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets PState -> Int64
psAtTime

addPosition :: Int -> SGet ()
addPosition :: Int -> SGet ()
addPosition Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. String -> SGet a
failSGet String
"internal error: negative position increment"
              | Bool
otherwise = do
    PState IntMap ByteString
dom Int
pos ByteString
inp Int64
t <- forall s (m :: * -> *). MonadState s m => m s
ST.get
    let !pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ Int
n
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pos' forall a. Ord a => a -> a -> Bool
> ByteString -> Int
BS.length ByteString
inp) forall a b. (a -> b) -> a -> b
$
        forall a. String -> SGet a
failSGet String
"malformed or truncated input"
    forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put forall a b. (a -> b) -> a -> b
$ IntMap ByteString -> Int -> ByteString -> Int64 -> PState
PState IntMap ByteString
dom Int
pos' ByteString
inp Int64
t

push :: Int -> Domain -> SGet ()
push :: Int -> ByteString -> SGet ()
push Int
n ByteString
d = do
    PState IntMap ByteString
dom Int
pos ByteString
inp Int64
t <- forall s (m :: * -> *). MonadState s m => m s
ST.get
    forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put forall a b. (a -> b) -> a -> b
$ IntMap ByteString -> Int -> ByteString -> Int64 -> PState
PState (forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
n ByteString
d IntMap ByteString
dom) Int
pos ByteString
inp Int64
t

pop :: Int -> SGet (Maybe Domain)
pop :: Int -> SGet (Maybe ByteString)
pop Int
n = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
ST.gets (forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. PState -> IntMap ByteString
psDomain)

----------------------------------------------------------------

get8 :: SGet Word8
get8 :: SGet Word8
get8  = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser Word8
A.anyWord8 forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
1

get16 :: SGet Word16
get16 :: SGet Word16
get16 = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser ByteString Word16
getWord16be forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
2
  where
    word8' :: Parser ByteString Word16
word8' = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Word8
A.anyWord8
    getWord16be :: Parser ByteString Word16
getWord16be = do
        Word16
a <- Parser ByteString Word16
word8'
        Word16
b <- Parser ByteString Word16
word8'
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Word16
a forall a. Num a => a -> a -> a
* Word16
0x100 forall a. Num a => a -> a -> a
+ Word16
b

get32 :: SGet Word32
get32 :: SGet Word32
get32 = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift Parser ByteString Word32
getWord32be forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
4
  where
    word8' :: Parser ByteString Word32
word8' = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Word8
A.anyWord8
    getWord32be :: Parser ByteString Word32
getWord32be = do
        Word32
a <- Parser ByteString Word32
word8'
        Word32
b <- Parser ByteString Word32
word8'
        Word32
c <- Parser ByteString Word32
word8'
        Word32
d <- Parser ByteString Word32
word8'
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Word32
a forall a. Num a => a -> a -> a
* Word32
0x1000000 forall a. Num a => a -> a -> a
+ Word32
b forall a. Num a => a -> a -> a
* Word32
0x10000 forall a. Num a => a -> a -> a
+ Word32
c forall a. Num a => a -> a -> a
* Word32
0x100 forall a. Num a => a -> a -> a
+ Word32
d

getInt8 :: SGet Int
getInt8 :: SGet Int
getInt8 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word8
get8

getInt16 :: SGet Int
getInt16 :: SGet Int
getInt16 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word16
get16

getInt32 :: SGet Int
getInt32 :: SGet Int
getInt32 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SGet Word32
get32

----------------------------------------------------------------

overrun :: SGet a
overrun :: forall a. SGet a
overrun = forall a. String -> SGet a
failSGet String
"malformed or truncated input"

getNBytes :: Int -> SGet [Int]
getNBytes :: Int -> SGet [Int]
getNBytes Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. SGet a
overrun
            | Bool
otherwise = ByteString -> [Int]
toInts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
n
  where
    toInts :: ByteString -> [Int]
toInts = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack

getNoctets :: Int -> SGet [Word8]
getNoctets :: Int -> SGet [Word8]
getNoctets Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. SGet a
overrun
             | Bool
otherwise = ByteString -> [Word8]
BS.unpack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SGet ByteString
getNByteString Int
n

skipNBytes :: Int -> SGet ()
skipNBytes :: Int -> SGet ()
skipNBytes Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. SGet a
overrun
             | Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift (Int -> Parser ByteString
A.take Int
n) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> SGet ()
addPosition Int
n

getNByteString :: Int -> SGet ByteString
getNByteString :: Int -> SGet ByteString
getNByteString Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. SGet a
overrun
                 | Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift (Int -> Parser ByteString
A.take Int
n) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> SGet ()
addPosition Int
n

fitSGet :: Int -> SGet a -> SGet a
fitSGet :: forall a. Int -> SGet a -> SGet a
fitSGet Int
len SGet a
parser | Int
len forall a. Ord a => a -> a -> Bool
< Int
0   = forall a. SGet a
overrun
                   | Bool
otherwise = do
    Int
pos0 <- SGet Int
getPosition
    a
ret <- SGet a
parser
    Int
pos' <- SGet Int
getPosition
    if Int
pos' forall a. Eq a => a -> a -> Bool
== Int
pos0 forall a. Num a => a -> a -> a
+ Int
len
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! a
ret
    else if Int
pos' forall a. Ord a => a -> a -> Bool
> Int
pos0 forall a. Num a => a -> a -> a
+ Int
len
    then forall a. String -> SGet a
failSGet String
"element size exceeds declared size"
    else forall a. String -> SGet a
failSGet String
"element shorter than declared size"

-- | Parse a list of elements that takes up exactly a given number of bytes.
-- In order to avoid infinite loops, if an element parser succeeds without
-- moving the buffer offset forward, an error will be returned.
--
sGetMany :: String -- ^ element type for error messages
         -> Int    -- ^ input buffer length
         -> SGet a -- ^ element parser
         -> SGet [a]
sGetMany :: forall a. String -> Int -> SGet a -> SGet [a]
sGetMany String
elemname Int
len SGet a
parser | Int
len forall a. Ord a => a -> a -> Bool
< Int
0   = forall a. SGet a
overrun
                             | Bool
otherwise = Int -> [a] -> StateT PState (Parser ByteString) [a]
go Int
len []
  where
    go :: Int -> [a] -> StateT PState (Parser ByteString) [a]
go Int
n [a]
xs
        | Int
n forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. String -> SGet a
failSGet forall a b. (a -> b) -> a -> b
$ String
elemname forall a. [a] -> [a] -> [a]
++ String
" longer than declared size"
        | Int
n forall a. Eq a => a -> a -> Bool
== Int
0    = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [a]
xs
        | Bool
otherwise = do
            Int
pos0 <- SGet Int
getPosition
            a
x    <- SGet a
parser
            Int
pos1 <- SGet Int
getPosition
            if Int
pos1 forall a. Ord a => a -> a -> Bool
<= Int
pos0
            then forall a. String -> SGet a
failSGet forall a b. (a -> b) -> a -> b
$ String
"internal error: in-place success for " forall a. [a] -> [a] -> [a]
++ String
elemname
            else Int -> [a] -> StateT PState (Parser ByteString) [a]
go (Int
n forall a. Num a => a -> a -> a
+ Int
pos0 forall a. Num a => a -> a -> a
- Int
pos1) (a
x forall a. a -> [a] -> [a]
: [a]
xs)

----------------------------------------------------------------

-- | To get a broad range of correct RRSIG inception and expiration times
-- without over or underflow, we choose a time half way between midnight PDT
-- 2010-07-15 (the day the root zone was signed) and 2^32 seconds later on
-- 2146-08-21.  Since 'decode' and 'runSGet' are pure, we can't peek at the
-- current time while parsing.  Outside this date range the output is off by
-- some non-zero multiple 2\^32 seconds.
--
dnsTimeMid :: Int64
dnsTimeMid :: Int64
dnsTimeMid = Int64
3426660848

initialState :: Int64 -> ByteString -> PState
initialState :: Int64 -> ByteString -> PState
initialState Int64
t ByteString
inp = IntMap ByteString -> Int -> ByteString -> Int64 -> PState
PState forall a. IntMap a
IM.empty Int
0 ByteString
inp Int64
t

-- Construct our own error message, without the unhelpful AttoParsec
-- \"Failed reading: \" prefix.
--
failSGet :: String -> SGet a
failSGet :: forall a. String -> SGet a
failSGet String
msg = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
ST.lift (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"" forall i a. Parser i a -> String -> Parser i a
A.<?> String
msg)

runSGetAt :: Int64 -> SGet a -> ByteString -> Either DNSError (a, PState)
runSGetAt :: forall a.
Int64 -> SGet a -> ByteString -> Either DNSError (a, PState)
runSGetAt Int64
t SGet a
parser ByteString
inp =
    forall r. Result r -> Either DNSError r
toResult forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Result a
A.parse (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
ST.runStateT SGet a
parser forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> PState
initialState Int64
t ByteString
inp) ByteString
inp
  where
    toResult :: A.Result r -> Either DNSError r
    toResult :: forall r. Result r -> Either DNSError r
toResult (A.Done ByteString
_ r
r)        = forall a b. b -> Either a b
Right r
r
    toResult (A.Fail ByteString
_ [String]
ctx String
msg)  = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ [String]
ctx forall a. [a] -> [a] -> [a]
++ [String
msg]
    toResult (A.Partial ByteString -> IResult ByteString r
_)       = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError String
"incomplete input"

runSGet :: SGet a -> ByteString -> Either DNSError (a, PState)
runSGet :: forall a. SGet a -> ByteString -> Either DNSError (a, PState)
runSGet = forall a.
Int64 -> SGet a -> ByteString -> Either DNSError (a, PState)
runSGetAt Int64
dnsTimeMid

runSGetWithLeftoversAt :: Int64      -- ^ Reference time for DNS clock arithmetic
                       -> SGet a     -- ^ Parser
                       -> ByteString -- ^ Encoded message
                       -> Either DNSError ((a, PState), ByteString)
runSGetWithLeftoversAt :: forall a.
Int64
-> SGet a
-> ByteString
-> Either DNSError ((a, PState), ByteString)
runSGetWithLeftoversAt Int64
t SGet a
parser ByteString
inp =
    forall r. Result r -> Either DNSError (r, ByteString)
toResult forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Result a
A.parse (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
ST.runStateT SGet a
parser forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> PState
initialState Int64
t ByteString
inp) ByteString
inp
  where
    toResult :: A.Result r -> Either DNSError (r, ByteString)
    toResult :: forall r. Result r -> Either DNSError (r, ByteString)
toResult (A.Done     ByteString
i r
r) = forall a b. b -> Either a b
Right (r
r, ByteString
i)
    toResult (A.Partial  ByteString -> IResult ByteString r
f)   = forall r. Result r -> Either DNSError (r, ByteString)
toResult forall a b. (a -> b) -> a -> b
$ ByteString -> IResult ByteString r
f ByteString
BS.empty
    toResult (A.Fail ByteString
_ [String]
ctx String
e) = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ [String]
ctx forall a. [a] -> [a] -> [a]
++ [String
e]

runSGetWithLeftovers :: SGet a -> ByteString -> Either DNSError ((a, PState), ByteString)
runSGetWithLeftovers :: forall a.
SGet a -> ByteString -> Either DNSError ((a, PState), ByteString)
runSGetWithLeftovers = forall a.
Int64
-> SGet a
-> ByteString
-> Either DNSError ((a, PState), ByteString)
runSGetWithLeftoversAt Int64
dnsTimeMid

runSPut :: SPut -> ByteString
runSPut :: SPut -> ByteString
runSPut = ByteString -> ByteString
LBS.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BB.toLazyByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> a
ST.evalState WState
initialWState

----------------------------------------------------------------

-- | Decode a domain name in A-label form to a leading label and a tail with
-- the remaining labels, unescaping backlashed chars and decimal triples along
-- the way. Any  U-label conversion belongs at the layer above this code.
--
parseLabel :: Word8 -> ByteString -> Either DNSError (ByteString, ByteString)
parseLabel :: Word8 -> ByteString -> Either DNSError (ByteString, ByteString)
parseLabel Word8
sep ByteString
dom =
    if (Word8 -> Bool) -> ByteString -> Bool
BS.any (forall a. Eq a => a -> a -> Bool
== Word8
bslash) ByteString
dom
    then IResult ByteString ByteString
-> Either DNSError (ByteString, ByteString)
toResult forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Result a
A.parse (Word8 -> ByteString -> Parser ByteString
labelParser Word8
sep forall a. Monoid a => a
mempty) ByteString
dom
    else (ByteString, ByteString)
-> Either DNSError (ByteString, ByteString)
check forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
safeTail forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (forall a. Eq a => a -> a -> Bool
== Word8
sep) ByteString
dom
  where
    toResult :: IResult ByteString ByteString
-> Either DNSError (ByteString, ByteString)
toResult (A.Partial ByteString -> IResult ByteString ByteString
c)  = IResult ByteString ByteString
-> Either DNSError (ByteString, ByteString)
toResult (ByteString -> IResult ByteString ByteString
c forall a. Monoid a => a
mempty)
    toResult (A.Done ByteString
tl ByteString
hd) = (ByteString, ByteString)
-> Either DNSError (ByteString, ByteString)
check (ByteString
hd, ByteString
tl)
    toResult IResult ByteString ByteString
_ = forall {b}. Either DNSError b
bottom
    safeTail :: ByteString -> ByteString
safeTail ByteString
bs | ByteString -> Bool
BS.null ByteString
bs = forall a. Monoid a => a
mempty
                | Bool
otherwise = HasCallStack => ByteString -> ByteString
BS.tail ByteString
bs
    check :: (ByteString, ByteString)
-> Either DNSError (ByteString, ByteString)
check r :: (ByteString, ByteString)
r@(ByteString
hd, ByteString
tl) | Bool -> Bool
not (ByteString -> Bool
BS.null ByteString
hd) Bool -> Bool -> Bool
|| ByteString -> Bool
BS.null ByteString
tl = forall a b. b -> Either a b
Right (ByteString, ByteString)
r
                     | Bool
otherwise = forall {b}. Either DNSError b
bottom
    bottom :: Either DNSError b
bottom = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> DNSError
DecodeError forall a b. (a -> b) -> a -> b
$ String
"invalid domain: " forall a. [a] -> [a] -> [a]
++ ByteString -> String
S8.unpack ByteString
dom

labelParser :: Word8 -> ByteString -> A.Parser ByteString
labelParser :: Word8 -> ByteString -> Parser ByteString
labelParser Word8
sep ByteString
acc = do
    ByteString
acc' <- forall a. Monoid a => a -> a -> a
mappend ByteString
acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Alternative f => a -> f a -> f a
A.option forall a. Monoid a => a
mempty Parser ByteString
simple
    Word8 -> ByteString -> Parser ByteString
labelEnd Word8
sep ByteString
acc' forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Parser Word8
escaped forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word8 -> ByteString -> Parser ByteString
labelParser Word8
sep forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word8 -> ByteString
BS.snoc ByteString
acc')
  where
    simple :: Parser ByteString
simple = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Parser a -> Parser (ByteString, a)
A.match Parser ByteString ()
skipUnescaped
      where
        skipUnescaped :: Parser ByteString ()
skipUnescaped = forall (f :: * -> *) a. Alternative f => f a -> f ()
A.skipMany1 forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> Parser Word8
A.satisfy Word8 -> Bool
notSepOrBslash
        notSepOrBslash :: Word8 -> Bool
notSepOrBslash Word8
w = Word8
w forall a. Eq a => a -> a -> Bool
/= Word8
sep Bool -> Bool -> Bool
&& Word8
w forall a. Eq a => a -> a -> Bool
/= Word8
bslash

    escaped :: Parser Word8
escaped = do
        (Word8 -> Bool) -> Parser ByteString ()
A.skip (forall a. Eq a => a -> a -> Bool
== Word8
bslash)
        forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Word -> Parser Word8
decodeDec forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a b.
Alternative f =>
f a -> f b -> f (Either a b)
A.eitherP Parser ByteString Word
digit Parser Word8
A.anyWord8
      where
        digit :: Parser ByteString Word
digit = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. (Word8 -> a) -> (a -> Bool) -> Parser a
A.satisfyWith (\Word8
n -> Word8
n forall a. Num a => a -> a -> a
- Word8
zero) (forall a. Ord a => a -> a -> Bool
<=Word8
9)
        decodeDec :: Word -> Parser Word8
decodeDec Word
d =
            Word -> Parser Word8
safeWord8 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Word -> Word -> Word -> Word
trigraph Word
d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString Word
digit forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Word
digit
          where
            trigraph :: Word -> Word -> Word -> Word
            trigraph :: Word -> Word -> Word -> Word
trigraph Word
x Word
y Word
z = Word
100 forall a. Num a => a -> a -> a
* Word
x forall a. Num a => a -> a -> a
+ Word
10 forall a. Num a => a -> a -> a
* Word
y forall a. Num a => a -> a -> a
+ Word
z

            safeWord8 :: Word -> A.Parser Word8
            safeWord8 :: Word -> Parser Word8
safeWord8 Word
n | Word
n forall a. Ord a => a -> a -> Bool
> Word
255 = forall (m :: * -> *) a. MonadPlus m => m a
mzero
                        | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
n

labelEnd :: Word8 -> ByteString -> A.Parser ByteString
labelEnd :: Word8 -> ByteString -> Parser ByteString
labelEnd Word8
sep ByteString
acc =
    (Word8 -> Bool) -> Parser Word8
A.satisfy (forall a. Eq a => a -> a -> Bool
== Word8
sep) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
acc forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
    forall t. Chunk t => Parser t ()
A.endOfInput       forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
acc

----------------------------------------------------------------

-- | Convert a wire-form label to presentation-form by escaping
-- the separator, special and non-printing characters.  For simple
-- labels with no bytes that require escaping we get back the input
-- bytestring asis with no copying or re-construction.
--
-- Note: the separator is required to be either \'.\' or \'\@\', but this
-- constraint is the caller's responsibility and is not checked here.
--
unparseLabel :: Word8 -> ByteString -> ByteString
unparseLabel :: Word8 -> ByteString -> ByteString
unparseLabel Word8
sep ByteString
label =
    if (Word8 -> Bool) -> ByteString -> Bool
BS.all (Word8 -> Word8 -> Bool
isPlain Word8
sep) ByteString
label
    then ByteString
label
    else forall {i} {t}. Monoid i => IResult i t -> t
toResult forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Result a
A.parse (Word8 -> ByteString -> Parser ByteString
labelUnparser Word8
sep forall a. Monoid a => a
mempty) ByteString
label
  where
    toResult :: IResult i t -> t
toResult (A.Partial i -> IResult i t
c) = IResult i t -> t
toResult (i -> IResult i t
c forall a. Monoid a => a
mempty)
    toResult (A.Done i
_ t
r) = t
r
    toResult IResult i t
_ = forall a e. Exception e => e -> a
E.throw DNSError
UnknownDNSError -- can't happen

labelUnparser :: Word8 -> ByteString -> A.Parser ByteString
labelUnparser :: Word8 -> ByteString -> Parser ByteString
labelUnparser Word8
sep ByteString
acc = do
    ByteString
acc' <- forall a. Monoid a => a -> a -> a
mappend ByteString
acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Alternative f => a -> f a -> f a
A.option forall a. Monoid a => a
mempty Parser ByteString
asis
    forall t. Chunk t => Parser t ()
A.endOfInput forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
acc' forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Parser ByteString
esc forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Word8 -> ByteString -> Parser ByteString
labelUnparser Word8
sep forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => a -> a -> a
mappend ByteString
acc')
  where
    -- Non-printables are escaped as decimal trigraphs, while printable
    -- specials just get a backslash prefix.
    esc :: Parser ByteString
esc = do
        Word8
w <- Parser Word8
A.anyWord8
        if Word8
w forall a. Ord a => a -> a -> Bool
<= Word8
32 Bool -> Bool -> Bool
|| Word8
w forall a. Ord a => a -> a -> Bool
>= Word8
127
        then let (Word8
q100, Word8
r100) = Word8
w forall a. Integral a => a -> a -> (a, a)
`divMod` Word8
100
                 (Word8
q10, Word8
r10) = Word8
r100 forall a. Integral a => a -> a -> (a, a)
`divMod` Word8
10
              in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
BS.pack [ Word8
bslash, Word8
zero forall a. Num a => a -> a -> a
+ Word8
q100, Word8
zero forall a. Num a => a -> a -> a
+ Word8
q10, Word8
zero forall a. Num a => a -> a -> a
+ Word8
r10 ]
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
BS.pack [ Word8
bslash, Word8
w ]

    -- Runs of plain bytes are recognized as a single chunk, which is then
    -- returned as-is.
    asis :: Parser ByteString
asis = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> Parser (ByteString, a)
A.match forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Alternative f => f a -> f ()
A.skipMany1 forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> Parser Word8
A.satisfy forall a b. (a -> b) -> a -> b
$ Word8 -> Word8 -> Bool
isPlain Word8
sep

-- | In the presentation form of DNS labels, these characters are escaped by
-- prepending a backlash. (They have special meaning in zone files). Whitespace
-- and other non-printable or non-ascii characters are encoded via "\DDD"
-- decimal escapes. The separator character is also quoted in each label. Note
-- that '@' is quoted even when not the separator.
escSpecials :: ByteString
escSpecials :: ByteString
escSpecials = ByteString
"\"$();@\\"

-- | Is the given byte the separator or one of the specials?
isSpecial :: Word8 -> Word8 -> Bool
isSpecial :: Word8 -> Word8 -> Bool
isSpecial Word8
sep Word8
w = Word8
w forall a. Eq a => a -> a -> Bool
== Word8
sep Bool -> Bool -> Bool
|| Word8 -> ByteString -> Maybe Int
BS.elemIndex Word8
w ByteString
escSpecials forall a. Eq a => a -> a -> Bool
/= forall a. Maybe a
Nothing

-- | Is the given byte a plain byte that reqires no escaping. The tests are
-- ordered to succeed or fail quickly in the most common cases. The test
-- ranges assume the expected numeric values of the named special characters.
-- Note: the separator is assumed to be either '.' or '@' and so not matched by
-- any of the first three fast-path 'True' cases.
isPlain :: Word8 -> Word8 -> Bool
isPlain :: Word8 -> Word8 -> Bool
isPlain Word8
sep Word8
w | Word8
w forall a. Ord a => a -> a -> Bool
>= Word8
127                 = Bool
False -- <DEL> + non-ASCII
              | Word8
w forall a. Ord a => a -> a -> Bool
> Word8
bslash               = Bool
True  -- ']'..'_'..'a'..'z'..'~'
              | Word8
w forall a. Ord a => a -> a -> Bool
>= Word8
zero Bool -> Bool -> Bool
&& Word8
w forall a. Ord a => a -> a -> Bool
< Word8
semi    = Bool
True  -- '0'..'9'..':'
              | Word8
w forall a. Ord a => a -> a -> Bool
> Word8
atsign Bool -> Bool -> Bool
&& Word8
w forall a. Ord a => a -> a -> Bool
< Word8
bslash = Bool
True  -- 'A'..'Z'..'['
              | Word8
w forall a. Ord a => a -> a -> Bool
<= Word8
32                  = Bool
False -- non-printables
              | Word8 -> Word8 -> Bool
isSpecial Word8
sep Word8
w          = Bool
False -- one of the specials
              | Bool
otherwise                = Bool
True  -- plain punctuation

-- | Some numeric byte constants.
zero, semi, atsign, bslash :: Word8
zero :: Word8
zero = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum Char
'0'    -- 48
semi :: Word8
semi = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum Char
';'    -- 59
atsign :: Word8
atsign = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum Char
'@'  -- 64
bslash :: Word8
bslash = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum Char
'\\' -- 92