{-# LANGUAGE GADTs, NoMonomorphismRestriction #-}

module Control.RMonad.AsMonad (AsMonad, embed, unEmbed) where

import Control.Monad
import Control.RMonad hiding ((>>=), return, fail, mzero, mplus)
import qualified Control.RMonad as RM ((>>=), return, fail, mzero, mplus)

-- |Turn a restricted monad into a normal monad
embed :: (RMonad m, Suitable m a) => m a -> AsMonad m a
embed = Embed

data AsMonad m a where
   Embed :: (RMonad m, Suitable m a) => m a -> AsMonad m a
   Return :: RMonad m => a -> AsMonad m a
   Bind :: RMonad m => AsMonad m a -> (a -> AsMonad m b) -> AsMonad m b
   Fail :: RMonad m => String -> AsMonad m a
   MZero :: RMonadPlus m => AsMonad m a
   MPlus :: RMonadPlus m => AsMonad m a -> AsMonad m a -> AsMonad m a

instance RMonad m => Monad (AsMonad m) where
   return = Return
   (>>=) = Bind
   fail = Fail

instance RMonadPlus m => MonadPlus (AsMonad m) where
   mzero = MZero
   mplus = MPlus

-- |Unwrap an 'AsMonad' value into the enclosed restricted monad
unEmbed :: Suitable m a => AsMonad m a -> m a
unEmbed (Embed m) = m
unEmbed (Return a) = RM.return a
unEmbed MZero = RM.mzero
unEmbed (MPlus m1 m2) = RM.mplus (unEmbed m1) (unEmbed m2)
unEmbed (Bind (Embed m) f) = (RM.>>=) m (\a -> unEmbed (f a))
unEmbed (Bind (Return a) f) = unEmbed (f a)
unEmbed (Bind (Bind m f) g) = unEmbed (Bind m (\x -> Bind (f x) g))
unEmbed (Bind MZero f) = unEmbed MZero
unEmbed (Bind (MPlus m1 m2) f) = unEmbed (MPlus (Bind m1 f) (Bind m2 f))