{-# LANGUAGE TypeFamilies, NoMonomorphismRestriction, MultiParamTypeClasses, FlexibleInstances,
             NoImplicitPrelude, FlexibleContexts #-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__>=700
{-# LANGUAGE RebindableSyntax #-}
#endif
-- |This module provides alternatives to the 'Functor', 'Monad' and 'MonadPlus' classes,
-- allowing for constraints on the contained type (a restricted monad).
-- It makes use of associated datatypes (available in GHC 6.8).
--
-- To make your own type instances of these classes, first define
-- the 'Constraints' datatype and the 'Suitable' type class for it. For example,
--
--   @
--     data instance Constraints Set a = Ord a => SetConstraints
--     instance Ord a => Suitable Set a where
--        constraints _ = SetConstraints
--   @
--
-- You need to change @Set@ to your own type, @Ord a@ to your own
-- constraints, and @SetConstraints@ to some distinguished name (this name
-- will not normally be visible to users of your type)
--
-- Next you can make an instance of 'RMonad' and if appropriate 'RMonadPlus'
-- by defining the members in the usual way. When you need to make use of the
-- constraint on the contained type, you will need to get hold of the constraint
-- wrapped up in the 'Constraints' datatype. For example here are the instances
-- for @Set@:
--
--   @
--    instance RMonad Set where
--       return = Set.singleton
--       s >>= f = let res = case constraints res of
--                             SetConstraints -> Set.fold (\a s' -> Set.union (f a) s') Set.empty s
--                 in res
--       fail _ = Set.empty
--   @
--
--   @
--    instance RMonadPlus Set where
--       mzero = Set.empty
--       mplus s1 s2 = let res = case constraints res of
--                                  SetConstraints -> Set.union s1 s2
--                     in res
--   @
--
-- Once you have made your type an instance of 'RMonad', you can
-- use it in two ways.
-- Firstly, import this module directly and use the @RebindableSyntax@ extension
-- so that do-syntax is rebound. If using ghc<7 use @NoImplicitPrelude@ instead.
-- Secondly, use the wrapper type in "Control.RMonad.AsMonad" which supports
-- the normal 'Monad' operations.
module Control.RMonad (Suitable(..), RFunctor(..), RMonad(..), RMonadPlus(..),
                       (<=<), (=<<), (>=>), ap,
                       filterM, foldM, foldM_, forM, forM_,
                       forever, guard, join,
                       liftM, liftM2, liftM3, liftM4, liftM5,
                       mapAndUnzipM, mapM, mapM_, msum,
                       replicateM, replicateM_, sequence, sequence_,
                       unless, when, zipWithM, zipWithM_
                      ) where

import Prelude hiding (return, fail, (>>=), (>>),
                       (=<<),
                       mapM, mapM_,
                       sequence, sequence_
                      )
import Control.IfThenElse
import qualified Control.Monad as M
import Data.Set (Set)
import qualified Data.Set as Set

import Data.Suitable

class RFunctor f where
   fmap :: (Suitable f a, Suitable f b) => (a -> b) -> f a -> f b

infixl 1 >>=
infixl 1 >>

class RMonad m where
   return :: Suitable m a => a -> m a
   (>>=) :: (Suitable m a, Suitable m b) => m a -> (a -> m b) -> m b
   (>>) :: (Suitable m a, Suitable m b) => m a -> m b -> m b
   m1 >> m2 = m1 >>= \_ -> m2
   fail :: Suitable m a => String -> m a
   fail = error

class RMonad m => RMonadPlus m where
   mzero :: Suitable m a => m a
   mplus :: Suitable m a => m a -> m a -> m a

instance RFunctor ((->) r) where
   fmap = M.fmap

instance RMonad ((->) r) where
   return = M.return
   (>>=) = (M.>>=)
   fail = M.fail

instance RFunctor Maybe where
   fmap = M.fmap

instance RMonad Maybe where
   return = M.return
   (>>=) = (M.>>=)
   fail = M.fail

instance RMonadPlus Maybe where
   mzero = M.mzero
   mplus = M.mplus

instance RFunctor [] where
   fmap = M.fmap

instance RMonad [] where
   return = M.return
   (>>=) = (M.>>=)
   fail = M.fail

instance RMonadPlus [] where
   mzero = M.mzero
   mplus = M.mplus

instance RFunctor IO where
   fmap = M.fmap

instance RMonad IO where
   return = M.return
   (>>=) = (M.>>=)
   fail = M.fail

instance RFunctor Set where
   fmap f a = withConstraintsOf a $ \SetConstraints -> withResConstraints $ \SetConstraints -> Set.map f a

instance RMonad Set where
   {-# INLINE return #-}
   return = Set.singleton

   {-# INLINE (>>=) #-}
   s >>= f = withResConstraints $ \SetConstraints -> Set.fold (\a s' -> Set.union (f a) s') Set.empty s
   {-# INLINE fail #-}
   fail _ = Set.empty

instance RMonadPlus Set where
   {-# INLINE mzero #-}
   mzero = Set.empty
   {-# INLINE mplus #-}
   mplus s1 s2 = withResConstraints $ \SetConstraints -> Set.union s1 s2

infixr 1 <=<
(<=<) :: (RMonad m, Suitable m a, Suitable m b, Suitable m c) => (b -> m c) -> (a -> m b) -> a -> m c
(f <=< g) a = g a >>= f

infixr 1 =<<
(=<<) :: (RMonad m, Suitable m a, Suitable m b) => (a -> m b) -> m a -> m b
(=<<) = flip (>>=)

infixr 1 >=>
(>=>) :: (RMonad m, Suitable m a, Suitable m b, Suitable m c) => (a -> m b) -> (b -> m c) -> a -> m c
(>=>) = flip (<=<)

ap :: (RMonad m, Suitable m (a -> b), Suitable m a, Suitable m b) => m (a -> b) -> m a -> m b
ap = liftM2 ($)

filterM :: (RMonad m, Suitable m [a], Suitable m Bool) => (a -> m Bool) -> [a] -> m [a]
filterM _ [] = return []
filterM f (x:xs) = do b <- f x
                      res <- filterM f xs
                      return (if b then x:res else res)

foldM :: (RMonad m, Suitable m a) => (a -> b -> m a) -> a -> [b] -> m a
foldM _ a [] = return a
foldM f a (x:xs) = do fax <- f a x
                      foldM f fax xs

foldM_ :: (RMonad m, Suitable m a, Suitable m ()) => (a -> b -> m a) -> a -> [b] -> m ()
foldM_ f a xs = foldM f a xs >> return ()

forM :: (RMonad m, Suitable m b, Suitable m [b]) => [a] -> (a -> m b) -> m [b]
forM = flip mapM

forM_ :: (RMonad m, Suitable m b, Suitable m ()) => [a] -> (a -> m b) -> m ()
forM_ = flip mapM_

forever :: (RMonad m, Suitable m a, Suitable m b) => m a -> m b
forever ma = let mb = ma >> mb in mb

guard :: (RMonadPlus m, Suitable m ()) => Bool -> m ()
guard True = return ()
guard False = mzero

join :: (RMonad m, Suitable m a, Suitable m (m a)) => m (m a) -> m a
join mma = mma >>= id

liftM :: (RMonad m, Suitable m a1, Suitable m r) => (a1 -> r) -> m a1 -> m r
liftM f ma1 = do { a1 <- ma1 ; return (f a1) }

liftM2 :: (RMonad m, Suitable m a1, Suitable m a2, Suitable m r) => (a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 f ma1 ma2 = do { a1 <- ma1 ; a2 <- ma2 ; return (f a1 a2) }

liftM3 :: (RMonad m, Suitable m a1, Suitable m a2, Suitable m a3, Suitable m r) => (a1 -> a2 -> a3 -> r) -> m a1 -> m a2 -> m a3 -> m r
liftM3 f ma1 ma2 ma3 = do { a1 <- ma1 ; a2 <- ma2 ; a3 <- ma3 ; return (f a1 a2 a3) }

liftM4 :: (RMonad m, Suitable m a1, Suitable m a2, Suitable m a3, Suitable m a4, Suitable m r) => (a1 -> a2 -> a3 -> a4 -> r) -> m a1 -> m a2 -> m a3 -> m a4 -> m r
liftM4 f ma1 ma2 ma3 ma4 = do { a1 <- ma1 ; a2 <- ma2 ; a3 <- ma3 ; a4 <- ma4 ; return (f a1 a2 a3 a4) }

liftM5 :: (RMonad m, Suitable m a1, Suitable m a2, Suitable m a3, Suitable m a4, Suitable m a5, Suitable m r) => (a1 -> a2 -> a3 -> a4 -> a5 -> r) -> m a1 -> m a2 -> m a3 -> m a4 -> m a5 -> m r
liftM5 f ma1 ma2 ma3 ma4 ma5 = do { a1 <- ma1 ; a2 <- ma2 ; a3 <- ma3 ; a4 <- ma4 ; a5 <- ma5 ; return (f a1 a2 a3 a4 a5) }

mapAndUnzipM :: (RMonad m, Suitable m (b, c), Suitable m [(b, c)], Suitable m ([b], [c])) => (a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM f xs = liftM unzip (mapM f xs)

mapM :: (RMonad m, Suitable m b, Suitable m [b]) => (a -> m b) -> [a] -> m [b]
mapM f xs = sequence (map f xs)

mapM_ :: (RMonad m, Suitable m b, Suitable m ()) => (a -> m b) -> [a] -> m ()
mapM_ f xs = sequence_ (map f xs)

msum :: (RMonadPlus m, Suitable m a) => [m a] -> m a
msum = foldr mplus mzero

replicateM :: (RMonad m, Suitable m a, Suitable m [a]) => Int -> m a -> m [a]
replicateM n ma = sequence (replicate n ma)

replicateM_ :: (RMonad m, Suitable m a, Suitable m ()) => Int -> m a -> m ()
replicateM_ n ma = sequence_ (replicate n ma)

sequence :: (RMonad m, Suitable m a, Suitable m [a]) => [m a] -> m [a]
sequence [] = return []
sequence (ma:mas) = liftM2 (:) ma (sequence mas)

sequence_ :: (RMonad m, Suitable m a, Suitable m ()) => [m a] -> m ()
sequence_ [] = return ()
sequence_ (ma:mas) = ma >> sequence_ mas

unless :: (RMonad m, Suitable m ()) => Bool -> m () -> m ()
unless True _ = return ()
unless False m = m

when :: (RMonad m, Suitable m ()) => Bool -> m () -> m ()
when True m = m
when False _ = return ()

zipWithM :: (RMonad m, Suitable m c, Suitable m [c]) => (a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM f as bs = sequence (zipWith f as bs)

zipWithM_ :: (RMonad m, Suitable m c, Suitable m ()) => (a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ f as bs = sequence_ (zipWith f as bs)