{-# LANGUAGE PatternGuards  #-}

module Network.SCGI (
    runOnceSCGI
  , runSCGI
  , runSCGIConcurrent
  , runSCGIConcurrent'

  , module Network.CGI
  ) where

import qualified Control.Exception as E (bracket, catch, finally)
import Control.Monad.Fix    (fix)
import Control.Concurrent
import Data.ByteString.Lazy.Char8 (ByteString)
import Network
import Network.CGI
import Network.CGI.Monad    (runCGIT)
import Network.CGI.Protocol (runCGIEnvFPS)
import System.IO            (Handle, hPutStrLn, stderr, hClose)

import qualified Data.ByteString.Lazy.Char8 as B

runSCGI        :: PortID -> CGI CGIResult -> IO ()
runSCGI port f = listen port $ fix $ \loop socket -> do
    (handle, _, _) <- accept socket
    E.catch
        (doSCGI f handle)
        (\e -> hPutStrLn stderr $ "scgi: "++show e)
    loop socket

runOnceSCGI        :: PortID -> CGI CGIResult -> IO ()
runOnceSCGI port f = listen port $ \socket -> do
    (handle, _, _) <- accept socket
    doSCGI f handle

runSCGIConcurrent :: Int               -- ^ Maximum number of concurrent threads
                  -> PortID
                  -> CGI CGIResult
                  -> IO ()
runSCGIConcurrent = runSCGIConcurrent' forkOS

runSCGIConcurrent' :: (IO () -> IO a)  -- ^ Fork function
                  -> Int               -- ^ Maximum number of concurrent threads
                  -> PortID
                  -> CGI CGIResult
                  -> IO ()
runSCGIConcurrent' fork maxThreads port f = do
    qsem <- newQSem maxThreads
    listen port $ fix $ \loop socket -> do
        waitQSem qsem
        (handle, _, _) <- accept socket
        fork $ do
            E.catch (do
                    E.finally
                        (doSCGI f handle)
                        (signalQSem qsem)
                )
                (\e -> hPutStrLn stderr $ "scgi: "++show e)
        loop socket

withHandle :: Handle -> (Handle -> IO ()) -> IO ()
withHandle handle doit = E.finally (doit handle) (hClose handle)

doSCGI          :: CGI CGIResult -> Handle -> IO ()
doSCGI f handle = withHandle handle $ \handle -> do
    (hdrs, body) <- fmap request $ B.hGetContents handle
    output       <- runCGIEnvFPS hdrs body (runCGIT f)
    B.hPut handle output

listen           :: PortID -> (Socket -> IO ()) -> IO ()
listen port loop = withSocketsDo $
    E.bracket (listenOn port) sClose loop

request     :: ByteString -> ([(String, String)], ByteString)
request str = (headers hdrs, body)
  where
    (hdrs, body) = netstring str

netstring    :: ByteString -> (String, ByteString)
netstring cs =
    let (len, rest) = B.span (/= ':') cs
        (str, body) = B.splitAt (read $ B.unpack len) (B.tail rest)
    in (B.unpack str, B.tail body)

headers :: String -> [(String, String)]
headers = pairs . split '\NUL'

pairs           :: [a] -> [(a, a)]
pairs (x:y:xys) = (x, y) : pairs xys
pairs _         = []

split            :: Eq a => a -> [a] -> [[a]]
split delim str
    | [] <- rest = [token]
    | otherwise  = token : split delim (tail rest)
  where
    (token, rest) = span (/= delim) str