{-# LANGUAGE UndecidableInstances #-}

module Servant.Server.Auth.Token.Persistent(
    PersistentBackendT
  , runPersistentBackendT
  , liftDB
  ) where

import Control.Monad.Cont.Class (MonadCont(..))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.RWS.Class (MonadRWS)
import Control.Monad.State.Class (MonadState(state))
import Control.Monad.Trans.Control
import Control.Monad.Writer.Class (MonadWriter(..))
import Data.Aeson.WithField
import Database.Persist
import Database.Persist.Sql
import Servant.Server
import Servant.Server.Auth.Token.Monad
import Servant.Server.Auth.Token.Model
import Servant.Server.Auth.Token.Config

import qualified Servant.Server.Auth.Token.Persistent.Schema as S

-- | Monad transformer that implements storage backend
newtype PersistentBackendT m a = PersistentBackendT { unPersistentBackendT :: PersistentBackendInternal m a }
  deriving (Functor, Applicative, Monad, MonadIO, MonadCont, MonadError ServantErr)

type PersistentBackendInternal m = ReaderT (AuthConfig, ConnectionPool) (ExceptT ServantErr (SqlPersistT m))

instance Monad m => HasAuthConfig (PersistentBackendT m) where
  getAuthConfig = PersistentBackendT $ asks fst

instance MonadTrans PersistentBackendT where
  lift = PersistentBackendT . lift . lift . lift

instance (MonadReader r m) => MonadReader r (PersistentBackendT m) where
  ask   = lift ask
  local = mapPersistentBackendT . local

instance (MonadState s m) => MonadState s (PersistentBackendT m) where
  state = lift . state

instance (MonadWriter w m) => MonadWriter w (PersistentBackendT m) where
  writer = lift . writer
  tell   = lift . tell
  listen = unwrapPersistentBackendT listen
  pass   = unwrapPersistentBackendT pass

instance (MonadRWS r w s m) => MonadRWS r w s (PersistentBackendT m)

mapPersistentBackendT :: (m (Either ServantErr a) -> n (Either ServantErr b))
                         -> PersistentBackendT m a -> PersistentBackendT n b
mapPersistentBackendT f = unwrapPersistentBackendT (mapReaderT (mapExceptT (mapReaderT f)))

unwrapPersistentBackendT :: (PersistentBackendInternal m a -> PersistentBackendInternal n b)
                            -> PersistentBackendT m a -> PersistentBackendT n b
unwrapPersistentBackendT f = PersistentBackendT . f . unPersistentBackendT

-- | Execute backend action with given connection pool.
runPersistentBackendT :: MonadBaseControl IO m => AuthConfig -> ConnectionPool -> PersistentBackendT m a -> m (Either ServantErr a)
runPersistentBackendT cfg pool ma = runSqlPool (runExceptT $ runReaderT (unPersistentBackendT ma) (cfg, pool)) pool

-- | Convert entity struct to 'WithId' version
toWithId :: (S.ConvertStorage a' a, S.ConvertStorage (Key a') i) => Entity a' -> WithId i a
toWithId (Entity k v) = WithField (S.convertFrom k) (S.convertFrom v)

-- | Helper to execute DB actions in backend monad
liftDB :: Monad m => SqlPersistT m a -> PersistentBackendT m a
liftDB = PersistentBackendT . lift . lift

