{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ScopedTypeVariables #-} module Web.Users.Postgresql () where import Web.Users.Types import Control.Monad import Control.Monad.Except import Data.Aeson import Data.Int import Data.Maybe import Data.Monoid import Data.Time.Clock import Database.PostgreSQL.Simple import Database.PostgreSQL.Simple.SqlQQ import Database.PostgreSQL.Simple.Types import qualified Data.ByteString.Char8 as BSC import qualified Data.Text as T import qualified Data.UUID as UUID createUsersTable :: Query createUsersTable = [sql| CREATE TABLE IF NOT EXISTS login ( lid SERIAL UNIQUE, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_DATE, username VARCHAR(64) NOT NULL UNIQUE, password VARCHAR(255) NOT NULL, email VARCHAR(64) NOT NULL UNIQUE, is_active BOOLEAN NOT NULL DEFAULT FALSE, more JSON, CONSTRAINT "l_pk" PRIMARY KEY (lid)); |] createUserTokenTable :: Query createUserTokenTable = [sql| CREATE TABLE IF NOT EXISTS login_token ( ltid SERIAL UNIQUE, token UUID UNIQUE, token_type VARCHAR(64) NOT NULL, lid INTEGER NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_DATE, valid_until TIMESTAMPTZ NOT NULL, CONSTRAINT "lt_pk" PRIMARY KEY (ltid), CONSTRAINT "lt_lid_fk" FOREIGN KEY (lid) REFERENCES login ON DELETE CASCADE ); |] doesIndexExist :: Connection -> String -> IO Bool doesIndexExist conn idx = do (resultSet :: [Only Int]) <- query conn [sql|SELECT 1 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = ? AND n.nspname = 'public'; |] (Only idx) return (length resultSet > 0) unlessM :: Monad m => m Bool -> m () -> m () unlessM check a = do r <- check unless r a instance UserStorageBackend Connection where type UserId Connection = Int64 initUserBackend conn = do _ <- execute_ conn [sql|CREATE EXTENSION IF NOT EXISTS pgcrypto;|] _ <- execute_ conn [sql|CREATE EXTENSION IF NOT EXISTS "uuid-ossp";|] _ <- execute_ conn createUsersTable _ <- execute_ conn createUserTokenTable unlessM (doesIndexExist conn "l_username") $ do _ <- execute_ conn [sql|CREATE INDEX l_username ON login USING btree(username);|] return () unlessM (doesIndexExist conn "l_email") $ do _ <- execute_ conn [sql|CREATE INDEX l_email ON login USING btree(email);|] return () unlessM (doesIndexExist conn "lt_token_type") $ do _ <- execute_ conn [sql|CREATE INDEX lt_token_type ON login_token USING btree(token_type);|] return () unlessM (doesIndexExist conn "lt_token") $ do _ <- execute_ conn [sql|CREATE INDEX lt_token ON login_token USING btree(token);|] return () return () destroyUserBackend conn = do _ <- execute_ conn [sql|DROP TABLE login_token;|] _ <- execute_ conn [sql|DROP TABLE login;|] return () housekeepBackend conn = do _ <- execute_ conn [sql|DELETE FROM login_token WHERE valid_until < NOW();|] return () getUserById conn userId = do resultSet <- query conn [sql|SELECT username, email, is_active, more FROM login WHERE lid = ? LIMIT 1;|] (Only userId) case resultSet of (userTuple : _) -> return $ convertUserTuple userTuple _ -> return Nothing listUsers conn mLimit = do let limitPart = case mLimit of Nothing -> "" Just (start, count) -> (Query $ BSC.pack $ " LIMIT " ++ show start ++ ", " ++ show count) baseQuery = [sql|SELECT lid, username, email, is_active, more FROM login|] fullQuery = baseQuery <> limitPart convertUser (lid, username, email, isActive, more) = do user <- convertUserTuple (username, email, isActive, more) return (lid, user) resultSet <- query_ conn fullQuery return $ catMaybes $ map convertUser resultSet countUsers conn = do [(Only count)] <- query_ conn [sql|SELECT COUNT(lid) FROM login;|] return count createUser conn user = case u_password user of PasswordPlain p -> do [(Only counter)] <- query conn [sql|SELECT COUNT(lid) FROM login WHERE username = ? OR email = ?;|] (u_name user, u_email user) if (counter :: Int64) /= 0 then return $ Left UsernameOrEmailAlreadyTaken else do [(Only userId)] <- query conn [sql|INSERT INTO login (username, password, email, is_active, more) VALUES (?, crypt(?, gen_salt('bf', 8)), ?, ?, ?) RETURNING lid|] (u_name user, p, u_email user, u_active user, toJSON $ u_more user) return $ Right userId _ -> return $ Left InvalidPassword updateUser conn userId updateFun = do mUser <- getUserById conn userId case mUser of Nothing -> return $ Left UserDoesntExit Just origUser -> runExceptT $ do let newUser = updateFun origUser when (u_name newUser /= u_name origUser) $ do [(Only counter)] <- liftIO $ query conn [sql|SELECT COUNT(lid) FROM login WHERE username = ?;|] (Only $ u_name newUser) when ((counter :: Int64) /= 0) $ throwError UsernameOrEmailAlreadyExists when (u_email newUser /= u_email origUser) $ do [(Only counter)] <- liftIO $ query conn [sql|SELECT COUNT(lid) FROM login WHERE email = ?;|] (Only $ u_email newUser) when ((counter :: Int64) /= 0) $ throwError UsernameOrEmailAlreadyExists liftIO $ do _ <- execute conn [sql|UPDATE login SET username = ?, email = ?, is_active = ?, more = ? WHERE lid = ?;|] (u_name newUser, u_email newUser, u_active newUser, toJSON $ u_more newUser, userId) case u_password newUser of PasswordPlain p -> do _ <- execute conn [sql|UPDATE login SET password = crypt(?, gen_salt('bf', 8)) WHERE lid = ?;|] (p, userId) return () _ -> return () return () deleteUser conn userId = do _ <- execute conn [sql|DELETE FROM login WHERE lid = ?;|] (Only userId) return () authUser conn username password sessionTtl = do resultSet <- query conn [sql|SELECT lid FROM login WHERE (username = ? OR email = ?) AND crypt(?, password) = password LIMIT 1;|] (username, username, password) case resultSet of ((Only userId) : _) -> do sessionToken <- createToken conn "session" userId sessionTtl return $ Just $ SessionId sessionToken _ -> return Nothing verifySession conn (SessionId sessionId) extendTime = do mUser <- getTokenOwner conn "session" sessionId case mUser of Nothing -> return Nothing Just userId -> do extendToken conn "session" sessionId extendTime return (Just userId) destroySession conn (SessionId sessionId) = deleteToken conn "session" sessionId requestPasswordReset conn userId timeToLive = do token <- createToken conn "password_reset" userId timeToLive return $ PasswordResetToken token requestActivationToken conn userId timeToLive = do token <- createToken conn "activation" userId timeToLive return $ ActivationToken token activateUser conn (ActivationToken token) = do mUser <- getTokenOwner conn "activation" token case mUser of Nothing -> return $ Left TokenInvalid Just userId -> do _ <- updateUser conn userId $ \(user :: User Value) -> user { u_active = True } deleteToken conn "activation" token return $ Right () verifyPasswordResetToken conn (PasswordResetToken token) = do mUser <- getTokenOwner conn "password_reset" token case mUser of Nothing -> return Nothing Just userId -> getUserById conn userId applyNewPassword conn (PasswordResetToken token) password = do mUser <- getTokenOwner conn "password_reset" token case mUser of Nothing -> return $ Left TokenInvalid Just userId -> do _ <- updateUser conn userId $ \(user :: User Value) -> user { u_password = PasswordPlain password } deleteToken conn "password_reset" token return $ Right () convertTtl :: NominalDiffTime -> Int convertTtl = round createToken :: Connection -> String -> Int64 -> NominalDiffTime -> IO T.Text createToken conn tokenType userId timeToLive = do [(Only sessionToken)] <- query conn [sql|INSERT INTO login_token (token, token_type, lid, valid_until) VALUES (uuid_generate_v4(), ?, ?, NOW() + '? seconds') RETURNING token;|] (tokenType, userId :: Int64, convertTtl timeToLive) return (T.pack $ UUID.toString sessionToken) deleteToken :: Connection -> String -> T.Text -> IO () deleteToken conn tokenType token = case UUID.fromString (T.unpack token) of Nothing -> return () Just uuid -> do _ <- execute conn [sql|DELETE FROM login_token WHERE token_type = ? AND token = ?;|] (tokenType, uuid) return () extendToken :: Connection -> String -> T.Text -> NominalDiffTime -> IO () extendToken conn tokenType token timeToLive = case UUID.fromString (T.unpack token) of Nothing -> return () Just uuid -> do _ <- execute conn [sql|UPDATE login_token SET valid_until = valid_until + '? seconds' WHERE token_type = ? AND token = ?;|] (convertTtl timeToLive, tokenType, uuid) return () getTokenOwner :: Connection -> String -> T.Text -> IO (Maybe Int64) getTokenOwner conn tokenType token = case UUID.fromString (T.unpack token) of Nothing -> return Nothing Just uuid -> do resultSet <- query conn [sql|SELECT lid FROM login_token WHERE token_type = ? AND token = ? AND valid_until > NOW() LIMIT 1;|] (tokenType, uuid) case resultSet of ((Only userId) : _) -> return $ Just userId _ -> return Nothing convertUserTuple :: (FromJSON a, Monad m) => (T.Text, T.Text, Bool, Value) -> m (User a) convertUserTuple (username, email, isActive, more) = case fromJSON more of Error e -> fail e Success val -> return $ User username email PasswordHidden isActive val