{-# LANGUAGE CPP, ForeignFunctionInterface #-}

module Network.Wai.Handler.Warp.SendFile (
    sendFile
  , readSendFile
  , packHeader -- for testing
#ifndef WINDOWS
  , positionRead
#endif
  ) where

import qualified Data.ByteString as BS
import Network.Socket (Socket)
import Network.Socket.BufferPool

#ifdef WINDOWS
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Ptr (plusPtr)
import qualified System.IO as IO
#else
import qualified UnliftIO
import Foreign.C.Error (throwErrno)
import Foreign.C.Types
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Network.Sendfile
import Network.Wai.Handler.Warp.FdCache (openFile, closeFile)
import System.Posix.Types
#endif

import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

-- | Function to send a file based on sendfile() for Linux\/Mac\/FreeBSD.
--   This makes use of the file descriptor cache.
--   For other OSes, this is identical to 'readSendFile'.
--
-- Since: 3.1.0
sendFile :: Socket -> Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
#ifdef SENDFILEFD
sendFile :: Socket -> Buffer -> Int -> (ByteString -> IO ()) -> SendFile
sendFile Socket
s Buffer
_ Int
_ ByteString -> IO ()
_ FileId
fid Integer
off Integer
len IO ()
act [ByteString]
hdr = case Maybe Fd
mfid of
    -- settingsFdCacheDuration is 0
    Maybe Fd
Nothing -> Socket -> FilePath -> FileRange -> IO () -> [ByteString] -> IO ()
sendfileWithHeader   Socket
s FilePath
path (Integer -> Integer -> FileRange
PartOfFile Integer
off Integer
len) IO ()
act [ByteString]
hdr
    Just Fd
fd -> Socket -> Fd -> FileRange -> IO () -> [ByteString] -> IO ()
sendfileFdWithHeader Socket
s Fd
fd   (Integer -> Integer -> FileRange
PartOfFile Integer
off Integer
len) IO ()
act [ByteString]
hdr
  where
    mfid :: Maybe Fd
mfid = FileId -> Maybe Fd
fileIdFd FileId
fid
    path :: FilePath
path = FileId -> FilePath
fileIdPath FileId
fid
#else
sendFile _ = readSendFile
#endif

----------------------------------------------------------------

packHeader :: Buffer -> BufSize -> (ByteString -> IO ())
           -> IO () -> [ByteString]
           -> Int
           -> IO Int
packHeader :: Buffer
-> Int
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> Int
-> IO Int
packHeader Buffer
_   Int
_   ByteString -> IO ()
_    IO ()
_    [] Int
n = forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
packHeader Buffer
buf Int
siz ByteString -> IO ()
send IO ()
hook (ByteString
bs:[ByteString]
bss) Int
n
  | Int
len forall a. Ord a => a -> a -> Bool
< Int
room = do
      let dst :: Ptr b
dst = Buffer
buf forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy forall {b}. Ptr b
dst ByteString
bs
      Buffer
-> Int
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> Int
-> IO Int
packHeader Buffer
buf Int
siz ByteString -> IO ()
send IO ()
hook [ByteString]
bss (Int
n forall a. Num a => a -> a -> a
+ Int
len)
  | Bool
otherwise  = do
      let dst :: Ptr b
dst = Buffer
buf forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
          (ByteString
bs1, ByteString
bs2) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
room ByteString
bs
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy forall {b}. Ptr b
dst ByteString
bs1
      Buffer -> Int -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf Int
siz ByteString -> IO ()
send
      IO ()
hook
      Buffer
-> Int
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> Int
-> IO Int
packHeader Buffer
buf Int
siz ByteString -> IO ()
send IO ()
hook (ByteString
bs2forall a. a -> [a] -> [a]
:[ByteString]
bss) Int
0
  where
    len :: Int
len = ByteString -> Int
BS.length ByteString
bs
    room :: Int
room = Int
siz forall a. Num a => a -> a -> a
- Int
n

mini :: Int -> Integer -> Int
mini :: Int -> Integer -> Int
mini Int
i Integer
n
  | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i forall a. Ord a => a -> a -> Bool
< Integer
n = Int
i
  | Bool
otherwise          = forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n

-- | Function to send a file based on pread()\/send() for Unix.
--   This makes use of the file descriptor cache.
--   For Windows, this is emulated by 'Handle'.
--
-- Since: 3.1.0
#ifdef WINDOWS
readSendFile :: Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
readSendFile buf siz send fid off0 len0 hook headers = do
    hn <- packHeader buf siz send hook headers 0
    let room = siz - hn
        buf' = buf `plusPtr` hn
    IO.withBinaryFile path IO.ReadMode $ \h -> do
        IO.hSeek h IO.AbsoluteSeek off0
        n <- IO.hGetBufSome h buf' (mini room len0)
        bufferIO buf (hn + n) send
        hook
        let n' = fromIntegral n
        fptr <- newForeignPtr_ buf
        loop h fptr (len0 - n')
  where
    path = fileIdPath fid
    loop h fptr len
      | len <= 0  = return ()
      | otherwise = do
        n <- IO.hGetBufSome h buf (mini siz len)
        when (n /= 0) $ do
            let bs = PS fptr 0 n
                n' = fromIntegral n
            send bs
            hook
            loop h fptr (len - n')
#else
readSendFile :: Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
readSendFile :: Buffer -> Int -> (ByteString -> IO ()) -> SendFile
readSendFile Buffer
buf Int
siz ByteString -> IO ()
send FileId
fid Integer
off0 Integer
len0 IO ()
hook [ByteString]
headers =
  forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
UnliftIO.bracket IO Fd
setup Fd -> IO ()
teardown forall a b. (a -> b) -> a -> b
$ \Fd
fd -> do
    Int
hn <- Buffer
-> Int
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> Int
-> IO Int
packHeader Buffer
buf Int
siz ByteString -> IO ()
send IO ()
hook [ByteString]
headers Int
0
    let room :: Int
room = Int
siz forall a. Num a => a -> a -> a
- Int
hn
        buf' :: Ptr b
buf' = Buffer
buf forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
hn
    Int
n <- Fd -> Buffer -> Int -> Integer -> IO Int
positionRead Fd
fd forall {b}. Ptr b
buf' (Int -> Integer -> Int
mini Int
room Integer
len0) Integer
off0
    Buffer -> Int -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf (Int
hn forall a. Num a => a -> a -> a
+ Int
n) ByteString -> IO ()
send
    IO ()
hook
    let n' :: Integer
n' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
    Fd -> Integer -> Integer -> IO ()
loop Fd
fd (Integer
len0 forall a. Num a => a -> a -> a
- Integer
n') (Integer
off0 forall a. Num a => a -> a -> a
+ Integer
n')
  where
    path :: FilePath
path = FileId -> FilePath
fileIdPath FileId
fid
    setup :: IO Fd
setup = case FileId -> Maybe Fd
fileIdFd FileId
fid of
       Just Fd
fd -> forall (m :: * -> *) a. Monad m => a -> m a
return Fd
fd
       Maybe Fd
Nothing -> FilePath -> IO Fd
openFile FilePath
path
    teardown :: Fd -> IO ()
teardown Fd
fd = case FileId -> Maybe Fd
fileIdFd FileId
fid of
       Just Fd
_  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
       Maybe Fd
Nothing -> Fd -> IO ()
closeFile Fd
fd
    loop :: Fd -> Integer -> Integer -> IO ()
loop Fd
fd Integer
len Integer
off
      | Integer
len forall a. Ord a => a -> a -> Bool
<= Integer
0  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do
          Int
n <- Fd -> Buffer -> Int -> Integer -> IO Int
positionRead Fd
fd Buffer
buf (Int -> Integer -> Int
mini Int
siz Integer
len) Integer
off
          Buffer -> Int -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf Int
n ByteString -> IO ()
send
          let n' :: Integer
n' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
          IO ()
hook
          Fd -> Integer -> Integer -> IO ()
loop Fd
fd (Integer
len forall a. Num a => a -> a -> a
- Integer
n') (Integer
off forall a. Num a => a -> a -> a
+ Integer
n')

positionRead :: Fd -> Buffer -> BufSize -> Integer -> IO Int
positionRead :: Fd -> Buffer -> Int -> Integer -> IO Int
positionRead Fd
fd Buffer
buf Int
siz Integer
off = do
    Int
bytes <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fd -> Ptr CChar -> ByteCount -> FileOffset -> IO CSsize
c_pread Fd
fd (forall a b. Ptr a -> Ptr b
castPtr Buffer
buf) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
off)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytes forall a. Ord a => a -> a -> Bool
< Int
0) forall a b. (a -> b) -> a -> b
$ forall a. FilePath -> IO a
throwErrno FilePath
"positionRead"
    forall (m :: * -> *) a. Monad m => a -> m a
return Int
bytes

foreign import ccall unsafe "pread"
  c_pread :: Fd -> Ptr CChar -> ByteCount -> FileOffset -> IO CSsize
#endif