{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Handler.Snap
    ( run
    ) where

import qualified Network.Wai as W
import Snap.Types
import Snap.Http.Server
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import Control.Monad.IO.Class
import Data.CIByteString
import Control.Arrow (first, (***))
import qualified Data.Map as Map
import Data.Monoid (mempty)
import qualified Data.Enumerator as E
import Control.Applicative ((<$>))

run :: Int -> W.Application -> IO ()
run port = httpServe (addListen (ListenHttp "*" port) mempty) . waiToSnap

waiToSnap :: W.Application -> Snap ()
waiToSnap wapp = do
    sreq <- getRequest
    reqBody <- getRequestBody
    wres <- liftIO $ wapp $ toWaiRequest reqBody sreq
    modifyResponse $ toSnapResponse wres
    case W.responseBody wres of
        W.ResponseFile fp -> sendFile fp
        W.ResponseEnumerator enum ->
            modifyResponse $ setResponseBody $ toSnapEnum enum
        W.ResponseLBS lbs -> writeLBS lbs

toWaiRequest :: L.ByteString -> Request -> W.Request
toWaiRequest reqBody req = W.Request
  {  W.requestMethod  = S8.pack $ show $ rqMethod req
  ,  W.httpVersion    = case rqVersion req of
                            (0, 9) -> W.http09
                            (1, 0) -> W.http10
                            (1, 1) -> W.http11
                            (x, y) -> S8.pack
                                    $ show x ++ "." ++ show y
  ,  W.pathInfo       = S8.cons '/' $ rqPathInfo req
  ,  W.queryString    = rqQueryString req
  ,  W.serverName     = rqServerName req
  ,  W.serverPort     = rqServerPort req
  ,  W.requestHeaders = toReqHeaders $ headers req
  ,  W.isSecure       = rqIsSecure req
  ,  W.requestBody    = bsToSource reqBody
  ,  W.errorHandler   = error
  ,  W.remoteHost     = rqRemoteAddr req
  }

toReqHeaders :: Map.Map CIByteString [S8.ByteString]
             -> [(W.RequestHeader, S8.ByteString)]
toReqHeaders =
    concatMap (\(x, y) -> zip (repeat x) y) . map (first go) . Map.toList
  where
    go = W.mkCIByteString . unCI

-- | Unfortunately, Source is not compatible with IterateeG.
bsToSource :: L.ByteString -> W.Source
bsToSource = go . L.toChunks
  where
    go [] = W.Source $ return Nothing
    go (x:xs) = W.Source $ return $ Just (x, go xs)

toSnapResponse :: W.Response -> Response -> Response
toSnapResponse wres =
    setResponseStatus (W.statusCode st) (W.statusMessage st)
  . updateHeaders (const newHeaders)
  where
    st = W.status wres
    newHeaders = Map.fromList $ map (go *** return) $ W.responseHeaders wres
    go = toCI . W.ciOriginal

toSnapEnum :: W.Enumerator -> Enumerator S8.ByteString IO a
toSnapEnum (W.Enumerator enum) step0 = do
    E.Iteratee $ either id id <$> enum go step0
  where
    go (E.Continue k) bs = do
        step' <- E.runIteratee $ k $ E.Chunks [bs]
        return $ Right step'
    go step _ = return $ Left step