module Network.SCGI (SCGIT, SCGI, runRequest, header, allHeaders, method, path, setHeader, Headers, Body, Status, Response(..)) where
import Control.Applicative ((<$>), (<*>), (<*))
import Control.Arrow (first)
import Control.Monad (liftM, liftM2)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT, runReaderT, MonadReader, asks)
import Control.Monad.State (StateT, runStateT, MonadState, modify)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Data.Attoparsec.ByteString.Char8 (Parser, IResult(..), parseOnly, parseWith, char, decimal, take, takeTill)
import Data.Attoparsec.Combinator (many1)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 ()
import Data.Char (toUpper)
import Data.CaseInsensitive (CI)
import qualified Data.CaseInsensitive as CI
import Data.Map (Map)
import qualified Data.Map as M
import System.IO (Handle)
import Prelude hiding (take)
type Headers = Map (CI B.ByteString) B.ByteString
type Body = BL.ByteString
type Status = BL.ByteString
data Response = Response Status Body
newtype SCGIT m a = SCGIT (ReaderT Headers (StateT Headers m) a)
deriving (Monad, MonadState Headers, MonadReader Headers, MonadIO)
type SCGI = SCGIT IO
instance MonadTrans SCGIT where
lift = SCGIT . lift . lift
runSCGIT :: Monad m => Headers -> SCGIT m Response -> m (Response, Headers)
runSCGIT headers (SCGIT r) = runStateT (runReaderT r headers) M.empty
header :: Monad m
=> B.ByteString
-> SCGIT m (Maybe B.ByteString)
header name = asks (M.lookup (CI.mk name))
allHeaders :: Monad m => SCGIT m [(B.ByteString, B.ByteString)]
allHeaders = asks (map (first CI.original) . M.toList)
method :: Monad m => SCGIT m (Maybe B.ByteString)
method = liftM (B8.map toUpper) `liftM` header "REQUEST_METHOD"
path :: Monad m => SCGIT m (Maybe B.ByteString)
path = do
path1 <- header "SCRIPT_NAME"
path2 <- header "PATH_INFO"
return $ liftM2 B.append path1 path2
setHeader :: Monad m
=> B.ByteString
-> B.ByteString
-> SCGIT m ()
setHeader name value = modify (M.insert (CI.mk name) value)
runRequest :: MonadIO m
=> Handle
-> (Body -> SCGIT m Response)
-> m ()
runRequest h f = do
result <- liftIO $ parseWith (B.hGetSome h 4096) netstringParser ""
case result of
Done rest headerString ->
case parseOnly (many1 headerParser) headerString of
Left e -> error e
Right headers -> do
let headerMap = M.fromList $ map (first CI.mk) headers
len' = B8.readInt $ M.findWithDefault (error "CONTENT_LENGTH missing from request") "CONTENT_LENGTH" headerMap
case len' of
Just (len, _) -> do
let c = fromIntegral (len B.length rest)
body <- liftIO $ (BL.fromChunks [rest] `BL.append`) `liftM` (if c > 0 then BL.hGet h c else return "")
(Response status body', headers') <- runSCGIT headerMap (f body)
liftIO $ BL.hPutStr h $ BL.concat ["Status: ", status, "\r\n"]
liftIO $ mapM_ (\(k, v) -> B.hPutStr h $ B.concat [CI.original k, ": ", v, "\r\n"]) $ M.toList headers'
liftIO $ BL.hPutStr h "\r\n"
liftIO $ BL.hPutStr h body'
_ -> error "Failed to parse CONTENT_LENGTH."
_ -> error "Failed to parse SCGI request."
netstringParser :: Parser B.ByteString
netstringParser = do
count <- decimal <* char ':'
take count <* char ','
headerParser :: Parser (B.ByteString, B.ByteString)
headerParser = (,) <$> cStringParser <*> cStringParser
cStringParser :: Parser B.ByteString
cStringParser = takeTill (== '\NUL') <* char '\NUL'