module Network.SCGI (SCGIT, SCGI, runRequest, header, allHeaders, body, method, path, setHeader, responseHeader, Headers, Body, Status, Response(..), negotiate) where
import Control.Applicative ((<$>), (<*>), (<*), (*>))
import Control.Arrow (first)
import Control.Exception (SomeException)
import Control.Monad (liftM, liftM2)
import Control.Monad.CatchIO (MonadCatchIO(..))
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT, runReaderT, MonadReader, asks)
import Control.Monad.State (StateT, runStateT, MonadState, modify, gets)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Data.Attoparsec.ByteString.Char8 (Parser, IResult(..), parseOnly, parseWith, char, string, skipSpace, decimal, take, takeTill, inClass, rational)
import Data.Attoparsec.Combinator (many1, sepBy, option)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.UTF8 as BLU
import qualified Data.ByteString.Lazy.Char8 ()
import Data.Char (toUpper)
import Data.CaseInsensitive (CI)
import qualified Data.CaseInsensitive as CI
import Data.Function (on)
import Data.List (sortBy, find, maximumBy)
import Data.Maybe (mapMaybe)
import qualified Data.Map.Lazy as M
import qualified System.FilePath.Glob as G
import System.IO (Handle)
import Prelude hiding (take)
type Headers = M.Map (CI B.ByteString) B.ByteString
type Body = BL.ByteString
type Status = BL.ByteString
data Response = Response Status Body
newtype SCGIT m a = SCGIT (ReaderT (Headers, Body) (StateT Headers m) a)
deriving (Monad, MonadState Headers, MonadReader (Headers, Body), MonadIO, MonadCatchIO)
type SCGI = SCGIT IO
instance MonadTrans SCGIT where
lift = SCGIT . lift . lift
runSCGIT :: MonadIO m => Headers -> Body -> SCGIT m Response -> m (Response, Headers)
runSCGIT headers body' (SCGIT r) = runStateT (runReaderT r (headers, body')) M.empty
header :: Monad m
=> B.ByteString
-> SCGIT m (Maybe B.ByteString)
header name = asks (M.lookup (CI.mk name) . fst)
allHeaders :: Monad m => SCGIT m [(B.ByteString, B.ByteString)]
allHeaders = asks (map (first CI.original) . M.toList . fst)
body :: Monad m
=> SCGIT m (BL.ByteString)
body = asks snd
method :: Monad m => SCGIT m (Maybe B.ByteString)
method = liftM (B8.map toUpper) `liftM` header "REQUEST_METHOD"
path :: Monad m => SCGIT m (Maybe B.ByteString)
path = do
path1 <- header "SCRIPT_NAME"
path2 <- header "PATH_INFO"
return $ liftM2 B.append path1 path2
setHeader :: Monad m
=> B.ByteString
-> B.ByteString
-> SCGIT m ()
setHeader name value = modify (M.insert (CI.mk name) value)
responseHeader :: Monad m
=> B.ByteString
-> SCGIT m (Maybe B.ByteString)
responseHeader name = gets (M.lookup (CI.mk name))
runRequest :: MonadCatchIO m
=> Handle
-> SCGIT m Response
-> m ()
runRequest h f = do
result <- liftIO $ parseWith (B.hGetSome h 4096) netstringParser ""
case result of
Done rest headerString ->
case parseOnly (many1 headerParser) headerString of
Left e -> error e
Right headers -> do
let headerMap = M.fromList $ map (first CI.mk) headers
len' = B8.readInt $ M.findWithDefault (error "CONTENT_LENGTH missing from request") "CONTENT_LENGTH" headerMap
case len' of
Just (len, _) -> do
let c = fromIntegral (len B.length rest)
body' <- liftIO $ (BL.fromChunks [rest] `BL.append`) `liftM` (if c > 0 then BL.hGet h c else return "")
(Response status body'', headers') <- catch (runSCGIT headerMap body' f) handleException
liftIO $ BL.hPutStr h $ BL.concat ["Status: ", status, "\r\n"]
liftIO $ mapM_ (\(k, v) -> B.hPutStr h $ B.concat [CI.original k, ": ", v, "\r\n"]) $ M.toList headers'
liftIO $ BL.hPutStr h "\r\n"
liftIO $ BL.hPutStr h body''
_ -> error "Failed to parse CONTENT_LENGTH."
_ -> error "Failed to parse SCGI request."
where handleException :: MonadIO m => SomeException -> m (Response, Headers)
handleException e = return ( Response "500 Internal Server Error" (BLU.fromString $ show e)
, M.fromList [("Content-Type", "text/plain; charset=utf-8")] )
netstringParser :: Parser B.ByteString
netstringParser = do
count <- decimal <* char ':'
take count <* char ','
headerParser :: Parser (B.ByteString, B.ByteString)
headerParser = (,) <$> cStringParser <*> cStringParser
cStringParser :: Parser B.ByteString
cStringParser = takeTill (== '\NUL') <* char '\NUL'
negotiate :: Monad m => [B.ByteString] -> SCGIT m [B.ByteString]
negotiate representations = do
accept <- header "HTTP_ACCEPT"
case accept of
Nothing -> return representations
Just acc -> return $ best $ matches representations acc
type Quality = Double
acceptParser :: Parser [(B.ByteString, Quality)]
acceptParser = ((,) <$> (skipSpace *> takeTill (inClass ";, "))
<*> (option 1.0 (skipSpace *> char ';' *> skipSpace *> string "q=" *> rational)))
`sepBy` (skipSpace *> char ',' <* skipSpace)
matches :: [B.ByteString]
-> B.ByteString
-> [(B.ByteString, Quality)]
matches available accept =
case parseOnly acceptParser accept of
Left _ -> []
Right acceptable -> mapMaybe (`match` ordered) available
where ordered = reverse $ sortBy (compare `on` snd) acceptable
match :: B.ByteString -> [(B.ByteString, Quality)] -> Maybe (B.ByteString, Quality)
match rep reps = case find (\(r, _) -> G.match (G.compile $ B8.unpack r) (B8.unpack rep)) reps of
Nothing -> Nothing
Just (_, q) -> Just (rep, q)
best :: [(B.ByteString, Quality)]
-> [B.ByteString]
best [] = []
best ms = let highest = snd $ maximumBy (compare `on` snd) ms in
map fst [ x | x <- ms, snd x == highest ]