{-# Language RankNTypes #-}
{-# Language FlexibleInstances #-}
{-# Language ScopedTypeVariables #-}
{-# Language UndecidableInstances #-}
{-# Language MultiParamTypeClasses #-}

{-|
Module     : Control.Monad.Trans.Choice.Invariant
Copyright  : (c) Eamon Olive, 2020
             (c) Louis Hyde,  2020
License    : AGPL-3
Maintainer : ejolive97@gmail.com
Stability  : experimental

-}
module Control.Monad.Trans.Choice.Invariant
  ( ChoiceT
    ( ChoiceT
    )
  , runChoiceT
  , invmapChoiceT
  ) where

-- Internal imports

import Control.Monad.Class.Choice
  ( MonadChoice
  , choose
  )

-- External imports

import Control.Applicative
  ( Alternative
    ( empty
    , (<|>)
    )
  )
import Control.Monad.Except
  ( MonadError
    ( throwError
    , catchError
    )
  )
import Control.Monad
  ( MonadPlus
  )
import Control.Monad.IO.Class
  ( MonadIO
    ( liftIO
    )
  )
import Control.Monad.Reader.Class
  ( MonadReader
    ( ask
    , local
    )
  )
import Control.Monad.State.Class
  ( MonadState
    ( get
    , put
    )
  )
import Control.Monad.Trans
  ( MonadTrans
  , lift
  )
import Control.Monad.Writer.Class
  ( MonadWriter
    ( listen
    , pass
    , tell
    )
  )
import Data.Functor.Contravariant
  ( Contravariant
    ( contramap
    )
  )
import Data.Functor.Invariant
  ( Invariant
    ( invmap
    )
  )

-- | The choice monad transformer
-- It takes a monad and enriches it with choice.
newtype ChoiceT f m a = ChoiceT
  { _runChoiceT :: (forall x . f x -> m x) -> m a
  -- ^ Should not be used externally.
  -- runChoiceT has the proper argument order.
  }

-- | Transforms the computation inside the 'ChoiceT'.
-- This function differs from other map functions over the monads in monad transformers in that it requires an additional function of type @forall x . n x -> m x@.
-- This is because 'ChoiceT' is not a covariant functor over the category of monads, but rather an invariant functor.
-- 
invmapChoiceT ::
  (forall x . n x -> m x)
    -> (m a -> n b)
      -> ChoiceT f m a
        -> ChoiceT f n b
invmapChoiceT f g m1 = ChoiceT (\ chooser -> g $ runChoiceT (\ options -> f $ chooser options) m1)

-- | Use a chooser to perform all the selections of a 'ChoiceT'.
-- This uses a sensible argument order unlike many of the other monad transformers.
runChoiceT ::
  (forall x . f x -> m x)
  -- ^ A chooser which can perform selections over @f@s containing arbitrary @x@s.
  -- This will be used to determine the result of each selection.
  -- 
  -- This type proves the theorem :
  --
  -- > chooser (fmap f x) = fmap f (chooser x)
    -> ChoiceT f m a
    -- ^ A monad transformed by 'ChoiceT' to contain choices
      -> m a
      -- ^ The initial monad with the 'ChoiceT' removed and thus the choices resolved.
runChoiceT chooser m1 = _runChoiceT m1 chooser

-- | This acts as an map but does not go as "deep" as 'fmap'.
--
-- prop> lowMap f . lift = lift . f
--
-- prop> lowMap id = id
--
lowMap ::
  (m a -> m b)
    -> (ChoiceT f m a -> ChoiceT f m b)
lowMap f m = ChoiceT (\ chooser -> f $ runChoiceT chooser m)


instance Functor m => Functor (ChoiceT f m) where
  fmap f (ChoiceT deChooser) = ChoiceT (\ chooser -> fmap f $ deChooser chooser)

instance Contravariant m => Contravariant (ChoiceT f m) where
  contramap f (ChoiceT deChooser) = ChoiceT (\ chooser -> contramap f $ deChooser chooser)

instance Invariant m => Invariant (ChoiceT f m) where
  invmap f1 f2 (ChoiceT deChooser) = ChoiceT (\ chooser -> invmap f1 f2 $ deChooser chooser)

instance Applicative m => Applicative (ChoiceT f m) where
  pure a =
    ChoiceT (\ _ -> pure a )
  (ChoiceT lDeChooser) <*> (ChoiceT rDeChooser) =
    ChoiceT (\ chooser -> lDeChooser chooser <*> rDeChooser chooser )

instance Alternative m => Alternative (ChoiceT f m) where
  empty = ChoiceT (const empty)
  m1 <|> m2 = ChoiceT (\ chooser -> runChoiceT chooser m1 <|> runChoiceT chooser m2)

instance Monad m => Monad (ChoiceT f m) where
  (ChoiceT deChooser) >>= f =
    ChoiceT (\ chooser -> deChooser chooser >>= (\ x -> runChoiceT chooser (f x)) )

instance MonadPlus m => MonadPlus (ChoiceT f m)

-- | 'ChoiceT' is a functor on the category of monads.
-- However it is an invariant functor.  Meaning the traditional @mapChoiceT@ cannot be implemented.
instance MonadTrans (ChoiceT f) where
  lift m = ChoiceT (\ _ -> m )

instance Monad m => MonadChoice f (ChoiceT f m) where
  choose f = ChoiceT (\ chooser -> chooser f )

instance MonadReader r m => MonadReader r (ChoiceT f m) where
  ask = lift ask
  local f = lowMap (local f)

instance MonadState s m => MonadState s (ChoiceT f m) where
  get = lift get
  put s = lift (put s)

instance MonadWriter w m => MonadWriter w (ChoiceT f m) where
  tell w = lift (tell w)
  listen = lowMap listen
  pass   = lowMap pass

instance MonadIO m => MonadIO (ChoiceT f m) where
  liftIO = lift . liftIO

instance MonadError e m => MonadError e (ChoiceT f m) where
  throwError toThrow = lift (throwError toThrow)
  catchError m handler = ChoiceT (\ chooser -> catchError (runChoiceT chooser m) (runChoiceT chooser . handler))