{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}

module Control.SafeAccess
  ( ensureAccess
  , Capability(..)
  , Capabilities
  , AccessDecision(..)
  , SafeAccessT(..)
  , AccessDescriptor(..)
  , MonadSafeAccess(..)
  , liftErrorT
  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Error

import Data.List

-- | Check that the access is legal or make the monad \"fail\".
ensureAccess :: MonadSafeAccess m d => d -> m ()
ensureAccess descr = do
  caps <- getCapabilities
  let decisions     = map (\cap -> runCapability cap descr) caps
      finalDecision = foldl' mergeDecisions AccessDeniedSoft decisions
  case finalDecision of
    AccessGranted     -> return ()
    _                 -> denyAccess descr

-- | Allow things to be accessed. See 'ensureAccess'.
--
-- @d@ is the type describing an access.
newtype Capability d = MkCapability { runCapability :: d -> AccessDecision }

type Capabilities d = [Capability d]

-- | Control the decision process.
--
-- The constructors are ordered by prevalence. For instance, if two capabilities
-- respectively return 'AccessGranted' and 'AccessDenied',
-- the final decision will be 'AccessDenied'.
data AccessDecision
  = AccessDeniedSoft  -- ^ No but another 'Capability' can still decide to grant
  | AccessGranted     -- ^ Final yes (see explanation)
  | AccessDenied      -- ^ Final no
  deriving (Show, Eq)

mergeDecisions :: AccessDecision -> AccessDecision -> AccessDecision
mergeDecisions a b = case (a, b) of
  (AccessDeniedSoft, _)  -> b
  (_, AccessDeniedSoft)  -> a
  (AccessGranted, _)     -> b
  (_, AccessGranted)     -> a
  _                      -> AccessDenied

-- | A simple monad (transformer) to ensure that data are accessed legitimately.
--
-- The return value is either the description of an access having been denied (left)
-- or the result of the normal computation (right).
newtype SafeAccessT d m a
  = SafeAccessT { runSafeAccessT :: Capabilities d -> m (Either d a) }

instance (AccessDescriptor d, Monad m) => Monad (SafeAccessT d m) where
  return = SafeAccessT . const . return . Right
  ma >>= f = SafeAccessT $ \caps -> do
    ex <- runSafeAccessT ma caps
    case ex of
      Left d  -> return $ Left d
      Right x -> runSafeAccessT (f x) caps
  fail = SafeAccessT . const . return . Left . descrMsg

instance MonadTrans (SafeAccessT d) where
  lift = SafeAccessT . const . (Right `liftM`)

instance Functor f => Functor (SafeAccessT d f) where
  fmap f sa = SafeAccessT $ \caps -> fmap (fmap f) $ runSafeAccessT sa caps

instance Applicative f => Applicative (SafeAccessT d f) where
  pure = SafeAccessT . const . pure . Right
  safef <*> safea = SafeAccessT $ \caps ->
    let fef = runSafeAccessT safef caps
        fea = runSafeAccessT safea caps
        ff  = flip fmap fef $ \ef -> case ef of
                Left d  -> const $ Left d
                Right f -> fmap f
    in ff <*> fea

instance (AccessDescriptor d, MonadIO m) => MonadIO (SafeAccessT d m) where
  liftIO = SafeAccessT . const . (Right `liftM`) . liftIO

-- | It is needed by 'fail' on 'SafeAccessT' to put the string into the @Either d a@ returned.
class AccessDescriptor d where
  descrMsg :: String -> d

getCapabilities' :: Monad m => SafeAccessT d m (Capabilities d)
getCapabilities' = SafeAccessT $ return . Right

denyAccess' :: Monad m => d -> SafeAccessT d m ()
denyAccess' = SafeAccessT . const . return . Left

class (AccessDescriptor d, Monad m) => MonadSafeAccess m d where
  getCapabilities :: m (Capabilities d)
  denyAccess      :: d -> m ()

instance (AccessDescriptor d, Monad m) => MonadSafeAccess (SafeAccessT d m) d where
  getCapabilities = getCapabilities'
  denyAccess      = denyAccess'

-- | Lift an action from 'ErrorT' to 'SafeAccessT'.
liftErrorT :: ErrorT d m a -> SafeAccessT d m a
liftErrorT = SafeAccessT . const . runErrorT