{-# LANGUAGE CPP #-}
module Servant.Auth.Server.Internal.Types where

import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Time
import Data.Monoid          (Monoid (..))
import Data.Semigroup       (Semigroup (..))
import Data.Time            (getCurrentTime)
import GHC.Generics         (Generic)
import Network.Wai          (Request)

import qualified Control.Monad.Fail as Fail

-- | The result of an authentication attempt.
data AuthResult val
  = BadPassword
  | NoSuchUser
  -- | Authentication succeeded.
  | Authenticated val
  -- | If an authentication procedure cannot be carried out - if for example it
  -- expects a password and username in a header that is not present -
  -- @Indefinite@ is returned. This indicates that other authentication
  -- methods should be tried.
  | Indefinite
  deriving (Eq, Show, Read, Generic, Ord, Functor, Traversable, Foldable)

instance Semigroup (AuthResult val) where
  Indefinite <> y = y
  x          <> _ = x

instance Monoid (AuthResult val) where
  mempty = Indefinite
  mappend = (<>)

instance Applicative AuthResult where
  pure = return
  (<*>) = ap

instance Monad AuthResult where
  return = Authenticated
  Authenticated v >>= f = f v
  BadPassword  >>= _ = BadPassword
  NoSuchUser   >>= _ = NoSuchUser
  Indefinite   >>= _ = Indefinite

instance Alternative AuthResult where
  empty = mzero
  (<|>) = mplus

instance MonadPlus AuthResult where
  mzero = mempty
  mplus = (<>)


-- | An @AuthCheck@ is the function used to decide the authentication status
-- (the 'AuthResult') of a request. Different @AuthCheck@s may be combined as a
-- Monoid or Alternative; the semantics of this is that the *first*
-- non-'Indefinite' result from left to right is used and the rest are ignored.
newtype AuthCheck val = AuthCheck
  { runAuthCheck :: Request -> IO (AuthResult val) }
  deriving (Generic, Functor)

instance Semigroup (AuthCheck val) where
  AuthCheck f <> AuthCheck g = AuthCheck $ \x -> do
    fx <- f x
    case fx of
      Indefinite -> g x
      r -> pure r

instance Monoid (AuthCheck val) where
  mempty = AuthCheck $ const $ return mempty
  mappend = (<>)

instance Applicative AuthCheck where
  pure = return
  (<*>) = ap

instance Monad AuthCheck where
  return = AuthCheck . return . return . return
  AuthCheck ac >>= f = AuthCheck $ \req -> do
    aresult <- ac req
    case aresult of
      Authenticated usr -> runAuthCheck (f usr) req
      BadPassword       -> return BadPassword
      NoSuchUser        -> return NoSuchUser
      Indefinite        -> return Indefinite

#if !MIN_VERSION_base(4,13,0)
  fail = Fail.fail
#endif

instance Fail.MonadFail AuthCheck where
  fail _ = AuthCheck . const $ return Indefinite

instance MonadReader Request AuthCheck where
  ask = AuthCheck $ \x -> return (Authenticated x)
  local f (AuthCheck check) = AuthCheck $ \req -> check (f req)

instance MonadIO AuthCheck where
  liftIO action = AuthCheck $ const $ Authenticated <$> action

instance MonadTime AuthCheck where
  currentTime = liftIO getCurrentTime

instance Alternative AuthCheck where
  empty = mzero
  (<|>) = mplus

instance MonadPlus AuthCheck where
  mzero = mempty
  mplus = (<>)