{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

{-|

This module allows you to use the auth snaplet with your user database stored
in a PostgreSQL database.  When you run your application with this snaplet, a
config file will be copied into the the @snaplets/postgresql-auth@ directory.
This file contains all of the configurable options for the snaplet and allows
you to change them without recompiling your application.

To use this snaplet in your application enable the session, postgres, and auth
snaplets as follows:

> data App = App
>     { ... -- your own application state here
>     , _sess :: Snaplet SessionManager
>     , _db   :: Snaplet Postgres
>     , _auth :: Snaplet (A.AuthManager App)
>     }

Then in your initializer you'll have something like this:

> d <- nestSnaplet "db" db pgsInit
> a <- nestSnaplet "auth" auth $ initPostgresAuth sess d

If you have not already created the database table for users, it will
automatically be created for you the first time you run your application.

-}

module Snap.Snaplet.Auth.Backends.PostgresqlSimple
  ( initPostgresAuth
  ) where

------------------------------------------------------------------------------
import           Control.Applicative
import qualified Control.Exception                      as E
import           Control.Lens                           ((^#))
import           Control.Monad                          (liftM, void, when)
import           Control.Monad.Trans                    (liftIO)
import qualified Data.Configurator                      as C
import qualified Data.HashMap.Lazy                      as HM
import           Data.Maybe                             (fromMaybe, listToMaybe)
import           Data.Text                              (Text)
import qualified Data.Text                              as T
import qualified Data.Text.Encoding                     as T
import qualified Database.PostgreSQL.Simple             as P
import           Database.PostgreSQL.Simple.FromField   (FromField, fromField)
import qualified Database.PostgreSQL.Simple.ToField     as P
import           Database.PostgreSQL.Simple.Types       (Query (Query))
import           Paths_snaplet_postgresql_simple
import           Prelude
import           Snap                                   (Snaplet, SnapletInit,
                                                         SnapletLens,
                                                         getSnapletUserConfig,
                                                         makeSnaplet,
                                                         snapletValue)
import qualified Snap.Snaplet.Auth                      as A
import           Snap.Snaplet.PostgresqlSimple          (FromRow, Only, ToRow,
                                                         field, fromRow)
import           Snap.Snaplet.PostgresqlSimple.Internal (Postgres,
                                                         withConnection)
import           Snap.Snaplet.Session                   (SessionManager, mkRNG)
import           Web.ClientSession                      (getKey)
------------------------------------------------------------------------------


data PostgresAuthManager = PostgresAuthManager
    { pamTable :: AuthTable
    , pamConn  :: Postgres
    }


------------------------------------------------------------------------------
-- | Initializer for the postgres backend to the auth snaplet.
--
initPostgresAuth
  :: SnapletLens b SessionManager  -- ^ Lens to the session snaplet
  -> Snaplet Postgres  -- ^ The postgres snaplet
  -> SnapletInit b (A.AuthManager b)
initPostgresAuth sess db = makeSnaplet "postgresql-auth" desc datadir $ do
    config <- getSnapletUserConfig
    authTable <- liftIO $ C.lookupDefault "snap_auth_user" config "authTable"
    authSettings <- A.authSettingsFromConfig
    key <- liftIO $ getKey (A.asSiteKey authSettings)
    let tableDesc = defAuthTable { tblName = authTable }
    let manager = PostgresAuthManager tableDesc $ db ^# snapletValue
    liftIO $ createTableIfMissing manager
    rng <- liftIO mkRNG
    return A.AuthManager
      { backend = manager
      , session = sess
      , activeUser = Nothing
      , minPasswdLen = A.asMinPasswdLen authSettings
      , rememberCookieName = A.asRememberCookieName authSettings
      , rememberCookieDomain = Nothing
      , rememberPeriod = A.asRememberPeriod authSettings
      , siteKey = key
      , lockout = A.asLockout authSettings
      , randomNumberGenerator = rng
      }
  where
    desc = "A PostgreSQL backend for user authentication"
    datadir = Just $ liftM (++"/resources/auth") getDataDir


------------------------------------------------------------------------------
-- | Create the user table if it doesn't exist.
createTableIfMissing :: PostgresAuthManager -> IO ()
createTableIfMissing PostgresAuthManager{..} = do
    withConnection pamConn $ \conn -> do
        res <- P.query_ conn $ Query $ T.encodeUtf8 $
          "select relname from pg_class where relname='"
          `T.append` schemaless (tblName pamTable) `T.append` "'"
        when (null (res :: [Only T.Text])) $
          void (P.execute_ conn (Query $ T.encodeUtf8 q))
    return ()
  where
    schemaless = T.reverse . T.takeWhile (/='.') . T.reverse
    q = T.concat
          [ "CREATE TABLE \""
          , tblName pamTable
          , "\" ("
          , T.intercalate "," (map (fDesc . ($pamTable) . fst) colDef)
          , "); "
          , "CREATE INDEX email_idx ON \""
          , tblName pamTable
          , "\" (email);"
          ]

buildUid :: Int -> A.UserId
buildUid = A.UserId . T.pack . show


instance FromField A.UserId where
    fromField f v = buildUid <$> fromField f v

instance FromField A.Password where
    fromField f v = A.Encrypted <$> fromField f v

instance FromRow A.AuthUser where
    fromRow =
        A.AuthUser
        <$> _userId
        <*> _userLogin
        <*> _userEmail
        <*> _userPassword
        <*> _userActivatedAt
        <*> _userSuspendedAt
        <*> _userRememberToken
        <*> _userLoginCount
        <*> _userFailedLoginCount
        <*> _userLockedOutUntil
        <*> _userCurrentLoginAt
        <*> _userLastLoginAt
        <*> _userCurrentLoginIp
        <*> _userLastLoginIp
        <*> _userCreatedAt
        <*> _userUpdatedAt
        <*> _userResetToken
        <*> _userResetRequestedAt
        <*> _userRoles
        <*> _userMeta
      where
        !_userId               = field
        !_userLogin            = field
        !_userEmail            = field
        !_userPassword         = field
        !_userActivatedAt      = field
        !_userSuspendedAt      = field
        !_userRememberToken    = field
        !_userLoginCount       = field
        !_userFailedLoginCount = field
        !_userLockedOutUntil   = field
        !_userCurrentLoginAt   = field
        !_userLastLoginAt      = field
        !_userCurrentLoginIp   = field
        !_userLastLoginIp      = field
        !_userCreatedAt        = field
        !_userUpdatedAt        = field
        !_userResetToken       = field
        !_userResetRequestedAt = field
        !_userRoles            = pure []
        !_userMeta             = pure HM.empty


querySingle :: (ToRow q, FromRow a)
            => Postgres -> Query -> q -> IO (Maybe a)
querySingle pc q ps = withConnection pc $ \conn -> return . listToMaybe =<<
    P.query conn q ps

authExecute :: ToRow q
            => Postgres -> Query -> q -> IO ()
authExecute pc q ps = do
    withConnection pc $ \conn -> P.execute conn q ps
    return ()

instance P.ToField A.Password where
    toField (A.ClearText bs) = P.toField bs
    toField (A.Encrypted bs) = P.toField bs


-- | Datatype containing the names of the columns for the authentication table.
data AuthTable
  =  AuthTable
  {  tblName             :: Text
  ,  colId               :: (Text, Text)
  ,  colLogin            :: (Text, Text)
  ,  colEmail            :: (Text, Text)
  ,  colPassword         :: (Text, Text)
  ,  colActivatedAt      :: (Text, Text)
  ,  colSuspendedAt      :: (Text, Text)
  ,  colRememberToken    :: (Text, Text)
  ,  colLoginCount       :: (Text, Text)
  ,  colFailedLoginCount :: (Text, Text)
  ,  colLockedOutUntil   :: (Text, Text)
  ,  colCurrentLoginAt   :: (Text, Text)
  ,  colLastLoginAt      :: (Text, Text)
  ,  colCurrentLoginIp   :: (Text, Text)
  ,  colLastLoginIp      :: (Text, Text)
  ,  colCreatedAt        :: (Text, Text)
  ,  colUpdatedAt        :: (Text, Text)
  ,  colResetToken       :: (Text, Text)
  ,  colResetRequestedAt :: (Text, Text)
  ,  rolesTable          :: Text
  }

-- | Default authentication table layout
defAuthTable :: AuthTable
defAuthTable
  =  AuthTable
  {  tblName             = "snap_auth_user"
  ,  colId               = ("uid", "SERIAL PRIMARY KEY")
  ,  colLogin            = ("login", "text UNIQUE NOT NULL")
  ,  colEmail            = ("email", "text")
  ,  colPassword         = ("password", "text")
  ,  colActivatedAt      = ("activated_at", "timestamptz")
  ,  colSuspendedAt      = ("suspended_at", "timestamptz")
  ,  colRememberToken    = ("remember_token", "text")
  ,  colLoginCount       = ("login_count", "integer NOT NULL")
  ,  colFailedLoginCount = ("failed_login_count", "integer NOT NULL")
  ,  colLockedOutUntil   = ("locked_out_until", "timestamptz")
  ,  colCurrentLoginAt   = ("current_login_at", "timestamptz")
  ,  colLastLoginAt      = ("last_login_at", "timestamptz")
  ,  colCurrentLoginIp   = ("current_login_ip", "text")
  ,  colLastLoginIp      = ("last_login_ip", "text")
  ,  colCreatedAt        = ("created_at", "timestamptz")
  ,  colUpdatedAt        = ("updated_at", "timestamptz")
  ,  colResetToken       = ("reset_token", "text")
  ,  colResetRequestedAt = ("reset_requested_at", "timestamptz")
  ,  rolesTable          = "user_roles"
  }

fDesc :: (Text, Text) -> Text
fDesc f = fst f `T.append` " " `T.append` snd f

-- | List of deconstructors so it's easier to extract column names from an
-- 'AuthTable'.
colDef :: [(AuthTable -> (Text, Text), A.AuthUser -> P.Action)]
colDef =
  [ (colId              , P.toField . fmap A.unUid . A.userId)
  , (colLogin           , P.toField . A.userLogin)
  , (colEmail           , P.toField . A.userEmail)
  , (colPassword        , P.toField . A.userPassword)
  , (colActivatedAt     , P.toField . A.userActivatedAt)
  , (colSuspendedAt     , P.toField . A.userSuspendedAt)
  , (colRememberToken   , P.toField . A.userRememberToken)
  , (colLoginCount      , P.toField . A.userLoginCount)
  , (colFailedLoginCount, P.toField . A.userFailedLoginCount)
  , (colLockedOutUntil  , P.toField . A.userLockedOutUntil)
  , (colCurrentLoginAt  , P.toField . A.userCurrentLoginAt)
  , (colLastLoginAt     , P.toField . A.userLastLoginAt)
  , (colCurrentLoginIp  , P.toField . A.userCurrentLoginIp)
  , (colLastLoginIp     , P.toField . A.userLastLoginIp)
  , (colCreatedAt       , P.toField . A.userCreatedAt)
  , (colUpdatedAt       , P.toField . A.userUpdatedAt)
  , (colResetToken      , P.toField . A.userResetToken)
  , (colResetRequestedAt, P.toField . A.userResetRequestedAt)
  ]

saveQuery :: AuthTable -> A.AuthUser -> (Text, [P.Action])
saveQuery atable u@A.AuthUser{..} = maybe insertQuery updateQuery userId
  where
    insertQuery =  (T.concat [ "INSERT INTO "
                             , tblName atable
                             , " ("
                             , T.intercalate "," cols
                             , ") VALUES ("
                             , T.intercalate "," vals
                             , ") RETURNING "
                             , T.intercalate "," (map (fst . ($atable) . fst) colDef)
                             ]
                   , params)
    qval f  = fst (f atable) `T.append` " = ?"
    updateQuery uid =
        (T.concat [ "UPDATE "
                  , tblName atable
                  , " SET "
                  , T.intercalate "," (map (qval . fst) $ tail colDef)
                  , " WHERE "
                  , fst (colId atable)
                  , " = ? RETURNING "
                  , T.intercalate "," (map (fst . ($atable) . fst) colDef)
                  ]
        , params ++ [P.toField $ A.unUid uid])
    cols = map (fst . ($atable) . fst) $ tail colDef
    vals = map (const "?") cols
    params = map (($u) . snd) $ tail colDef


onFailure :: Monad m => E.SomeException -> m (Either A.AuthFailure a)
onFailure e = return $ Left $ A.AuthError $ show e

------------------------------------------------------------------------------
-- |
instance A.IAuthBackend PostgresAuthManager where
    save PostgresAuthManager{..} u@A.AuthUser{..} = do
        let (qstr, params) = saveQuery pamTable u
        let q = Query $ T.encodeUtf8 qstr
        let action = withConnection pamConn $ \conn -> do
                res <- P.query conn q params
                return $ Right $ fromMaybe u $ listToMaybe res
        E.catch action onFailure


    lookupByUserId PostgresAuthManager{..} uid = do
        let q = Query $ T.encodeUtf8 $ T.concat
                [ "select ", T.intercalate "," cols, " from "
                , tblName pamTable
                , " where "
                , fst (colId pamTable)
                , " = ?"
                ]
        querySingle pamConn q [A.unUid uid]
      where cols = map (fst . ($pamTable) . fst) colDef

    lookupByLogin PostgresAuthManager{..} login = do
        let q = Query $ T.encodeUtf8 $ T.concat
                [ "select ", T.intercalate "," cols, " from "
                , tblName pamTable
                , " where "
                , fst (colLogin pamTable)
                , " = ?"
                ]
        querySingle pamConn q [login]
      where cols = map (fst . ($pamTable) . fst) colDef

#if MIN_VERSION_snap(1,1,0)
    lookupByEmail PostgresAuthManager{..} email = do
        let q = Query $ T.encodeUtf8 $ T.concat
                [ "select ", T.intercalate "," cols, " from "
                , tblName pamTable
                , " where "
                , fst (colEmail pamTable)
                , " = ?"
                ]
        querySingle pamConn q [email]
      where cols = map (fst . ($pamTable) . fst) colDef
#endif

    lookupByRememberToken PostgresAuthManager{..} token = do
        let q = Query $ T.encodeUtf8 $ T.concat
                [ "select ", T.intercalate "," cols, " from "
                , tblName pamTable
                , " where "
                , fst (colRememberToken pamTable)
                , " = ?"
                ]
        querySingle pamConn q [token]
      where cols = map (fst . ($pamTable) . fst) colDef

    destroy PostgresAuthManager{..} A.AuthUser{..} = do
        let q = Query $ T.encodeUtf8 $ T.concat
                [ "delete from "
                , tblName pamTable
                , " where "
                , fst (colLogin pamTable)
                , " = ?"
                ]
        authExecute pamConn q [userLogin]