{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Snap.Snaplet.SqliteSimple.JwtAuth.JwtAuth (
    SqliteJwt(..)
  , User(..)
  , AuthFailure(..)
  , defaults
  , sqliteJwtInit
  , requireAuth
  , registerUser
  , loginUser
  , createUser
  , login
  , jsonResponse
  , writeJSON
  , reqJSON
  ) where

------------------------------------------------------------------------------
import           Control.Lens hiding ((.=), (??))
import           Control.Monad.Except
import           Control.Monad.State (gets)
import           Control.Error hiding (err)
import qualified Crypto.BCrypt as BC
import           Data.Aeson
import           Data.Aeson.Types (parseEither)
import qualified Data.Attoparsec.ByteString.Char8 as AP
import           Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS8
import           Data.Maybe
import           Data.Map as M
import           Data.HashMap.Strict as HM
import qualified Data.Text as T
import qualified Data.Text.Encoding as LT
import           Data.Time.Clock.POSIX (POSIXTime, getPOSIXTime)
import           Snap
import           Snap.Snaplet.SqliteSimple (Sqlite, sqliteConn)
import qualified Web.JWT as JWT

import           Snap.Snaplet.SqliteSimple.JwtAuth.Util
import           Snap.Snaplet.SqliteSimple.JwtAuth.Types
import qualified Snap.Snaplet.SqliteSimple.JwtAuth.Db as Db

-- | Default settings for the snaplet
defaults :: Options
defaults = Options {
    hashingPolicy      = BC.fastBcryptHashingPolicy
  , signingKeyFilename = "jwt_secret.txt"
  , maxTokenExpiration = 60*60*24*14 -- two weeks
  }

-------------------------------------------------------------------------
-- | Initializer for the sqlite-simple JwtAuth snaplet.
--
-- If the secret random key 'jwtSigningKeyFname' doesn't exist in the current
-- working directory, a new random key will be generated.  Otherwise the
-- existing key will be loaded as the site signing key.  This key is used to
-- sign the JWTs generated by the login procedure.
--
-- Initialization will automatically setup SQL tables used to store user
-- accounts.  It will also automatically upgrade the SQL schema if necessary.
sqliteJwtInit
  :: Options        -- ^ Site policy options
  -> Snaplet Sqlite -- ^ The sqlite-simple snaplet
  -> SnapletInit b SqliteJwt
sqliteJwtInit options db = makeSnaplet "sqlite-simple-jwt" description Nothing $ do
    k <- liftIO $ (JWT.binarySecret <$> getKey (signingKeyFilename options))
    let conn = sqliteConn $ db ^# snapletValue
    liftIO $ Db.createTableIfMissing conn
    return $ SqliteJwt k conn options
  where
    description = "sqlite-simple jwt auth"

-------------------------------------------------------------------------
-- | Create a new user.
createUser
  :: T.Text -- ^ Login name of the user to be created
  -> T.Text -- ^ Password of the new user
  -> Handler b SqliteJwt (Either AuthFailure User)
createUser loginName password = do
  user <- Db.queryUser loginName
  hashPolicy <- hashingPolicy <$> gets options
  case user of
    Nothing -> do
      hashedPass <- liftIO $ BC.hashPasswordUsingPolicy hashPolicy (LT.encodeUtf8 password)
      -- TODO don't use fromJust
      Db.insertUser loginName (fromJust hashedPass)
      u <- Db.queryUser loginName
      return (Right (Db.fromDbUser . fromJust $ u))
    Just _ ->
      return (Left DuplicateLogin)

-------------------------------------------------------------------------
-- | Login a user
login
  :: T.Text -- ^ Login name of the user logging in
  -> T.Text -- ^ Password
  -> Handler b SqliteJwt (Either AuthFailure User)
login loginName password = do
  user <- Db.queryUser loginName
  case user of
    Nothing ->
      return (Left UnknownUser)
    Just u -> do
      if BC.validatePassword (Db.dbuserHashedPass u) (LT.encodeUtf8 password) then
        passwordOk (Db.fromDbUser u)
      else
        passwordFail

  where
    -- TODO this should return JWT
    passwordOk u = return (Right u)
    passwordFail = return (Left WrongPassword)

parseBearerJwt :: ByteString -> Either String T.Text
parseBearerJwt s = AP.parseOnly (AP.string "Bearer " *> payload) s
  where
    payload = LT.decodeUtf8 <$> AP.takeWhile1 (AP.inClass base64)
    base64 = "A-Za-z0-9+/_=.-"

jwtFromUser :: User -> POSIXTime -> Handler b SqliteJwt JWT.JSON
jwtFromUser (User uid loginName) expiresOn = do
  key <- gets siteSecret
  let cs = JWT.def {
            JWT.unregisteredClaims = M.fromList [("id", Number (fromIntegral uid)), ("login", String loginName)]
          , JWT.exp = JWT.intDate $ expiresOn
          }
  return $ JWT.encodeSigned JWT.HS256 key cs

-- | Run a handler with the currently logged in user.
--
-- Verify authentication from the JWT token passed in the Authorization
-- header, and run the user provided 'action' with the logged in user.
--
-- Use the following syntax for constructing the Authorization header:
--
-- @
-- Bearer \<JWT\>
-- @
--
-- where \<JWT\> is obtained from the "token" field of a successful call to
-- 'registerUser' or 'loginUser'.
--
-- On errors such as missing or malformed JWT or failure to verify the JWT,
-- error out early and issue an HTTP 401 error.
requireAuth :: (User -> Handler b SqliteJwt a) -> Handler b SqliteJwt a
requireAuth action = do
  key <- gets siteSecret
  req <- getRequest
  res <- runExceptT $ do
    curTime     <- liftIO $ getPOSIXTime
    authHdr     <- getHeader "Authorization" (rqHeaders req) ?? "missing Authorization header"
    encPayload  <- hoistEither . parseBearerJwt $ authHdr
    jwt         <- JWT.decode encPayload           ?? "malformed JWT"
    verifJwt    <- JWT.verify key jwt              ?? "JWT verification failed"
    exp         <- JWT.exp (JWT.claims verifJwt)   ?? "exp not set in JWT"
    assertZ (JWT.secondsSinceEpoch exp >= curTime) ?? "token has expired"
    let unregClaims = JWT.unregisteredClaims (JWT.claims verifJwt)
    hoistEither . parseEither parseJSON $ (toObject unregClaims)
  either (finishEarly 401 . BS8.pack) action res
  where
    toObject = Object . HM.fromList . M.toList

handleLoginError :: AuthFailure -> H b ()
handleLoginError err =
  case err of
    DuplicateLogin -> failLogin dupeError
    UnknownUser    -> failLogin failedPassOrUserError
    WrongPassword  -> failLogin failedPassOrUserError
  where
    dupeError             = "Duplicate login"
    failedPassOrUserError = "Unknown user or wrong password"

    failLogin :: T.Text -> H b ()
    failLogin msg = do
      jsonResponse
      modifyResponse $ setResponseStatus 401 "bad login"
      writeJSON $ object [ "error" .= msg]

loginOK :: User -> Handler b SqliteJwt ()
loginOK user = do
  expiresIn <- maxTokenExpiration <$> gets options
  curTime   <- liftIO getPOSIXTime
  jwt       <- jwtFromUser user (curTime + (fromIntegral expiresIn))
  writeJSON $ object [ "token" .= jwt ]

-- | Create a new user.
--
-- Use a POST request for this handler with password and login encoded as JSON
-- in the body of the request:
--
-- @
-- {
--    "login": \<login name\>,
--    "pass": \<password\>
-- }
-- @
--
-- A successful login will reply with an HTTP 400 code and the following JSON:
--
-- @
-- {
--    "token": \<JWT\>
-- }
-- @
--
-- The returned token can be used to make authenticated requests.
--
-- Login and user registration errors will be reported with an HTTP error code
-- 401 and the error message will be sent as a JSON object:
--
-- @
-- {
--   "error": \<message\>
-- }
-- @
registerUser :: Handler b SqliteJwt ()
registerUser = method POST newUser
  where
    newUser = do
      params    <- reqJSON
      userOrErr <- createUser (lpLogin params) (lpPass params)
      either handleLoginError loginOK userOrErr

-- | Login an existing user.
--
-- Use a POST request for this handler with password and login encoded as JSON
-- in the body of the request:
--
-- @
-- {
--   "login": \<login name\>,
--   "pass": \<password\>
-- }
-- @
--
-- A successful login will reply with an HTTP 400 code and the following JSON:
--
-- @
-- {
--   "token": \<JWT\>
-- }
-- @
--
-- The returned token can be used to make authenticated requests.
--
-- Login and user registration errors will be reported with an HTTP error code
-- 401 and the error message will be sent as a JSON object:
--
-- @
-- {
--   "error": \<message\>
-- }
-- @
loginUser :: Handler b SqliteJwt ()
loginUser = method POST $ do
  params    <- reqJSON
  userOrErr <- login (lpLogin params) (lpPass params)
  either handleLoginError loginOK userOrErr