instance (MonadIO m) => HasStorage (PersistentBackendT m) where
  getUserImpl = liftDB . fmap (fmap S.convertFrom) . get . S.convertTo
  getUserImplByLogin = liftDB . fmap (fmap toWithId) . getBy . S.UniqueLogin
  listUsersPaged page size = liftDB $ do
    users <- selectList [] [Asc S.UserImplId, OffsetBy (fromIntegral $ page * size), LimitTo (fromIntegral size)]
    total <- count ([] :: [Filter S.UserImpl])
    return (fmap toWithId users, fromIntegral total)
  getUserImplPermissions uid = liftDB . fmap (fmap toWithId) $ selectList [S.UserPermUser ==. S.convertTo uid] [Asc S.UserPermPermission]
  deleteUserPermissions uid = liftDB $ deleteWhere [S.UserPermUser ==. S.convertTo uid]
  insertUserPerm = liftDB . fmap S.convertFrom . insert . S.convertTo
  insertUserImpl = liftDB . fmap S.convertFrom . insert . S.convertTo
  replaceUserImpl i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
  deleteUserImpl = liftDB . delete . S.convertTo
  hasPerm i p = liftDB $ do
    c <- count [S.UserPermUser ==. S.convertTo i, S.UserPermPermission ==. p]
    return $ c > 0
  getFirstUserByPerm p = liftDB . fmap (fmap toWithId) $ do
    mp <- selectFirst [S.UserPermPermission ==. p] []
    case mp of
      Nothing -> return Nothing
      Just (Entity _ S.UserPerm{..}) -> fmap (Entity userPermUser) <$> get userPermUser
  selectUserImplGroups i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupUsersUser ==. S.convertTo i] [Asc S.AuthUserGroupUsersGroup]
  clearUserImplGroups i = liftDB $ deleteWhere [S.AuthUserGroupUsersUser ==. S.convertTo i]
  insertAuthUserGroup = liftDB . fmap S.convertFrom . insert . S.convertTo
  insertAuthUserGroupUsers = liftDB . fmap S.convertFrom . insert . S.convertTo
  insertAuthUserGroupPerms = liftDB . fmap S.convertFrom . insert . S.convertTo
  getAuthUserGroup = liftDB . fmap (fmap S.convertFrom) . get . S.convertTo
  listAuthUserGroupPermissions i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupPermsGroup ==. S.convertTo i] [Asc S.AuthUserGroupPermsPermission]
  listAuthUserGroupUsers i = liftDB . fmap (fmap toWithId) $ selectList [S.AuthUserGroupUsersGroup ==. S.convertTo i] [Asc S.AuthUserGroupUsersUser]
  replaceAuthUserGroup i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
  clearAuthUserGroupUsers i = liftDB $ deleteWhere [S.AuthUserGroupUsersGroup ==. S.convertTo i]
  clearAuthUserGroupPerms i = liftDB $ deleteWhere [S.AuthUserGroupPermsGroup ==. S.convertTo i]
  deleteAuthUserGroup = liftDB . delete . S.convertTo
  listGroupsPaged page size = liftDB $ do
    groups <- selectList [] [Asc S.AuthUserGroupId, OffsetBy (fromIntegral $ page * size), LimitTo (fromIntegral size)]
    total <- count ([] :: [Filter S.AuthUserGroup])
    return (fmap toWithId groups, fromIntegral total)
  setAuthUserGroupName i n = liftDB $ update (S.convertTo i) [S.AuthUserGroupName =. n]
  setAuthUserGroupParent i mp = liftDB $ update (S.convertTo i) [S.AuthUserGroupParent =. fmap S.convertTo mp]
  insertSingleUseCode = liftDB . fmap S.convertFrom . insert . S.convertTo
  setSingleUseCodeUsed i mt = liftDB $ update (S.convertTo i) [S.UserSingleUseCodeUsed =. mt]
  getUnusedCode c i t = liftDB . fmap (fmap toWithId) $ selectFirst ([
          S.UserSingleUseCodeValue ==. c
        , S.UserSingleUseCodeUser ==. S.convertTo i
        , S.UserSingleUseCodeUsed ==. Nothing
        ] ++ (
            [S.UserSingleUseCodeExpire ==. Nothing]
        ||. [S.UserSingleUseCodeExpire >=. Just t]
    )) [Desc S.UserSingleUseCodeExpire]
  invalidatePermanentCodes i t = liftDB $ updateWhere [
      S.UserSingleUseCodeUser ==. S.convertTo i
    , S.UserSingleUseCodeUsed ==. Nothing
    , S.UserSingleUseCodeExpire ==. Nothing
    ]
    [S.UserSingleUseCodeUsed =. Just t]
  selectLastRestoreCode i t = liftDB . fmap (fmap toWithId) $ selectFirst [S.UserRestoreUser ==. S.convertTo i, S.UserRestoreExpire >. t] [Desc S.UserRestoreExpire]
  insertUserRestore = liftDB . fmap S.convertFrom . insert . S.convertTo
  findRestoreCode i rc t = liftDB . fmap (fmap toWithId) $ selectFirst [S.UserRestoreUser ==. S.convertTo i, S.UserRestoreValue ==. rc, S.UserRestoreExpire >. t] [Desc S.UserRestoreExpire]
  replaceRestoreCode i v = liftDB $ replace (S.convertTo i) (S.convertTo v)
  findAuthToken i t = liftDB . fmap (fmap toWithId) $ selectFirst [S.AuthTokenUser ==. S.convertTo i, S.AuthTokenExpire >. t] []
  findAuthTokenByValue t = liftDB . fmap (fmap toWithId) $ selectFirst [S.AuthTokenValue ==. t] []
  insertAuthToken = liftDB . fmap S.convertFrom . insert . S.convertTo
  replaceAuthToken i v = liftDB $ replace (S.convertTo i) (S.convertTo v)