{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Conduit where

import qualified Data.ByteString as S
import qualified Data.IORef as I
import Data.Word8 (_0, _9, _A, _F, _a, _cr, _f, _lf)
import UnliftIO (assert, throwIO)

import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

-- | Contains a @Source@ and a byte count that is still to be read in.
data ISource = ISource !Source !(I.IORef Int)

mkISource :: Source -> Int -> IO ISource
mkISource :: Source -> Int -> IO ISource
mkISource Source
src Int
cnt = do
    IORef Int
ref <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
I.newIORef Int
cnt
    ISource -> IO ISource
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ISource -> IO ISource) -> ISource -> IO ISource
forall a b. (a -> b) -> a -> b
$! Source -> IORef Int -> ISource
ISource Source
src IORef Int
ref

-- | Given an @IsolatedBSSource@ provide a @Source@ that only allows up to the
-- specified number of bytes to be passed downstream. All leftovers should be
-- retained within the @Source@. If there are not enough bytes available,
-- throws a @ConnectionClosedByPeer@ exception.
readISource :: ISource -> IO ByteString
readISource :: ISource -> IO ByteString
readISource (ISource Source
src IORef Int
ref) = do
    Int
count <- IORef Int -> IO Int
forall a. IORef a -> IO a
I.readIORef IORef Int
ref
    if Int
count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
        then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
        else do
            ByteString
bs <- Source -> IO ByteString
readSource Source
src

            -- If no chunk available, then there aren't enough bytes in the
            -- stream. Throw a ConnectionClosedByPeer
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
S.null ByteString
bs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
ConnectionClosedByPeer

            let -- How many of the bytes in this chunk to send downstream
                toSend :: Int
toSend = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
count (ByteString -> Int
S.length ByteString
bs)
                -- How many bytes will still remain to be sent downstream
                count' :: Int
count' = Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
toSend

            IORef Int -> Int -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef Int
ref Int
count'

            if Int
count' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                then -- The expected count is greater than the size of the
                -- chunk we just read. Send the entire chunk
                -- downstream, and then loop on this function for the
                -- next chunk.
                    ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
                else do
                    -- Some of the bytes in this chunk should not be sent
                    -- downstream. Split up the chunk into the sent and
                    -- not-sent parts, add the not-sent parts onto the new
                    -- source, and send the rest of the chunk downstream.
                    let (ByteString
x, ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
toSend ByteString
bs
                    Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
y
                    Bool -> IO ByteString -> IO ByteString
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
count' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
x

----------------------------------------------------------------

data CSource = CSource !Source !(I.IORef ChunkState)

data ChunkState
    = NeedLen
    | NeedLenNewline
    | HaveLen Word
    | DoneChunking
    deriving (Int -> ChunkState -> ShowS
[ChunkState] -> ShowS
ChunkState -> String
(Int -> ChunkState -> ShowS)
-> (ChunkState -> String)
-> ([ChunkState] -> ShowS)
-> Show ChunkState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ChunkState -> ShowS
showsPrec :: Int -> ChunkState -> ShowS
$cshow :: ChunkState -> String
show :: ChunkState -> String
$cshowList :: [ChunkState] -> ShowS
showList :: [ChunkState] -> ShowS
Show)

mkCSource :: Source -> IO CSource
mkCSource :: Source -> IO CSource
mkCSource Source
src = do
    IORef ChunkState
ref <- ChunkState -> IO (IORef ChunkState)
forall a. a -> IO (IORef a)
I.newIORef ChunkState
NeedLen
    CSource -> IO CSource
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CSource -> IO CSource) -> CSource -> IO CSource
forall a b. (a -> b) -> a -> b
$! Source -> IORef ChunkState -> CSource
CSource Source
src IORef ChunkState
ref

readCSource :: CSource -> IO ByteString
readCSource :: CSource -> IO ByteString
readCSource (CSource Source
src IORef ChunkState
ref) = do
    ChunkState
mlen <- IORef ChunkState -> IO ChunkState
forall a. IORef a -> IO a
I.readIORef IORef ChunkState
ref
    ChunkState -> IO ByteString
go ChunkState
mlen
  where
    withLen :: Word -> ByteString -> IO ByteString
withLen Word
0 ByteString
bs = do
        Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs
        IO ()
dropCRLF
        ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
S.empty ChunkState
DoneChunking
    withLen Word
len ByteString
bs
        | ByteString -> Bool
S.null ByteString
bs = do
            -- FIXME should this throw an exception if len > 0?
            IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
DoneChunking
            ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
        | Bool
otherwise =
            case ByteString -> Int
