{-# LANGUAGE UndecidableInstances #-}
module Servant.Server.Auth.Token.Persistent(
PersistentBackendT
, runPersistentBackendT
, liftDB
) where
import Control.Monad.Catch
import Control.Monad.Cont.Class (MonadCont(..))
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Control.Monad.RWS.Class (MonadRWS)
import Control.Monad.State.Class (MonadState(state))
import Control.Monad.Writer.Class (MonadWriter(..))
import Data.Aeson.WithField
import Database.Persist
import Database.Persist.Sql
import Servant.Server
import Servant.Server.Auth.Token.Config
import Servant.Server.Auth.Token.Model
import Servant.Server.Auth.Token.Monad
import qualified Servant.Server.Auth.Token.Persistent.Schema as S
newtype PersistentBackendT m a = PersistentBackendT { unPersistentBackendT :: PersistentBackendInternal m a }
deriving (Functor, Applicative, Monad, MonadIO, MonadCont, MonadThrow, MonadCatch)
type PersistentBackendInternal m = ReaderT (AuthConfig, ConnectionPool) (SqlPersistT m)
instance MonadCatch m => MonadError ServantErr (PersistentBackendT m) where
throwError = throwM
catchError = catch
instance Monad m => HasAuthConfig (PersistentBackendT m) where
getAuthConfig = PersistentBackendT $ asks fst
instance MonadTrans PersistentBackendT where
lift = PersistentBackendT . lift . lift
instance (MonadReader r m) => MonadReader r (PersistentBackendT m) where
ask = lift ask
local = mapPersistentBackendT . local
instance (MonadState s m) => MonadState s (PersistentBackendT m) where
state = lift . state
instance (MonadWriter w m) => MonadWriter w (PersistentBackendT m) where
writer = lift . writer
tell = lift . tell
listen = unwrapPersistentBackendT listen
pass = unwrapPersistentBackendT pass
instance (MonadRWS r w s m) => MonadRWS r w s (PersistentBackendT m)
mapPersistentBackendT :: (m a -> n b) -> PersistentBackendT m a -> PersistentBackendT n b
mapPersistentBackendT f = unwrapPersistentBackendT (mapReaderT (mapReaderT f))
unwrapPersistentBackendT :: (PersistentBackendInternal m a -> PersistentBackendInternal n b)
-> PersistentBackendT m a -> PersistentBackendT n b
unwrapPersistentBackendT f = PersistentBackendT . f . unPersistentBackendT
runPersistentBackendT :: (MonadUnliftIO m, MonadCatch m) => AuthConfig -> ConnectionPool -> PersistentBackendT m a -> m (Either ServantErr a)
runPersistentBackendT cfg pool ma = do
let ma' = runSqlPool (runReaderT (unPersistentBackendT ma) (cfg, pool)) pool
catch (Right <$> ma') $ \e -> pure $ Left e
toWithId :: (S.ConvertStorage a' a, S.ConvertStorage (Key a') i) => Entity a' -> WithId i a
toWithId (Entity k v) = WithField (S.convertFrom k) (S.convertFrom v)
liftDB :: Monad m => SqlPersistT m a -> PersistentBackendT m a
liftDB = PersistentBackendT . lift
instance (MonadIO m) => HasStorage (PersistentBackendT m) where
getUserImpl = liftDB . fmap (fmap S.convertFrom) . get . S.convertTo
getUserImplByLogin = liftDB . fmap (fmap toWithId) . getBy . S.UniqueLogin
listUsersPaged page size = liftDB $ do
users <- selectList [] [Asc S.UserImplId, OffsetBy (fromIntegral $ page * size), LimitTo (fromIntegral size)]
total <- count ([] :: [Filter S.UserImpl])
return (fmap toWithId users, fromIntegral total)
getUserImplPermissions uid = liftDB . fmap (fmap toWithId) $ selectList [S.UserPermUser ==. S.convertTo uid] [Asc S.UserPermPermission]
deleteUserPermissions uid = liftDB $ deleteWhere [S.UserPermUser ==. S.convertTo uid]
insertUserPerm = liftDB . fmap S.convertFrom . insert . S.convertTo
insertUserImpl = liftDB . fmap S.convertFrom . insert . S.convertTo
replaceUserImpl i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
deleteUserImpl = liftDB . delete . S.convertTo
hasPerm i p = liftDB $ do
c <- count [S.UserPermUser ==. S.convertTo i, S.UserPermPermission ==. p]
return $ c > 0
getFirstUserByPerm p = liftDB . fmap (fmap toWithId) $ do
mp <- selectFirst [S.UserPermPermission ==. p] []
case mp of
Nothing -> return Nothing
Just (Entity _ S.UserPerm{..}) -> fmap (Entity userPermUser) <$> get userPermUser
selectUserImplGroups i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupUsersUser ==. S.convertTo i] [Asc S.AuthUserGroupUsersGroup]
clearUserImplGroups i = liftDB $ deleteWhere [S.AuthUserGroupUsersUser ==. S.convertTo i]
insertAuthUserGroup = liftDB . fmap S.convertFrom . insert . S.convertTo
insertAuthUserGroupUsers = liftDB . fmap S.convertFrom . insert . S.convertTo
insertAuthUserGroupPerms = liftDB . fmap S.convertFrom . insert . S.convertTo
getAuthUserGroup = liftDB . fmap (fmap S.convertFrom) . get . S.convertTo
listAuthUserGroupPermissions i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupPermsGroup ==. S.convertTo i] [Asc S.AuthUserGroupPermsPermission]
listAuthUserGroupUsers i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupUsersGroup ==. S.convertTo i] [Asc S.AuthUserGroupUsersUser]
replaceAuthUserGroup i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
clearAuthUserGroupUsers i = liftDB $ deleteWhere [S.AuthUserGroupUsersGroup ==. S.convertTo i]
clearAuthUserGroupPerms i = liftDB $ deleteWhere [S.AuthUserGroupPermsGroup ==. S.convertTo i]
deleteAuthUserGroup = liftDB . delete . S.convertTo
listGroupsPaged page size = liftDB $ do
groups <- selectList [] [Asc S.AuthUserGroupId, OffsetBy (fromIntegral $ page * size), LimitTo (fromIntegral size)]
total <- count ([] :: [Filter S.AuthUserGroup])
return (fmap toWithId groups, fromIntegral total)
setAuthUserGroupName i n = liftDB $ update (S.convertTo i) [S.AuthUserGroupName =. n]
setAuthUserGroupParent i mp = liftDB $ update (S.convertTo i) [S.AuthUserGroupParent =. fmap S.convertTo mp]
insertSingleUseCode = liftDB . fmap S.convertFrom . insert . S.convertTo
setSingleUseCodeUsed i mt = liftDB $ update (S.convertTo i) [S.UserSingleUseCodeUsed =. mt]
getUnusedCode c i t = liftDB . fmap (fmap toWithId) $ selectFirst ([
S.UserSingleUseCodeValue ==. c
, S.UserSingleUseCodeUser ==. S.convertTo i
, S.UserSingleUseCodeUsed ==. Nothing
] ++ (
[S.UserSingleUseCodeExpire ==. Nothing]
||. [S.UserSingleUseCodeExpire >=. Just t]
)) [Desc S.UserSingleUseCodeExpire]
invalidatePermanentCodes i t = liftDB $ updateWhere [
S.UserSingleUseCodeUser ==. S.convertTo i
, S.UserSingleUseCodeUsed ==. Nothing
, S.UserSingleUseCodeExpire ==. Nothing
]
[S.UserSingleUseCodeUsed =. Just t]
selectLastRestoreCode i t = liftDB . fmap (fmap toWithId) $ selectFirst [S.UserRestoreUser ==. S.convertTo i, S.UserRestoreExpire >. t] [Desc S.UserRestoreExpire]
insertUserRestore = liftDB . fmap S.convertFrom . insert . S.convertTo
findRestoreCode i rc t = liftDB . fmap (fmap toWithId) $ selectFirst [S.UserRestoreUser ==. S.convertTo i, S.UserRestoreValue ==. rc, S.UserRestoreExpire >. t] [Desc S.UserRestoreExpire]
replaceRestoreCode i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
findAuthToken i t = liftDB . fmap (fmap toWithId) $ selectFirst [S.AuthTokenUser ==. S.convertTo i, S.AuthTokenExpire >. t] []
findAuthTokenByValue t = liftDB . fmap (fmap toWithId) $ selectFirst [S.AuthTokenValue ==. t] []
insertAuthToken = liftDB . fmap S.convertFrom . insert . S.convertTo
replaceAuthToken i v = liftDB $ replace (S.convertTo i) (S.convertTo v)