{-# LANGUAGE CPP, ForeignFunctionInterface #-}

module Network.Wai.Handler.Warp.SendFile (
    sendFile
  , readSendFile
  , packHeader -- for testing
#ifndef WINDOWS
  , positionRead
#endif
  ) where

import qualified Data.ByteString as BS
import Network.Socket (Socket)

#ifdef WINDOWS
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Ptr (plusPtr)
import qualified System.IO as IO
#else
import qualified UnliftIO
import Foreign.C.Error (throwErrno)
import Foreign.C.Types
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Network.Sendfile
import Network.Wai.Handler.Warp.FdCache (openFile, closeFile)
import System.Posix.Types
#endif

import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

-- | Function to send a file based on sendfile() for Linux\/Mac\/FreeBSD.
--   This makes use of the file descriptor cache.
--   For other OSes, this is identical to 'readSendFile'.
--
-- Since: 3.1.0
sendFile :: Socket -> Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
#ifdef SENDFILEFD
sendFile :: Socket -> Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
sendFile Socket
s Buffer
_ BufSize
_ ByteString -> IO ()
_ FileId
fid Integer
off Integer
len IO ()
act [ByteString]
hdr = case Maybe Fd
mfid of
    -- settingsFdCacheDuration is 0
    Maybe Fd
Nothing -> Socket -> FilePath -> FileRange -> IO () -> [ByteString] -> IO ()
sendfileWithHeader   Socket
s FilePath
path (Integer -> Integer -> FileRange
PartOfFile Integer
off Integer
len) IO ()
act [ByteString]
hdr
    Just Fd
fd -> Socket -> Fd -> FileRange -> IO () -> [ByteString] -> IO ()
sendfileFdWithHeader Socket
s Fd
fd   (Integer -> Integer -> FileRange
PartOfFile Integer
off Integer
len) IO ()
act [ByteString]
hdr
  where
    mfid :: Maybe Fd
mfid = FileId -> Maybe Fd
fileIdFd FileId
fid
    path :: FilePath
path = FileId -> FilePath
fileIdPath FileId
fid
#else
sendFile _ = readSendFile
#endif

----------------------------------------------------------------

packHeader :: Buffer -> BufSize -> (ByteString -> IO ())
           -> IO () -> [ByteString]
           -> Int
           -> IO Int
packHeader :: Buffer
-> BufSize
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> BufSize
-> IO BufSize
packHeader Buffer
_   BufSize
_   ByteString -> IO ()
_    IO ()
_    [] BufSize
n = BufSize -> IO BufSize
forall (m :: * -> *) a. Monad m => a -> m a
return BufSize
n
packHeader Buffer
buf BufSize
siz ByteString -> IO ()
send IO ()
hook (ByteString
bs:[ByteString]
bss) BufSize
n
  | BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufSize
room = do
      let dst :: Ptr b
dst = Buffer
buf Buffer -> BufSize -> Ptr b
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n
      IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
forall b. Ptr b
dst ByteString
bs
      Buffer
-> BufSize
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> BufSize
-> IO BufSize
packHeader Buffer
buf BufSize
siz ByteString -> IO ()
send IO ()
hook [ByteString]
bss (BufSize
n BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
+ BufSize
len)
  | Bool
otherwise  = do
      let dst :: Ptr b
dst = Buffer
buf Buffer -> BufSize -> Ptr b
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n
          (ByteString
bs1, ByteString
bs2) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
room ByteString
bs
      IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
forall b. Ptr b
dst ByteString
bs1
      Buffer -> BufSize -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf BufSize
siz ByteString -> IO ()
send
      IO ()
hook
      Buffer
-> BufSize
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> BufSize
-> IO BufSize
packHeader Buffer
buf BufSize
siz ByteString -> IO ()
send IO ()
hook (ByteString
bs2ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss) BufSize
0
  where
    len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
    room :: BufSize
room = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
n

mini :: Int -> Integer -> Int
mini :: BufSize -> Integer -> BufSize
mini BufSize
i Integer
n
  | BufSize -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
n = BufSize
i
  | Bool
otherwise          = Integer -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n

-- | Function to send a file based on pread()\/send() for Unix.
--   This makes use of the file descriptor cache.
--   For Windows, this is emulated by 'Handle'.
--
-- Since: 3.1.0
#ifdef WINDOWS
readSendFile :: Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
readSendFile buf siz send fid off0 len0 hook headers = do
    hn <- packHeader buf siz send hook headers 0
    let room = siz - hn
        buf' = buf `plusPtr` hn
    IO.withBinaryFile path IO.ReadMode $ \h -> do
        IO.hSeek h IO.AbsoluteSeek off0
        n <- IO.hGetBufSome h buf' (mini room len0)
        bufferIO buf (hn + n) send
        hook
        let n' = fromIntegral n
        fptr <- newForeignPtr_ buf
        loop h fptr (len0 - n')
  where
    path = fileIdPath fid
    loop h fptr len
      | len <= 0  = return ()
      | otherwise = do
        n <- IO.hGetBufSome h buf (mini siz len)
        when (n /= 0) $ do
            let bs = PS fptr 0 n
                n' = fromIntegral n
            send bs
            hook
            loop h fptr (len - n')
#else
readSendFile :: Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
readSendFile :: Buffer -> BufSize -> (ByteString -> IO ()) -> SendFile
readSendFile Buffer
buf BufSize
siz ByteString -> IO ()
send FileId
fid Integer
off0 Integer
len0 IO ()
hook [ByteString]
headers =
  IO Fd -> (Fd -> IO ()) -> (Fd -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
UnliftIO.bracket IO Fd
setup Fd -> IO ()
teardown ((Fd -> IO ()) -> IO ()) -> (Fd -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Fd
fd -> do
    BufSize
hn <- Buffer
-> BufSize
-> (ByteString -> IO ())
-> IO ()
-> [ByteString]
-> BufSize
-> IO BufSize
packHeader Buffer
buf BufSize
siz ByteString -> IO ()
send IO ()
hook [ByteString]
headers BufSize
0
    let room :: BufSize
room = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
hn
        buf' :: Ptr b
buf' = Buffer
buf Buffer -> BufSize -> Ptr b
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
hn
    BufSize
n <- Fd -> Buffer -> BufSize -> Integer -> IO BufSize
positionRead Fd
fd Buffer
forall b. Ptr b
buf' (BufSize -> Integer -> BufSize
mini BufSize
room Integer
len0) Integer
off0
    Buffer -> BufSize -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf (BufSize
hn BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
+ BufSize
n) ByteString -> IO ()
send
    IO ()
hook
    let n' :: Integer
n' = BufSize -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
n
    Fd -> Integer -> Integer -> IO ()
loop Fd
fd (Integer
len0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
n') (Integer
off0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
n')
  where
    path :: FilePath
path = FileId -> FilePath
fileIdPath FileId
fid
    setup :: IO Fd
setup = case FileId -> Maybe Fd
fileIdFd FileId
fid of
       Just Fd
fd -> Fd -> IO Fd
forall (m :: * -> *) a. Monad m => a -> m a
return Fd
fd
       Maybe Fd
Nothing -> FilePath -> IO Fd
openFile FilePath
path
    teardown :: Fd -> IO ()
teardown Fd
fd = case FileId -> Maybe Fd
fileIdFd FileId
fid of
       Just Fd
_  -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
       Maybe Fd
Nothing -> Fd -> IO ()
closeFile Fd
fd
    loop :: Fd -> Integer -> Integer -> IO ()
loop Fd
fd Integer
len Integer
off
      | Integer
len Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0  = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do
          BufSize
n <- Fd -> Buffer -> BufSize -> Integer -> IO BufSize
positionRead Fd
fd Buffer
buf (BufSize -> Integer -> BufSize
mini BufSize
siz Integer
len) Integer
off
          Buffer -> BufSize -> (ByteString -> IO ()) -> IO ()
bufferIO Buffer
buf BufSize
n ByteString -> IO ()
send
          let n' :: Integer
n' = BufSize -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
n
          IO ()
hook
          Fd -> Integer -> Integer -> IO ()
loop Fd
fd (Integer
len Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
n') (Integer
off Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
n')

positionRead :: Fd -> Buffer -> BufSize -> Integer -> IO Int
positionRead :: Fd -> Buffer -> BufSize -> Integer -> IO BufSize
positionRead Fd
fd Buffer
buf BufSize
siz Integer
off = do
    BufSize
bytes <- CSsize -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSsize -> BufSize) -> IO CSsize -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fd -> Ptr CChar -> ByteCount -> FileOffset -> IO CSsize
c_pread Fd
fd (Buffer -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Buffer
buf) (BufSize -> ByteCount
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
siz) (Integer -> FileOffset
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
off)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BufSize
bytes BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufSize
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> IO ()
forall a. FilePath -> IO a
throwErrno FilePath
"positionRead"
    BufSize -> IO BufSize
forall (m :: * -> *) a. Monad m => a -> m a
return BufSize
bytes

foreign import ccall unsafe "pread"
  c_pread :: Fd -> Ptr CChar -> ByteCount -> FileOffset -> IO CSsize
#endif