module Web.Apiary.ClientSession.Internal where
import Language.Haskell.TH
import System.Directory
import "crypto-random" Crypto.Random
import Crypto.Random.AESCtr
import Control.Monad
import Control.Applicative
import Control.Monad.Apiary.Filter
import Control.Monad.Apiary.Action
import Web.Apiary.Wai
import Web.Apiary
import Web.Apiary.Cookie
import Web.ClientSession
import qualified Network.HTTP.Types as HTTP
import Data.Apiary.Compat
import Data.Apiary.Param
import Data.String
import Data.Maybe
import Data.Monoid
import Data.Time
import Data.Default.Class
import Data.Binary
import Data.IORef
import Text.Blaze.Html
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as SC
import qualified Data.ByteString.Lazy as L
data Session = Session
{ sessionKey :: Key
, tokenGen :: IORef AESRNG
, sessionConfig :: SessionConfig
}
data KeySource
= KeyFile FilePath
| KeyByteString S.ByteString
instance IsString KeySource where
fromString = KeyByteString . fromString
embedKeyConfig :: FilePath -> ExpQ
embedKeyConfig keyfile = do
bs <- runIO $ do
exists <- doesFileExist keyfile
if exists
then do
b <- S.readFile keyfile
case initKey b of
Left _ -> newKey
Right _ -> return b
else newKey
let s = stringE $ SC.unpack bs
[| def { sessionKeySource = KeyByteString $s } |]
where
newKey = do
(bs, _) <- randomKey
S.writeFile keyfile bs
return bs
embedDefaultKeyConfig :: ExpQ
embedDefaultKeyConfig = embedKeyConfig defaultKeyFile
data SessionConfig = SessionConfig
{ sessionKeySource :: KeySource
, sessionMaxAge :: DiffTime
, sessionPath :: Maybe S.ByteString
, sessionDomain :: Maybe S.ByteString
, sessionHttpOnly :: Bool
, sessionSecure :: Bool
, sessionTimeoutAction :: forall exts prms m. MonadIO m => ActionT exts prms m ()
, angularXsrfCookieName :: Maybe S.ByteString
, csrfTokenCookieName :: S.ByteString
, csrfTokenCheckingName :: Either HTTP.HeaderName S.ByteString
, csrfTokenLength :: Int
}
defaultCheckTokenFailAction :: Monad actM => ActionT exts prms actM ()
defaultCheckTokenFailAction = do
reset
status status401
bytes "session timeout\n"
stop
instance Default SessionConfig where
def = SessionConfig
(KeyFile defaultKeyFile) (24 * 60 * 60) Nothing Nothing True True
defaultCheckTokenFailAction Nothing "_token" (Right "_token") 40
makeSession :: MonadIO m => SessionConfig -> m Session
makeSession cfg@SessionConfig{..} = do
k <- liftIO $ case sessionKeySource of
KeyFile f -> getKey f
KeyByteString s -> either fail return $ initKey s
p <- liftIO $ makeSystem >>= newIORef
let sess = Session k p cfg
return sess
newtype BinUTCTime = BinUTCTime { getUTCTime :: UTCTime }
instance Binary BinUTCTime where
put (BinUTCTime t) = do
put . toModifiedJulianDay $ utctDay t
put . toRational $ utctDayTime t
get = do
d <- ModifiedJulianDay <$> get
t <- fromRational <$> get
return . BinUTCTime $ UTCTime d t
mkSessionCookie :: SessionConfig -> Key -> S.ByteString -> S.ByteString -> IO SetCookie
mkSessionCookie conf skey k v = do
t <- getCurrentTime
let expire = addUTCTime (realToFrac $ sessionMaxAge conf) t
v' <- encryptIO skey $ L.toStrict $ encode (BinUTCTime expire, v)
return def { setCookieName = k
, setCookieValue = v'
, setCookiePath = sessionPath conf
, setCookieExpires = Just expire
, setCookieMaxAge = Just (sessionMaxAge conf)
, setCookieDomain = sessionDomain conf
, setCookieHttpOnly = sessionHttpOnly conf
, setCookieSecure = sessionSecure conf
}
getSessionValue :: Session -> UTCTime
-> S.ByteString
-> Maybe S.ByteString
getSessionValue Session{sessionKey = k} c s = decrypt k s >>= \s' -> case decodeOrFail (L.fromStrict s') of
Right (_, _, (BinUTCTime t, v)) -> if c < t then Just v else Nothing
_ -> Nothing
setSession :: MonadIO m => Session -> S.ByteString -> S.ByteString -> ActionT exts prms m ()
setSession sess k v = do
s <- liftIO $ mkSessionCookie (sessionConfig sess) (sessionKey sess) k v
setCookie s
newToken :: Int -> IORef AESRNG -> IO S.ByteString
newToken len gen = do
atomicModifyIORef' gen (\rng -> swap $ withRandomBytes rng len Base64.encode)
where
swap (a,b) = (b,a)
csrfToken :: MonadIO m => Session -> ActionT exts prms m S.ByteString
csrfToken Session{..} = do
tok <- liftIO $ newToken (csrfTokenLength sessionConfig) tokenGen
sc <- liftIO $ mkSessionCookie sessionConfig sessionKey (csrfTokenCookieName sessionConfig) tok
setCookie sc
maybe (return ()) (setCookie . ngCookie sc tok) (angularXsrfCookieName sessionConfig)
return tok
where
ngCookie sc tok k = sc { setCookieName = k
, setCookieValue = tok
, setCookieHttpOnly = False
}
session :: (MonadIO actM, Strategy w, Has Session exts, KnownSymbol k, NotMember k prms, Query a)
=> proxy k -> w a -> ApiaryT exts (SNext w k a prms) actM m () -> ApiaryT exts prms actM m ()
session k p = focus (DocPrecondition $ toHtml (symbolVal k) <> " session cookie required") $ do
sess <- getExt (Proxy :: Proxy Session)
t <- liftIO getCurrentTime
c <- map (readQuery . getSessionValue sess t . snd) .
filter ((SC.pack (symbolVal k) ==) . fst) . cookie' <$> getRequest
strategy p k c =<< getParams
checkToken :: (MonadIO actM, Has Session exts)
=> ApiaryT exts prms actM m ()
-> ApiaryT exts prms actM m ()
checkToken = focus (DocPrecondition "CSRF token required") $ do
sess@Session{..} <- getExt (Proxy :: Proxy Session)
qs <- getQueryParams
r <- getRequest
p <- getReqBodyParams
t <- liftIO getCurrentTime
let stok = getSessionValue sess t =<<
lookup (csrfTokenCookieName sessionConfig) (cookie' r)
guard (isJust stok)
qtok <- return . join $ case csrfTokenCheckingName sessionConfig of
Right name -> lookup name $ reqParams pByteString qs p []
Left name -> lookup name $ map (\(k,v) -> (k, Just v)) $ requestHeaders r
guard (isJust qtok)
if qtok == stok then getParams else sessionTimeoutAction sessionConfig >> mzero