{-# LANGUAGE GADTs #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DoAndIfThenElse #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Web.Spock.Internal.Wire where import Control.Applicative import Control.Arrow ((***)) import Control.Concurrent.MVar import Control.Concurrent.STM import Control.Exception import Control.Monad.RWS.Strict #if MIN_VERSION_mtl(2,2,0) import Control.Monad.Except #else import Control.Monad.Error #endif import Control.Monad.Reader.Class () import Control.Monad.Trans.Resource import Data.Hashable import Data.IORef import Data.Maybe import Data.Typeable import Data.Word import GHC.Generics import Network.HTTP.Types.Header (ResponseHeaders) import Network.HTTP.Types.Method import Network.HTTP.Types.Status #if MIN_VERSION_base(4,6,0) import Prelude #else import Prelude hiding (catch) #endif import System.Directory import Web.Routing.Router import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import qualified Data.ByteString.Lazy.Char8 as BSLC import qualified Data.ByteString.SuperBuffer as SB import qualified Data.CaseInsensitive as CI import qualified Data.HashMap.Strict as HM import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Vault.Lazy as V import qualified Network.Wai as Wai import qualified Network.Wai.Parse as P newtype HttpMethod = HttpMethod { unHttpMethod :: StdMethod } deriving (Show, Eq, Generic) instance Hashable HttpMethod where hashWithSalt = hashUsing (fromEnum . unHttpMethod) -- | The 'SpockMethod' allows safe use of http verbs via the 'MethodStandard' constructor and 'StdMethod', -- and custom verbs via the 'MethodCustom' constructor. data SpockMethod -- | Standard HTTP Verbs from 'StdMethod' = MethodStandard !HttpMethod -- | Custom HTTP Verbs using 'T.Text' | MethodCustom !T.Text deriving (Eq, Generic) instance Hashable SpockMethod data UploadedFile = UploadedFile { uf_name :: !T.Text , uf_contentType :: !T.Text , uf_tempLocation :: !FilePath } deriving Show data VaultIf = VaultIf { vi_modifyVault :: (V.Vault -> V.Vault) -> IO () , vi_lookupKey :: forall a. V.Key a -> IO (Maybe a) } data CacheVar v = forall r. CacheVar { cv_lock :: !(MVar ()) , cv_makeVal :: !(IO r) , cv_value :: !(IORef (Maybe r)) , cv_read :: r -> v } instance Functor CacheVar where fmap f (CacheVar lock makeVal valRef readV) = CacheVar { cv_lock = lock , cv_makeVal = makeVal , cv_value = valRef , cv_read = f . readV } newCacheVar :: IO v -> IO (CacheVar v) newCacheVar makeVal = do lock <- newEmptyMVar valueR <- newIORef Nothing return (CacheVar lock makeVal valueR id) loadCacheVarOpt :: CacheVar v -> IO (Maybe v) loadCacheVarOpt (CacheVar lock _ valRef readV) = bracket_ (putMVar lock ()) (takeMVar lock) $ fmap readV <$> readIORef valRef loadCacheVar :: CacheVar v -> IO v loadCacheVar (CacheVar lock makeVal valRef readV) = bracket_ (putMVar lock ()) (takeMVar lock) $ do val <- readIORef valRef case val of Just v -> return (readV v) Nothing -> do v <- makeVal writeIORef valRef (Just v) return (readV v) data RequestBody = RequestBody { rb_value :: CacheVar BS.ByteString , rb_postParams :: CacheVar [(T.Text, T.Text)] , rb_files :: CacheVar (HM.HashMap T.Text UploadedFile) } data RequestInfo ctx = RequestInfo { ri_method :: !SpockMethod , ri_request :: !Wai.Request , ri_getParams :: ![(T.Text, T.Text)] , ri_reqBody :: !RequestBody , ri_vaultIf :: !VaultIf , ri_context :: !ctx } newtype ResponseBody = ResponseBody (Status -> ResponseHeaders -> Wai.Response) data MultiHeader = MultiHeaderCacheControl | MultiHeaderConnection | MultiHeaderContentEncoding | MultiHeaderContentLanguage | MultiHeaderPragma | MultiHeaderProxyAuthenticate | MultiHeaderTrailer | MultiHeaderTransferEncoding | MultiHeaderUpgrade | MultiHeaderVia | MultiHeaderWarning | MultiHeaderWWWAuth | MultiHeaderSetCookie deriving (Show, Eq, Enum, Bounded, Generic) instance Hashable MultiHeader multiHeaderCI :: MultiHeader -> CI.CI BS.ByteString multiHeaderCI mh = case mh of MultiHeaderCacheControl -> "Cache-Control" MultiHeaderConnection -> "Connection" MultiHeaderContentEncoding -> "Content-Encoding" MultiHeaderContentLanguage -> "Content-Language" MultiHeaderPragma -> "Pragma" MultiHeaderProxyAuthenticate -> "Proxy-Authenticate" MultiHeaderTrailer -> "Trailer" MultiHeaderTransferEncoding -> "Transfer-Encoding" MultiHeaderUpgrade -> "Upgrade" MultiHeaderVia -> "Via" MultiHeaderWarning -> "Warning" MultiHeaderWWWAuth -> "WWW-Authenticate" MultiHeaderSetCookie -> "Set-Cookie" multiHeaderMap :: HM.HashMap (CI.CI BS.ByteString) MultiHeader multiHeaderMap = HM.fromList $ flip map allHeaders $ \mh -> (multiHeaderCI mh, mh) where -- this is a nasty hack until we know more about the origin of -- uncaught exception: ErrorCall (toEnum{MultiHeader}: tag (-12565) is outside of enumeration's range (0,12)) -- see: https://ghc.haskell.org/trac/ghc/ticket/10792 and https://github.com/agrafix/Spock/issues/44 allHeaders = [ MultiHeaderCacheControl , MultiHeaderConnection , MultiHeaderContentEncoding , MultiHeaderContentLanguage , MultiHeaderPragma , MultiHeaderProxyAuthenticate , MultiHeaderTrailer , MultiHeaderTransferEncoding , MultiHeaderUpgrade , MultiHeaderVia , MultiHeaderWarning , MultiHeaderWWWAuth , MultiHeaderSetCookie ] data ResponseVal = ResponseValState !ResponseState | ResponseHandler !(IO Wai.Application) data ResponseState = ResponseState { rs_responseHeaders :: !(HM.HashMap (CI.CI BS.ByteString) BS.ByteString) , rs_multiResponseHeaders :: !(HM.HashMap MultiHeader [BS.ByteString]) , rs_status :: !Status , rs_responseBody :: !ResponseBody } data ActionInterupt = ActionRedirect !T.Text | ActionTryNext | ActionError String | ActionDone | ActionMiddlewarePass | ActionMiddleware !(IO Wai.Middleware) | ActionApplication !(IO Wai.Application) deriving Typeable instance Monoid ActionInterupt where mempty = ActionDone mappend _ a = a #if MIN_VERSION_mtl(2,2,0) type ErrorT = ExceptT runErrorT :: ExceptT e m a -> m (Either e a) runErrorT = runExceptT #else instance Error ActionInterupt where noMsg = ActionError "Unkown Internal Action Error" strMsg = ActionError #endif type ActionT = ActionCtxT () newtype ActionCtxT ctx m a = ActionCtxT { runActionCtxT :: ErrorT ActionInterupt (RWST (RequestInfo ctx) () ResponseState m) a } deriving ( Monad, Functor, Applicative, Alternative, MonadIO , MonadReader (RequestInfo ctx), MonadState ResponseState , MonadError ActionInterupt ) instance MonadTrans (ActionCtxT ctx) where lift = ActionCtxT . lift . lift data SpockConfigInternal = SpockConfigInternal { sci_maxRequestSize :: Maybe Word64 , sci_errorHandler :: Status -> IO Wai.Application } defaultSpockConfigInternal :: SpockConfigInternal defaultSpockConfigInternal = SpockConfigInternal Nothing defaultErrorHandler where defaultErrorHandler status = return $ \_ respond -> do let errorMessage = "Error handler failed with status code " ++ (show $ statusCode status) respond $ Wai.responseLBS status500 [] $ BSLC.pack errorMessage respStateToResponse :: ResponseVal -> Wai.Response respStateToResponse (ResponseValState (ResponseState headers multiHeaders status (ResponseBody body))) = let mkMultiHeader (k, vals) = let kCi = multiHeaderCI k in map (\v -> (kCi, v)) vals outHeaders = HM.toList headers ++ (concatMap mkMultiHeader $ HM.toList multiHeaders) in body status outHeaders respStateToResponse _ = error "ResponseState expected" errorResponse :: Status -> BSL.ByteString -> ResponseVal errorResponse s e = ResponseValState $ ResponseState { rs_responseHeaders = HM.singleton "Content-Type" "text/html" , rs_multiResponseHeaders = HM.empty , rs_status = s , rs_responseBody = ResponseBody $ \status headers -> Wai.responseLBS status headers $ BSL.concat [ "" , e , "

" , e , "

" ] } defResponse :: ResponseState defResponse = ResponseState { rs_responseHeaders = HM.empty , rs_multiResponseHeaders = HM.empty , rs_status = status200 , rs_responseBody = ResponseBody $ \status headers -> Wai.responseLBS status headers $ BSL.empty } type SpockAllT n m a = RegistryT (ActionT n) () Wai.Middleware SpockMethod m a middlewareToApp :: Wai.Middleware -> Wai.Application middlewareToApp mw = mw fallbackApp where fallbackApp :: Wai.Application fallbackApp _ respond = respond notFound notFound = respStateToResponse $ errorResponse status404 "404 - File not found" makeActionEnvironment :: InternalState -> SpockMethod -> Wai.Request -> IO (RequestInfo (), TVar V.Vault, IO ()) makeActionEnvironment st stdMethod req = do vaultVar <- liftIO $ newTVarIO (Wai.vault req) let vaultIf = VaultIf { vi_modifyVault = atomically . modifyTVar' vaultVar , vi_lookupKey = \k -> V.lookup k <$> atomically (readTVar vaultVar) } getParams = map (\(k, mV) -> (T.decodeUtf8 k, T.decodeUtf8 $ fromMaybe BS.empty mV)) $ Wai.queryString req rbValue <- newCacheVar $ do let parseBody = Wai.requestBody req bodyLength = Wai.requestBodyLength req buffStart = case bodyLength of Wai.ChunkedBody -> 1024 Wai.KnownLength x -> fromIntegral x SB.withBuffer buffStart $ \sb -> do let loop = do b <- parseBody if BS.null b then pure () else (SB.appendBuffer sb b >> loop) loop bodyTuple <- newCacheVar $ case P.getRequestBodyType req of Nothing -> return ([], HM.empty) Just rbt -> do bodyBs <- loadCacheVar rbValue bodyRef <- newIORef (Just bodyBs) let loader = do mb <- readIORef bodyRef case mb of Just b -> writeIORef bodyRef Nothing >> pure b Nothing -> pure BS.empty (bodyParams, bodyFiles) <- P.sinkRequestBody (P.tempFileBackEnd st) rbt loader let uploadedFiles = HM.fromList $ flip map bodyFiles $ \(k, fileInfo) -> ( T.decodeUtf8 k , UploadedFile (T.decodeUtf8 $ P.fileName fileInfo) (T.decodeUtf8 $ P.fileContentType fileInfo) (P.fileContent fileInfo) ) postParams = map (T.decodeUtf8 *** T.decodeUtf8) bodyParams return (postParams, uploadedFiles) let reqBody = RequestBody { rb_value = rbValue , rb_files = fmap snd bodyTuple , rb_postParams = fmap fst bodyTuple } return ( RequestInfo { ri_method = stdMethod , ri_request = req , ri_getParams = getParams , ri_reqBody = reqBody , ri_vaultIf = vaultIf , ri_context = () } , vaultVar , removeUploadedFiles (rb_files reqBody) ) removeUploadedFiles :: CacheVar (HM.HashMap k UploadedFile) -> IO () removeUploadedFiles uploadedFilesRef = do cvals <- loadCacheVarOpt uploadedFilesRef case cvals of Nothing -> return () Just uploadedFiles -> forM_ (HM.elems uploadedFiles) $ \uploadedFile -> do stillThere <- doesFileExist (uf_tempLocation uploadedFile) when stillThere $ liftIO $ removeFile (uf_tempLocation uploadedFile) applyAction :: MonadIO m => SpockConfigInternal -> Wai.Request -> RequestInfo () -> [ActionT m ()] -> m (Maybe ResponseVal) applyAction config _ _ [] = return $ Just $ getErrorHandler config status404 applyAction config req env (selectedAction : xs) = do (r, respState, _) <- runRWST (runErrorT $ runActionCtxT selectedAction) env defResponse case r of Left (ActionRedirect loc) -> return $ Just $ ResponseValState $ respState { rs_status = status302 , rs_responseBody = ResponseBody $ \status headers -> Wai.responseLBS status (("Location", T.encodeUtf8 loc) : headers) BSL.empty } Left ActionTryNext -> applyAction config req env xs Left (ActionError errorMsg) -> do liftIO $ putStrLn $ "Spock Error while handling " ++ show (Wai.pathInfo req) ++ ": " ++ errorMsg return $ Just $ getErrorHandler config status500 Left ActionDone -> return $ Just (ResponseValState respState) Left ActionMiddlewarePass -> return Nothing Left (ActionApplication app) -> return $ Just (ResponseHandler app) Left (ActionMiddleware getMiddleware) -> return $ Just $ ResponseHandler $ do errHandler <- sci_errorHandler config status404 mw <- getMiddleware return $ mw errHandler Right () -> return $ Just (ResponseValState respState) handleRequest :: MonadIO m => SpockConfigInternal -> SpockMethod -> (forall a. m a -> IO a) -> [ActionT m ()] -> InternalState -> Wai.Application -> Wai.Application handleRequest config stdMethod registryLift allActions st coreApp req respond = do reqGo <- case sci_maxRequestSize config of Nothing -> return req Just lim -> requestSizeCheck lim req handleRequest' config stdMethod registryLift allActions st coreApp reqGo respond handleRequest' :: MonadIO m => SpockConfigInternal -> SpockMethod -> (forall a. m a -> IO a) -> [ActionT m ()] -> InternalState -> Wai.Application -> Wai.Application handleRequest' config stdMethod registryLift allActions st coreApp req respond = do actEnv <- (Left <$> makeActionEnvironment st stdMethod req) `catch` \(_ :: SizeException) -> return (Right $ getErrorHandler config status413) case actEnv of Left (mkEnv, vaultVar, cleanUp) -> do mRespState <- registryLift (applyAction config req mkEnv allActions) `catches` [ Handler $ \(_ :: SizeException) -> return (Just $ getErrorHandler config status413) , Handler $ \(e :: SomeException) -> do putStrLn $ "Spock Error while handling " ++ show (Wai.pathInfo req) ++ ": " ++ show e return $ Just $ getErrorHandler config status500 ] cleanUp case mRespState of Just (ResponseHandler responseHandler) -> responseHandler >>= \app -> app req respond Just respState -> respond $ respStateToResponse respState Nothing -> do newVault <- atomically $ readTVar vaultVar let req' = req { Wai.vault = V.union newVault (Wai.vault req) } coreApp req' respond Right respState -> respond $ respStateToResponse respState getErrorHandler :: SpockConfigInternal -> Status -> ResponseVal getErrorHandler config = ResponseHandler . sci_errorHandler config data SizeException = SizeException deriving (Show, Typeable) instance Exception SizeException requestSizeCheck :: Word64 -> Wai.Request -> IO Wai.Request requestSizeCheck maxSize req = do currentSize <- newIORef 0 return $ req { Wai.requestBody = do bs <- Wai.requestBody req total <- atomicModifyIORef currentSize $ \sz -> let !nextSize = sz + fromIntegral (BS.length bs) in (nextSize, nextSize) if total > maxSize then throwIO SizeException else return bs } buildMiddleware :: forall m. (MonadIO m) => SpockConfigInternal -> (forall a. m a -> IO a) -> SpockAllT m m () -> IO Wai.Middleware buildMiddleware config registryLift spockActions = do (_, getMatchingRoutes, middlewares) <- registryLift $ runRegistry spockActions let spockMiddleware = foldl (.) id middlewares app :: Wai.Application -> Wai.Application app coreApp req respond = withSpockMethod (Wai.requestMethod req) $ \method -> do let allActions = getMatchingRoutes method (Wai.pathInfo req) runResourceT $ withInternalState $ \st -> handleRequest config method registryLift allActions st coreApp req respond return $ spockMiddleware . app withSpockMethod :: forall t. Method -> (SpockMethod -> t) -> t withSpockMethod method cnt = case parseMethod method of Left _ -> cnt (MethodCustom $ T.decodeUtf8 method) Right stdMethod -> cnt (MethodStandard $ HttpMethod stdMethod)