{-# LANGUAGE OverloadedStrings #-}

module CJ.Auth.Token
  ( UserBearerToken(..)
  , JSONToken(..)
  , decodeToken
  , encodeToken
  ) where

import qualified Data.ByteString.Base64.URL as BS64
import qualified Data.Map as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Web.JWT as JWT
import Control.Monad
import Data.Aeson (FromJSON (..), Value, toJSON)
import Data.Aeson.Types (parseMaybe)
import Data.Either.Combinators
import Data.Maybe
import Data.Time.Clock.POSIX (POSIXTime)
import Web.JWT
  ( Algorithm(HS256)
  , JWTClaimsSet(JWTClaimsSet)
  , Secret
  , claims
  , decodeAndVerifySignature
  , def
  , encodeSigned
  , numericDate
  , secondsSinceEpoch
  , unregisteredClaims
  )

class JSONToken a where
  toClaims :: a -> JWTClaimsSet
  fromClaims :: JWTClaimsSet -> Maybe a

data UserBearerToken = UserBearerToken
  { _bearerUserId :: Maybe T.Text
  , _bearerAppId :: T.Text
  } deriving (Eq, Show)

instance JSONToken UserBearerToken where
  toClaims UserBearerToken{ _bearerUserId = maybeUid, _bearerAppId = aid } =
    def { unregisteredClaims = M.fromList [ ("userId", toJSON maybeUid)
                                          , ("appId",  toJSON aid) ] }

  fromClaims JWTClaimsSet{ unregisteredClaims = clms } =
      UserBearerToken <$> uid <*> aid
    where uid = fromValue =<< M.lookup "userId" clms
          aid = fromValue =<< M.lookup "appId"  clms

fromValue :: FromJSON a => Value -> Maybe a
fromValue = parseMaybe parseJSON

encodeJWT :: Secret -> JWTClaimsSet -> T.Text
encodeJWT = encodeSigned HS256

decodableToken :: T.Text -> Bool
decodableToken token = isJust $ do
    (header : claimsPayload : _) <- return $ T.splitOn "." token
    void $ verifyUtf8Base64 header
    verifyUtf8Base64 claimsPayload
  where
    verifyUtf8Base64 base64Str = rightToMaybe $ TE.decodeUtf8' (BS64.decodeLenient $ TE.encodeUtf8 base64Str)

-- move this out of interface?
encodeToken :: (JSONToken a) => Secret -> a -> POSIXTime -> T.Text
encodeToken secret token expTime = encodeJWT secret expiringClaims
  where baseClaims = toClaims token
        expiringClaims = baseClaims { JWT.exp = numericDate expTime }

decodeToken :: (JSONToken a) => Secret -> T.Text -> POSIXTime -> Maybe a
decodeToken secret token currentTime =
    fromClaims =<< verifyFresh currentTime =<< decodedClaims
  where jwt = if decodableToken token then decodeAndVerifySignature secret token else Nothing
        decodedClaims = claims <$> jwt

verifyFresh :: POSIXTime -> JWTClaimsSet -> Maybe JWTClaimsSet
verifyFresh currentTime clms@JWTClaimsSet{ JWT.exp = (Just expirationTime) }
  | secondsSinceEpoch expirationTime < currentTime = Nothing
  | otherwise                                      = Just clms
verifyFresh _ clms = Just clms