S.length ByteString
bs Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
len of
                Ordering
EQ -> ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
bs ChunkState
NeedLenNewline
                Ordering
LT -> ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
bs (ChunkState -> IO ByteString) -> ChunkState -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Word -> ChunkState
HaveLen (Word -> ChunkState) -> Word -> ChunkState
forall a b. (a -> b) -> a -> b
$ Word
len Word -> Word -> Word
forall a. Num a => a -> a -> a
- Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
S.length ByteString
bs)
                Ordering
GT -> do
                    let (ByteString
x, ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
len) ByteString
bs
                    Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
y
                    ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
x ChunkState
NeedLenNewline

    yield' :: b -> ChunkState -> IO b
yield' b
bs ChunkState
mlen = do
        IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
mlen
        b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
bs

    dropCRLF :: IO ()
dropCRLF = do
        ByteString
bs <- Source -> IO ByteString
readSource Source
src
        case ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs of
            Maybe (Word8, ByteString)
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just (Word8
w8, ByteString
bs')
                | Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_cr -> ByteString -> IO ()
dropLF ByteString
bs'
                | Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf -> Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs'
                | Bool
otherwise -> Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs

    dropLF :: ByteString -> IO ()
dropLF ByteString
bs =
        case ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs of
            Maybe (Word8, ByteString)
Nothing -> do
                ByteString
bs2 <- Source -> IO ByteString
readSource' Source
src
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
S.null ByteString
bs2) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ()
dropLF ByteString
bs2
            Just (Word8
w8, ByteString
bs') ->
                Source -> ByteString -> IO ()
leftoverSource Source
src (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$
                    if Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf then ByteString
bs' else ByteString
bs

    go :: ChunkState -> IO ByteString
go ChunkState
NeedLen = IO ByteString
getLen
    go ChunkState
NeedLenNewline = IO ()
dropCRLF IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ByteString
getLen
    go (HaveLen Word
0) = do
        -- Drop the final CRLF
        IO ()
dropCRLF
        IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
DoneChunking
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
    go (HaveLen Word
len) = do
        ByteString
bs <- Source -> IO ByteString
readSource Source
src
        Word -> ByteString -> IO ByteString
withLen Word
len ByteString
bs
    go ChunkState
DoneChunking = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty

    -- Get the length from the source, and then pass off control to withLen
    getLen :: IO ByteString
getLen = do
        ByteString
bs <- Source -> IO ByteString
readSource Source
src
        if ByteString -> Bool
S.null ByteString
bs
            then do
                IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref (ChunkState -> IO ()) -> ChunkState -> IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> ChunkState -> ChunkState
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False (ChunkState -> ChunkState) -> ChunkState -> ChunkState
forall a b. (a -> b) -> a -> b
$ Word -> ChunkState
HaveLen Word
0
                ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
            else do
                (ByteString
x, ByteString
y) <-
                    case (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf) ByteString
bs of
                        (ByteString
x, ByteString
y)
                            | ByteString -> Bool
S.null ByteString
y -> do
                                ByteString
bs2 <- Source -> IO ByteString
readSource' Source
src
                                (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$
                                    if ByteString -> Bool
S.null ByteString
bs2
                                        then (ByteString
x, ByteString
y)
                                        else (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf) (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
bs ByteString -> ByteString -> ByteString
`S.append` ByteString
bs2
                            | Bool
otherwise -> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
x, ByteString
y)
                let w :: Word
w =
                        (Word -> Word8 -> Word) -> Word -> ByteString -> Word
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
S.foldl' (\Word
i Word8
c -> Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
16 Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word8 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word8
hexToWord Word8
c)) Word
0 (ByteString -> Word) -> ByteString -> Word
forall a b. (a -> b) -> a -> b
$
                            (Word8 -> Bool) -> ByteString -> ByteString
S.takeWhile Word8 -> Bool
isHexDigit ByteString
x

                let y' :: ByteString
y' = Int -> ByteString -> ByteString
S.drop Int
1 ByteString
y
                ByteString
y'' <-
                    if ByteString -> Bool
S.null ByteString
y'
                        then Source -> IO ByteString
readSource Source
src
                        else ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
y'
                Word -> ByteString -> IO ByteString
withLen Word
w ByteString
y''

    hexToWord :: Word8 -> Word8
hexToWord Word8
w
        | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_9 = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
_0
        | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_F = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
55
        | Bool
otherwise = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
87

isHexDigit :: Word8 -> Bool
isHexDigit :: Word8 -> Bool
isHexDigit Word8
w =
    Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_0 Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_9
        Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_A Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_F
        Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_a Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_f