{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
module Web.Users.Postgresql () where

import Web.Users.Types

import Control.Monad
#if MIN_VERSION_mtl(2,2,0)
import Control.Monad.Except
#else
import Control.Monad.Error
#endif
import Data.Int
import Data.Maybe
import Data.Monoid
import Data.Time.Clock
import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.SqlQQ
import Database.PostgreSQL.Simple.Types
import qualified Data.ByteString.Char8 as BSC
import qualified Data.Text as T
import qualified Data.UUID as UUID

createUsersTable :: Query
createUsersTable =
    [sql|
          CREATE TABLE IF NOT EXISTS login (
             lid             SERIAL UNIQUE,
             created_at      TIMESTAMPTZ NOT NULL DEFAULT CURRENT_DATE,
             username        VARCHAR(64)    NOT NULL UNIQUE,
             password        VARCHAR(255)   NOT NULL,
             email           VARCHAR(64)   NOT NULL UNIQUE,
             is_active       BOOLEAN NOT NULL DEFAULT FALSE,
             CONSTRAINT "l_pk" PRIMARY KEY (lid)
          );
    |]

createUserTokenTable :: Query
createUserTokenTable =
    [sql|
          CREATE TABLE IF NOT EXISTS login_token (
             ltid             SERIAL UNIQUE,
             token            UUID UNIQUE,
             token_type       VARCHAR(64) NOT NULL,
             lid              INTEGER NOT NULL,
             created_at       TIMESTAMPTZ NOT NULL DEFAULT CURRENT_DATE,
             valid_until      TIMESTAMPTZ NOT NULL,
             CONSTRAINT "lt_pk" PRIMARY KEY (ltid),
             CONSTRAINT "lt_lid_fk" FOREIGN KEY (lid) REFERENCES login ON DELETE CASCADE
          );
    |]

doesIndexExist :: Connection -> String -> IO Bool
doesIndexExist conn idx =
    do (resultSet :: [Only Int]) <-
           query conn [sql|SELECT 1
                            FROM pg_class c
                            JOIN pg_namespace n ON n.oid = c.relnamespace
                            WHERE c.relname = ?
                            AND n.nspname = 'public';
                      |] (Only idx)
       return (length resultSet > 0)

unlessM :: Monad m => m Bool -> m () -> m ()
unlessM check a =
    do r <- check
       unless r a

#if MIN_VERSION_mtl(2,2,0)
type ErrorT = ExceptT
runErrorT :: ErrorT e m a -> m (Either e a)
runErrorT = runExceptT
#else
-- a hack... :-(
instance Error UpdateUserError where
    noMsg = error "Calling fail not supported"
    strMsg = error "Calling fail not supported"
#endif

getSqlField :: UserField -> BSC.ByteString
getSqlField userField =
    case userField of
      UserFieldId -> "lid"
      UserFieldActive -> "is_active"
      UserFieldEmail -> "email"
      UserFieldName -> "username"
      UserFieldPassword -> "password"

getOrderBy :: SortBy UserField -> BSC.ByteString
getOrderBy sb =
    "ORDER BY " <>
    case sb of
      SortAsc t -> getSqlField t <> " ASC"
      SortDesc t -> getSqlField t <> " DESC"

instance UserStorageBackend Connection where
    type UserId Connection = Int64
    initUserBackend conn =
        do _ <- execute_ conn [sql|CREATE EXTENSION IF NOT EXISTS "uuid-ossp";|]
           _ <- execute_ conn createUsersTable
           _ <- execute_ conn createUserTokenTable
           unlessM (doesIndexExist conn "l_username") $
              do _ <- execute_ conn [sql|CREATE INDEX l_username ON login USING btree(username);|]
                 return ()
           unlessM (doesIndexExist conn "l_email") $
              do _ <- execute_ conn [sql|CREATE INDEX l_email ON login USING btree(email);|]
                 return ()
           unlessM (doesIndexExist conn "l_lower_email") $
              do _ <- execute_ conn [sql|CREATE INDEX l_lower_email ON login USING btree(lower(email));|]
                 return ()
           unlessM (doesIndexExist conn "lt_token_type") $
              do _ <- execute_ conn [sql|CREATE INDEX lt_token_type ON login_token USING btree(token_type);|]
                 return ()
           unlessM (doesIndexExist conn "lt_token") $
              do _ <- execute_ conn [sql|CREATE INDEX lt_token ON login_token USING btree(token);|]
                 return ()
           return ()
    destroyUserBackend conn =
        do _ <- execute_ conn [sql|DROP TABLE login_token;|]
           _ <- execute_ conn [sql|DROP TABLE login;|]
           return ()
    housekeepBackend conn =
        do _ <- execute_ conn [sql|DELETE FROM login_token WHERE valid_until < NOW();|]
           return ()
    -- | Retrieve a user id from the database
    getUserIdByName conn username =
        listToMaybe <$> map fromOnly <$> query conn [sql|SELECT lid FROM login WHERE (username = ? OR email = ?) LIMIT 1;|] (username, username)
    getUserById conn userId =
        do resultSet <-
               query conn [sql|SELECT username, email, is_active FROM login WHERE lid = ? LIMIT 1;|] (Only userId)
           case resultSet of
             ((username, email, is_active) : _) ->
                 return $ Just $ convertUserTuple (username, PasswordHidden, email, is_active)
             _ -> return Nothing
    listUsers conn mLimit sortField =
        do let limitPart =
                   case mLimit of
                     Nothing -> ""
                     Just (start, count) ->
                         (Query $ BSC.pack $ " OFFSET " ++ show start ++ " LIMIT " ++ show count)
               sortPart =
                   Query $ " " <> getOrderBy sortField <> " "
               baseQuery =
                   [sql|SELECT lid, username, email, is_active FROM login|]
               fullQuery = baseQuery <> sortPart <> limitPart
               convertUser (lid, username, email, isActive) =
                   (lid, convertUserTuple (username, PasswordHidden, email, isActive))
           resultSet <-
               query_ conn fullQuery
           return $ map convertUser resultSet

    countUsers conn =
        do [(Only count)] <-
               query_ conn [sql|SELECT COUNT(lid) FROM login;|]
           return count
    createUser conn user =
        case u_password user of
          PasswordHash p ->
              do ([(Only emailCounter)], [(Only nameCounter)]) <- (,) <$>
                     query conn [sql|SELECT COUNT(lid) FROM login WHERE lower(email) = lower(?) LIMIT 1;|] (Only $ u_email user)
                     <*> query conn [sql|SELECT COUNT(lid) FROM login WHERE username = ? LIMIT 1;|] (Only $ u_name user)
                 let both f (x, y) = (f x, f y)
                     bothCount = both (== 1) (emailCounter :: Int64, nameCounter :: Int64)
                 case bothCount of
                      (True, True)   -> return $ Left UsernameAndEmailAlreadyTaken
                      (True, False)  -> return $ Left EmailAlreadyTaken
                      (False, True)  -> return $ Left UsernameAlreadyTaken
                      (False, False) ->
                        do [(Only userId)] <-
                               query conn [sql|INSERT INTO login (username, password, email, is_active) VALUES (?, ?, ?, ?) RETURNING lid|]
                                     (u_name user, p, u_email user, u_active user)
                           return $ Right userId
          _ ->
              return $ Left InvalidPassword
    updateUser conn userId updateFun =
        do mUser <- getUserById conn userId
           case mUser of
             Nothing ->
                 return $ Left UserDoesntExist
             Just origUser ->
                 runErrorT $
                 do let newUser = updateFun origUser
                    when (u_name newUser /= u_name origUser) $
                         do [(Only counter)] <-
                                liftIO $ query conn [sql|SELECT COUNT(lid) FROM login WHERE username = ?;|] (Only $ u_name newUser)
                            when ((counter :: Int64) /= 0) $ throwError UsernameAlreadyExists
                    when (u_email newUser /= u_email origUser) $
                         do [(Only counter)] <-
                                liftIO $ query conn [sql|SELECT COUNT(lid) FROM login WHERE lower(email) = lower(?);|] (Only $ u_email newUser)
                            when ((counter :: Int64) /= 0) $ throwError EmailAlreadyExists
                    liftIO $
                       do _ <-
                              execute conn [sql|UPDATE login SET username = ?, email = ?, is_active = ? WHERE lid = ?;|]
                                 (u_name newUser, u_email newUser, u_active newUser, userId)
                          case u_password newUser of
                            PasswordHash p ->
                                do _ <-
                                      execute conn [sql|UPDATE login SET password = ? WHERE lid = ?;|] (p, userId)
                                   return ()
                            _ -> return ()
                          return ()
    deleteUser conn userId =
        do _ <- execute conn [sql|DELETE FROM login WHERE lid = ?;|] (Only userId)
           return ()
    authUser conn username password sessionTtl =
        withAuthUser conn username (\user -> verifyPassword password $ u_password user) $ \userId ->
           SessionId <$> createToken conn "session" userId sessionTtl
    createSession conn userId sessionTtl =
        do mUser <- getUserById conn userId
           case (mUser :: Maybe User) of
             Nothing -> return Nothing
             Just _ -> Just . SessionId <$> createToken conn "session" userId sessionTtl
    withAuthUser conn username authFn action =
        do resultSet <- query conn [sql|SELECT lid, username, password, email, is_active FROM login WHERE (username = ? OR email = ?) LIMIT 1;|] (username, username)
           case resultSet of
             ((userId, name, password, email, is_active) : _)
               -> do let user = convertUserTuple (name, PasswordHash password, email, is_active)
                     if authFn user
                        then Just <$> action userId
                        else return Nothing
             _ -> return Nothing
    verifySession conn (SessionId sessionId) extendTime =
        do mUser <- getTokenOwner conn "session" sessionId
           case mUser of
             Nothing -> return Nothing
             Just userId ->
                 do extendToken conn "session" sessionId extendTime
                    return (Just userId)
    destroySession conn (SessionId sessionId) = deleteToken conn "session" sessionId
    requestPasswordReset conn userId timeToLive =
        do token <- createToken conn "password_reset" userId timeToLive
           return $ PasswordResetToken token
    requestActivationToken conn userId timeToLive =
        do token <- createToken conn "activation" userId timeToLive
           return $ ActivationToken token
    activateUser conn (ActivationToken token) =
        do mUser <- getTokenOwner conn "activation" token
           case mUser of
             Nothing ->
                 return $ Left TokenInvalid
             Just userId ->
                 do _ <-
                        updateUser conn userId $ \user -> user { u_active = True }
                    deleteToken conn "activation" token
                    return $ Right ()
    verifyPasswordResetToken conn (PasswordResetToken token) =
        do mUser <- getTokenOwner conn "password_reset" token
           case mUser of
             Nothing -> return Nothing
             Just userId -> getUserById conn userId
    applyNewPassword conn (PasswordResetToken token) password =
        do mUser <- getTokenOwner conn "password_reset" token
           case mUser of
             Nothing ->
                 return $ Left TokenInvalid
             Just userId ->
                 do _ <-
                        updateUser conn userId $ \user -> user { u_password = password }
                    deleteToken conn "password_reset" token
                    return $ Right ()

convertTtl :: NominalDiffTime -> Int
convertTtl = round

createToken :: Connection -> String -> Int64 -> NominalDiffTime -> IO T.Text
createToken conn tokenType userId timeToLive =
    do [Only sessionToken] <-
           query conn [sql|INSERT INTO login_token (token, token_type, lid, valid_until)
                            VALUES (uuid_generate_v4(), ?, ?, NOW() + '? seconds')
                                   RETURNING token;|]
                     (tokenType, userId :: Int64, convertTtl timeToLive)
       return (T.pack $ UUID.toString sessionToken)

deleteToken :: Connection -> String -> T.Text -> IO ()
deleteToken conn tokenType token =
    case UUID.fromString (T.unpack token) of
      Nothing -> return ()
      Just uuid ->
          do _ <- execute conn [sql|DELETE FROM login_token WHERE token_type = ? AND token = ?;|] (tokenType, uuid)
             return ()

extendToken :: Connection -> String -> T.Text -> NominalDiffTime -> IO ()
extendToken conn tokenType token timeToLive =
    case UUID.fromString (T.unpack token) of
      Nothing -> return ()
      Just uuid ->
          do _ <-
                  execute conn [sql|
                                   UPDATE login_token
                                   SET valid_until =
                                            (CASE WHEN NOW() + '? seconds' > valid_until THEN NOW() + '? seconds' ELSE valid_until END)
                                   WHERE token_type = ?
                                   AND token = ?;|] (convertTtl timeToLive, convertTtl timeToLive, tokenType, uuid)
             return ()

getTokenOwner :: Connection -> String -> T.Text -> IO (Maybe Int64)
getTokenOwner conn tokenType token =
    case UUID.fromString (T.unpack token) of
      Nothing -> return Nothing
      Just uuid ->
          do resultSet <- query conn [sql|SELECT lid FROM login_token WHERE token_type = ? AND token = ? AND valid_until > NOW() LIMIT 1;|] (tokenType, uuid)
             case resultSet of
               ((Only userId) : _) -> return $ Just userId
               _ -> return Nothing

convertUserTuple :: (T.Text, Password, T.Text, Bool) -> User
convertUserTuple (username, password, email, isActive) =
    User username email password isActive