-------------------------------------------------------------------------------- -- | -- Module : Happstack.Auth.Internal -- Copyright : (c) Nils Schweinsberg 2010 -- License : BSD3 (see LICENSE file) -- -- Maintainer : mail@n-sch.de -- Stability : experimental -- Portability : non-portable -- -- Internal representation of state functions. -- -------------------------------------------------------------------------------- {-# LANGUAGE TemplateHaskell, TypeSynonymInstances, MultiParamTypeClasses, FlexibleContexts, FlexibleInstances, DeriveDataTypeable #-} {-# OPTIONS -fno-warn-orphans #-} module Happstack.Auth.Internal ( buildSaltAndHash , AskUsers (..) , AddUser (..) , GetUser (..) , GetUserById (..) , DelUser (..) , AuthUser (..) , IsUser (..) , ListUsers (..) , NumUsers (..) , UpdateUser (..) , SetPassword (..) , ChangePassword (..) , ClearAllSessions (..) , SetSession (..) , GetSession (..) , GetSessions (..) , NewSession (..) , DelSession (..) , NumSessions (..) , ClearExpiredSessions (..) , UpdateTimeout (..) ) where import Control.Monad.Reader import Control.Monad.State (modify,get,gets) import Data.Maybe import Numeric import System.Random import qualified Data.Map as M import Codec.Utils (Octet, listToOctets) import Data.ByteString.Internal import Data.Digest.SHA512 (hash) import Happstack.Data.IxSet hiding (null) import Happstack.State import Happstack.State.ClockTime import Happstack.Auth.Internal.Data hiding (Username, User, SessionData) import qualified Happstack.Auth.Internal.Data as D ------------------------------------------------------------------------------- -- Password generation saltLength :: Num t => t saltLength = 16 strToOctets :: String -> [Octet] strToOctets = listToOctets . (map c2w) slowHash :: [Octet] -> [Octet] slowHash a = (iterate hash a) !! 512 randomSalt :: IO String randomSalt = liftM concat $ sequence $ take saltLength $ repeat $ randomRIO (0::Int,15) >>= return . flip showHex "" buildSaltAndHash :: String -> IO (Maybe SaltedHash) buildSaltAndHash str | null str = return Nothing | otherwise = do salt <- randomSalt let salt' = strToOctets salt str' = strToOctets str h = slowHash (salt'++str') return . Just $ SaltedHash $ salt'++h checkSalt :: String -> SaltedHash -> Bool checkSalt str (SaltedHash h) = h == salt++(slowHash $ salt++(strToOctets str)) where salt = take saltLength h -------------------------------------------------------------------------------- -- State functions: Users askUsers :: Query AuthState UserDB askUsers = return . users =<< ask getUser :: D.Username -> Query AuthState (Maybe D.User) getUser un = do udb <- askUsers return $ getOne $ udb @= un getUserById :: D.UserId -> Query AuthState (Maybe D.User) getUserById uid = do udb <- askUsers return $ getOne $ udb @= uid modUsers :: (UserDB -> UserDB) -> Update AuthState () modUsers f = modify (\s -> (AuthState (sessions s) (f $ users s) (nextUid s))) getAndIncUid :: Update AuthState D.UserId getAndIncUid = do uid <- gets nextUid modify (\s -> (AuthState (sessions s) (users s) (uid+1))) return uid isUser :: D.Username -> Query AuthState Bool isUser name = do us <- askUsers return $ isJust $ getOne $ us @= name addUser :: D.Username -> SaltedHash -> Update AuthState (Maybe D.User) addUser name pass | null (unUser name) = return Nothing | otherwise = do s <- get let exists = isJust $ getOne $ (users s) @= name if exists then return Nothing else do u <- newUser name pass modUsers $ insert u return $ Just u where newUser u p = do uid <- getAndIncUid return $ D.User uid u p delUser :: D.Username -> Update AuthState () delUser name = modUsers del where del db = case getOne (db @= name) of Just u -> delete u db Nothing -> db updateUser :: D.User -> Update AuthState () updateUser u = modUsers (updateIx (userid u) u) authUser :: String -> String -> Query AuthState (Maybe D.User) authUser name pass = do udb <- askUsers let u = getOne $ udb @= (D.Username name) case u of (Just v) -> return $ if checkSalt pass (userpass v) then u else Nothing Nothing -> return Nothing listUsers :: Query AuthState [D.Username] listUsers = do udb <- askUsers return $ map username $ toList udb numUsers :: Query AuthState Int numUsers = liftM length listUsers setPassword :: D.Username -> SaltedHash -> Update AuthState Bool setPassword un h = do mu <- runQuery $ getUser un case mu of Just u -> do updateUser u { userpass = h } return True _ -> return False changePassword :: String -- ^ Username -> String -- ^ Old password -> SaltedHash -- ^ New password -> Update AuthState Bool changePassword un op s = do mu <- runQuery $ authUser un op case mu of Just u -> do updateUser u { userpass = s } return True _ -> return False -------------------------------------------------------------------------------- -- State functions: Sessions askSessions :: Query AuthState (Sessions D.SessionData) askSessions = return . sessions =<< ask modSessions :: (Sessions D.SessionData -> Sessions D.SessionData) -> Update AuthState () modSessions f = modify (\s -> (AuthState (f $ sessions s) (users s) (nextUid s))) setSession :: SessionKey -> D.SessionData -> Update AuthState () setSession key u = do modSessions $ Sessions . (M.insert key u) . unsession return () newSession :: D.SessionData -> Update AuthState SessionKey newSession u = do key <- getRandom setSession key u return key delSession :: SessionKey -> Update AuthState () delSession key = do modSessions $ Sessions . (M.delete key) . unsession return () clearAllSessions :: Update AuthState () clearAllSessions = modSessions $ const (Sessions M.empty) getSession :: SessionKey -> Query AuthState (Maybe D.SessionData) getSession key = liftM ((M.lookup key) . unsession) askSessions getSessions :: Query AuthState (Sessions D.SessionData) getSessions = askSessions numSessions :: Query AuthState Int numSessions = liftM (M.size . unsession) askSessions clearExpiredSessions :: ClockTime -> Update AuthState () clearExpiredSessions c = modSessions $ Sessions . (M.filter ((c <) . sesTimeout)) . unsession updateTimeout :: SessionKey -> ClockTime -> Update AuthState () updateTimeout sid c = do modSessions $ Sessions . (M.update (\sd -> Just sd { sesTimeout = c }) sid) . unsession -------------------------------------------------------------------------------- -- Generate Methods $(mkMethods ''AuthState [ 'askUsers , 'addUser , 'getUser , 'getUserById , 'delUser , 'authUser , 'isUser , 'listUsers , 'numUsers , 'updateUser , 'setPassword , 'changePassword , 'clearAllSessions , 'setSession , 'getSession , 'getSessions , 'newSession , 'delSession , 'numSessions , 'clearExpiredSessions , 'updateTimeout ])