{-# LANGUAGE PatternGuards #-}

module Network.SCGI (
    runOnceSCGI
  , runSCGI
  , runSCGIConcurrent
  , runSCGIConcurrent'

  , module Network.CGI
  ) where

import Control.Exception.Extensible (SomeException, bracket, catch, finally)
import Control.Monad
import Control.Monad.Fix    (fix)
import Control.Concurrent
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as L
import Network
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
import Network.CGI
import Network.CGI.Monad    (runCGIT)
import Network.CGI.Protocol (runCGIEnvFPS)
import Prelude hiding (catch)
import System.IO            (hPutStrLn, stderr)


runSCGI        :: PortID -> CGI CGIResult -> IO ()
runSCGI port f = listen port $ fix $ \loop socket -> do
    (sock, _) <- NS.accept socket
    catch
        (doSCGI f sock)
        (\e -> hPutStrLn stderr $ "scgi: "++show (e::SomeException))
    loop socket

runOnceSCGI        :: PortID -> CGI CGIResult -> IO ()
runOnceSCGI port f = listen port $ \socket -> do
    (sock, _) <- NS.accept socket
    doSCGI f sock

runSCGIConcurrent :: Int               -- ^ Maximum number of concurrent threads
                  -> PortID
                  -> CGI CGIResult
                  -> IO ()
runSCGIConcurrent = runSCGIConcurrent' forkOS

runSCGIConcurrent' :: (IO () -> IO a)  -- ^ Fork function
                   -> Int              -- ^ Maximum number of concurrent threads
                   -> PortID
                   -> CGI CGIResult
                   -> IO ()
runSCGIConcurrent' fork maxThreads port f = do
    qsem <- newQSem maxThreads
    listen port $ fix $ \loop socket -> do
        waitQSem qsem
        (sock, _) <- NS.accept socket
        _ <- fork $ do
            catch (do
                    finally
                        (doSCGI f sock)
                        (signalQSem qsem)
                )
                (\e -> hPutStrLn stderr $ "scgi: "++show (e::SomeException))
        loop socket

withSocket :: Socket -> (Socket -> IO ()) -> IO ()
withSocket sock doit = finally (doit sock) (sClose sock)

stopAtNothing :: [Maybe a] -> [a]
stopAtNothing (Nothing:_) = []
stopAtNothing (Just a:xs) = a:stopAtNothing xs
stopAtNothing [] = []

-- | This function replaces Data.ByteString.hGetContents, because the latter is
-- now (on GHC 6.10.4 and GHC 6.12.2, and network-2.2.1.7) acting in such a way
-- that when a block is received over the TCP connection, it is not immediately
-- delivered.  This causes the SCGI server to stall.  I don't know when this
-- behaviour changed.
lazyContents :: Socket -> IO (ThreadId, L.ByteString)
lazyContents s = do
    ch <- newChan
    tid <- forkIO $ (forever $ do
        blk <- NSB.recv s 16384
        writeChan ch (Just blk))
      `finally`
        writeChan ch Nothing
    blks <- getChanContents ch
    return $ (tid, L.fromChunks (stopAtNothing blks))

doSCGI          :: CGI CGIResult -> Socket -> IO ()
doSCGI f sock = withSocket sock $ \sock -> do
    (tid, input) <- lazyContents sock
    do
        let (hdrs, body) = request input
        output <- runCGIEnvFPS hdrs body (runCGIT f)
        forM_ (L.toChunks output) $ sendFully sock
      `finally`
        killThread tid

sendFully :: Socket -> B.ByteString -> IO ()
sendFully s bs = do
    sent <- NSB.send s bs
    let remaining = B.length bs - sent
    if remaining == 0
        then return ()
        else sendFully s (B.drop sent bs)

listen           :: PortID -> (Socket -> IO ()) -> IO ()
listen port loop = withSocketsDo $
    bracket (listenOn port) sClose loop

request     :: L.ByteString -> ([(String, String)], L.ByteString)
request str = (headers hdrs, body)
  where
    (hdrs, body) = netstring str

netstring    :: L.ByteString -> (String, L.ByteString)
netstring cs =
    let (len, rest) = L.span (/= ':') cs
        (str, body) = L.splitAt (read $ L.unpack len) (L.tail rest)
    in (L.unpack str, L.tail body)

headers :: String -> [(String, String)]
headers = pairs . split '\NUL'

pairs           :: [a] -> [(a, a)]
pairs (x:y:xys) = (x, y) : pairs xys
pairs _         = []

split            :: Eq a => a -> [a] -> [[a]]
split delim str
    | [] <- rest = [token]
    | otherwise  = token : split delim (tail rest)
  where
    (token, rest) = span (/= delim) str