-- The use of ImpredicativeTypes here is safe, see discussion under GitHub issue
-- #35. It's only needed to allow the visible type application of a polytype.
{-# LANGUAGE ImpredicativeTypes #-}

module Ether.TaggedTrans
  ( TaggedTrans(..)
  ) where

import Control.Applicative
import Control.Monad (MonadPlus)
import Control.Monad.Fix (MonadFix)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Morph (MFunctor(..), MMonad(..))
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)

import qualified Control.Monad.Base as MB
import qualified Control.Monad.Trans.Control as MC

import qualified Control.Monad.Trans.Lift.StT    as Lift
import qualified Control.Monad.Trans.Lift.Local  as Lift
import qualified Control.Monad.Trans.Lift.Catch  as Lift
import qualified Control.Monad.Trans.Lift.Listen as Lift
import qualified Control.Monad.Trans.Lift.Pass   as Lift
import qualified Control.Monad.Trans.Lift.CallCC as Lift

import qualified Control.Monad.Cont.Class    as Mtl
import qualified Control.Monad.Reader.Class  as Mtl
import qualified Control.Monad.State.Class   as Mtl
import qualified Control.Monad.Writer.Class  as Mtl
import qualified Control.Monad.Error.Class   as Mtl

import GHC.Generics (Generic)
import Data.Coerce (coerce)

newtype TaggedTrans tag trans m a = TaggedTrans (trans m a)
  deriving
    ( Generic
    , Functor, Applicative, Alternative, Monad, MonadPlus
    , MonadFix, MonadTrans, MonadIO
    , MonadThrow, MonadCatch, MonadMask )

type Pack tag trans m a = trans m a -> TaggedTrans tag trans m a

type Unpack tag trans m a = TaggedTrans tag trans m a -> trans m a

instance
    ( MB.MonadBase b (trans m)
    ) => MB.MonadBase b (TaggedTrans tag trans m)
  where
    liftBase =
      (coerce :: forall a .
        (b a -> trans m a) ->
        (b a -> TaggedTrans tag trans m a))
      MB.liftBase

instance
    ( MC.MonadTransControl trans
    ) => MC.MonadTransControl (TaggedTrans tag trans)
  where
    type StT (TaggedTrans tag trans) a = MC.StT trans a

    liftWith = MC.defaultLiftWith
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)

    restoreT = MC.defaultRestoreT
      (coerce :: Pack tag trans m a)

type LiftBaseWith b m a = (MC.RunInBase m b -> b a) -> m a

newtype LiftBaseWith' b m a = LBW { unLBW :: LiftBaseWith b m a }

coerceLiftBaseWith ::
  LiftBaseWith b                 (trans m) a ->
  LiftBaseWith b (TaggedTrans tag trans m) a
coerceLiftBaseWith lbw =
  unLBW (coerce (LBW lbw))

instance
    ( MC.MonadBaseControl b (trans m)
    ) => MC.MonadBaseControl b (TaggedTrans tag trans m)
  where
    type StM (TaggedTrans tag trans m) a = MC.StM (trans m) a

    liftBaseWith = coerceLiftBaseWith MC.liftBaseWith

    restoreM =
      (coerce :: forall a .
        (MC.StM (trans m) a ->                 trans m a) ->
        (MC.StM (trans m) a -> TaggedTrans tag trans m a))
      MC.restoreM

type instance Lift.StT (TaggedTrans tag trans) a = Lift.StT trans a

instance Lift.LiftLocal trans => Lift.LiftLocal (TaggedTrans tag trans) where
  liftLocal =
    Lift.defaultLiftLocal
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)

instance Lift.LiftCatch trans => Lift.LiftCatch (TaggedTrans tag trans) where
  liftCatch =
    Lift.defaultLiftCatch
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)

instance Lift.LiftListen trans => Lift.LiftListen (TaggedTrans tag trans) where
  liftListen =
    Lift.defaultLiftListen
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)

instance Lift.LiftPass trans => Lift.LiftPass (TaggedTrans tag trans) where
  liftPass =
    Lift.defaultLiftPass
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)

instance Lift.LiftCallCC trans => Lift.LiftCallCC (TaggedTrans tag trans) where
  liftCallCC  =
    Lift.defaultLiftCallCC
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)
  liftCallCC' =
    Lift.defaultLiftCallCC'
      (coerce :: Pack tag trans m a)
      (coerce :: Unpack tag trans m a)


-- Instances for mtl classes

instance
    ( Mtl.MonadCont m
    , Lift.LiftCallCC trans
    , Monad (trans m)
    ) => Mtl.MonadCont (TaggedTrans tag trans m)
  where
    callCC = Lift.liftCallCC' Mtl.callCC

instance
    ( Mtl.MonadReader r m
    , Lift.LiftLocal trans
    , Monad (trans m)
    ) => Mtl.MonadReader r (TaggedTrans tag trans m)
  where
    ask = lift Mtl.ask
    local = Lift.liftLocal Mtl.ask Mtl.local
    reader = lift . Mtl.reader

instance
    ( Mtl.MonadState s m
    , MonadTrans trans
    , Monad (trans m)
    ) => Mtl.MonadState s (TaggedTrans tag trans m)
  where
    get = lift Mtl.get
    put = lift . Mtl.put
    state = lift . Mtl.state

instance
    ( Mtl.MonadWriter w m
    , Lift.LiftListen trans
    , Lift.LiftPass trans
    , Monad (trans m)
    ) => Mtl.MonadWriter w (TaggedTrans tag trans m)
  where
    writer = lift . Mtl.writer
    tell   = lift . Mtl.tell
    listen = Lift.liftListen Mtl.listen
    pass   = Lift.liftPass Mtl.pass

instance
    ( Mtl.MonadError e m
    , Lift.LiftCatch trans
    , Monad (trans m)
    ) => Mtl.MonadError e (TaggedTrans tag trans m)
  where
    throwError = lift . Mtl.throwError
    catchError = Lift.liftCatch Mtl.catchError

type Hoist trans =
  forall m n b . Monad m =>
    (forall a . m a -> n a) -> trans m b -> trans n b

-- NB: Don't use GeneralizedNewtypeDeriving to create this instance, as it will
-- trigger GHC Trac #11837 on GHC 8.0.1 and older.
instance MFunctor trans => MFunctor (TaggedTrans tag trans) where
  hoist =
    coerce
      @(Hoist trans)
      @(Hoist (TaggedTrans tag trans))
      hoist

type Embed trans =
  forall n m b . Monad n =>
    (forall a . m a -> trans n a) -> trans m b -> trans n b

-- NB: Don't use GeneralizedNewtypeDeriving to create this instance, as it will
-- trigger GHC Trac #11837 on GHC 8.0.1 and older.
instance MMonad trans => MMonad (TaggedTrans tag trans) where
  embed =
    coerce
      @(Embed                  trans)
      @(Embed (TaggedTrans tag trans))
      embed