-- |
-- The identity monad transformer.
--
-- This is useful for functions parameterized by a monad transformer.
--
module Foundation.Monad.Identity
    ( IdentityT
    , runIdentityT
    ) where

import Basement.Compat.Base hiding (throw)
import Basement.Monad (MonadFailure(..))
import Foundation.Monad.MonadIO
import Foundation.Monad.Exception
import Foundation.Monad.Transformer

-- | Identity Transformer
newtype IdentityT m a = IdentityT { IdentityT m a -> m a
runIdentityT :: m a }

instance Functor m => Functor (IdentityT m) where
    fmap :: (a -> b) -> IdentityT m a -> IdentityT m b
fmap a -> b
f (IdentityT m a
m) = m b -> IdentityT m b
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (a -> b
f (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` m a
m)
    {-# INLINE fmap #-}

instance Applicative m => Applicative (IdentityT m) where
    pure :: a -> IdentityT m a
pure a
x = m a -> IdentityT m a
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x)
    {-# INLINE pure #-}
    IdentityT m (a -> b)
fab <*> :: IdentityT m (a -> b) -> IdentityT m a -> IdentityT m b
<*> IdentityT m a
fa = m b -> IdentityT m b
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (IdentityT m (a -> b) -> m (a -> b)
forall (m :: * -> *) a. IdentityT m a -> m a
runIdentityT IdentityT m (a -> b)
fab m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IdentityT m a -> m a
forall (m :: * -> *) a. IdentityT m a -> m a
runIdentityT IdentityT m a
fa)
    {-# INLINE (<*>) #-}

instance Monad m => Monad (IdentityT m) where
    return :: a -> IdentityT m a
return a
x = m a -> IdentityT m a
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x)
    {-# INLINE return #-}
    IdentityT m a
ma >>= :: IdentityT m a -> (a -> IdentityT m b) -> IdentityT m b
>>= a -> IdentityT m b
mb = m b -> IdentityT m b
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (m b -> IdentityT m b) -> m b -> IdentityT m b
forall a b. (a -> b) -> a -> b
$ IdentityT m a -> m a
forall (m :: * -> *) a. IdentityT m a -> m a
runIdentityT IdentityT m a
ma m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IdentityT m b -> m b
forall (m :: * -> *) a. IdentityT m a -> m a
runIdentityT (IdentityT m b -> m b) -> (a -> IdentityT m b) -> a -> m b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> IdentityT m b
mb
    {-# INLINE (>>=) #-}

instance MonadTrans IdentityT where
    lift :: m a -> IdentityT m a
lift = m a -> IdentityT m a
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT
    {-# INLINE lift #-}

instance MonadIO m => MonadIO (IdentityT m) where
    liftIO :: IO a -> IdentityT m a
liftIO IO a
f = m a -> IdentityT m a
forall (trans :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans trans, Monad m) =>
m a -> trans m a
lift (IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
f)
    {-# INLINE liftIO #-}

instance MonadFailure m => MonadFailure (IdentityT m) where
    type Failure (IdentityT m) = Failure m
    mFail :: Failure (IdentityT m) -> IdentityT m ()
mFail = m () -> IdentityT m ()
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (m () -> IdentityT m ())
-> (Failure m -> m ()) -> Failure m -> IdentityT m ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Failure m -> m ()
forall (m :: * -> *). MonadFailure m => Failure m -> m ()
mFail

instance MonadThrow m => MonadThrow (IdentityT m) where
    throw :: e -> IdentityT m a
throw e
e = m a -> IdentityT m a
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw e
e)

instance MonadCatch m => MonadCatch (IdentityT m) where
    catch :: IdentityT m a -> (e -> IdentityT m a) -> IdentityT m a
catch (IdentityT m a
m) e -> IdentityT m a
c = m a -> IdentityT m a
forall (m :: * -> *) a. m a -> IdentityT m a
IdentityT (m a -> IdentityT m a) -> m a -> IdentityT m a
forall a b. (a -> b) -> a -> b
$ m a
m m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (IdentityT m a -> m a
forall (m :: * -> *) a. IdentityT m a -> m a
runIdentityT (IdentityT m a -> m a) -> (e -> IdentityT m a) -> e -> m a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. e -> IdentityT m a
c)