{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 800
{-# LANGUAGE UndecidableSuperClasses #-}
#endif
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
module Control.Monad.Constrained( module Control.Applicative.Constrained
, Monad(..), return, (>>=), (=<<), (>>), (<<)
, (>=>), (<=<)
, Kleisli(..)
, MonadZero(..), mzero, MonadPlus(..), mplus
, MonadFail(..)
, mapM, mapM_, forM, forM_, sequence, sequence_
, guard, when, unless
, forever, void
, filterM
) where
import Control.Applicative.Constrained
import Data.Foldable.Constrained
import Data.Traversable.Constrained
import Data.Tagged
import Prelude hiding (
id, const, fst, snd, (.), ($)
, Functor(..), Applicative(..), Monad(..), (=<<)
, uncurry, curry, filter
, mapM, mapM_, sequence, sequence_
)
import qualified Control.Category.Hask as Hask
import Control.Arrow.Constrained
class ( Applicative m k k
, Object k (m (UnitObject k)), Object k (m (m (UnitObject k)))
) => Monad m k where
join :: (Object k a, Object k (m a), Object k (m (m a)))
=> m (m a) `k` m a
return :: Monad m (->) => a -> m a
return = pure
infixr 1 =<<
(=<<) :: ( Monad m k, Object k a, Object k b
, Object k (m a), Object k (m b), Object k (m (m b)) )
=> k a (m b) -> k (m a) (m b)
(=<<) q = join . fmap q
infixl 1 >>=
(>>=) :: ( Function f, Monad m f, Object f a, Object f b
, Object f (m a), Object f (m b), Object f (m (m b)) )
=> m a -> f a (m b) -> m b
g >>= h = (=<<) h $ g
infixr 1 <<
(<<) :: ( Monad m k, WellPointed k
, Object k a, Object k b, Object k (m a), ObjectPoint k (m b), Object k (m (m b))
) => m b -> k (m a) (m b)
(<<) b = join . fmap (const b)
infixl 1 >>
(>>) :: ( WellPointed k, Monad m k
, ObjectPair k b (UnitObject k), ObjectPair k (m b) (UnitObject k)
, ObjectPair k (UnitObject k) (m b), ObjectPair k b a
, ObjectPair k a b, Object k (m (a,b)), ObjectPair k (m a) (m b)
, ObjectPoint k (m a)
) => m a -> k (m b) (m b)
(>>) a = fmap snd . fzip . first (globalElement a) . swap . attachUnit
instance (Hask.Applicative m, Hask.Monad m) => Monad m (->) where
join = Hask.join
class (Monad m k) => MonadZero m k where
fmzero :: (Object k a, Object k (m a)) => UnitObject k `k` m a
mzero :: (MonadZero m (->)) => m a
mzero = fmzero ()
class (MonadZero m k) => MonadPlus m k where
fmplus :: (ObjectPair k (m a) (m a)) => k (m a, m a) (m a)
mplus :: (MonadPlus m (->)) => m a -> m a -> m a
mplus = curry fmplus
instance (Hask.MonadPlus m, Hask.Applicative m) => MonadZero m (->) where
fmzero = const Hask.mzero
instance (Hask.MonadPlus m, Hask.Applicative m) => MonadPlus m (->) where
fmplus = uncurry Hask.mplus
class (MonadPlus m k) => MonadFail m k where
fail :: (Object k (m a)) => k String (m a)
instance (Hask.MonadPlus m, Hask.Applicative m) => MonadFail m (->) where
fail = Hask.fail
infixr 1 >=>, <=<
(>=>) :: ( Monad m k, Object k a, Object k b, Object k c
, Object k (m b), Object k (m c), Object k (m (m c)))
=> a `k` m b -> b `k` m c -> a `k` m c
f >=> g = join . fmap g . f
(<=<) :: ( Monad m k, Object k a, Object k b, Object k c
, Object k (m b), Object k (m c), Object k (m (m c)))
=> b `k` m c -> a `k` m b -> a `k` m c
f <=< g = join . fmap f . g
newtype Kleisli m k a b = Kleisli { runKleisli :: k a (m b) }
instance (Monad m k) => Category (Kleisli m k) where
type Object (Kleisli m k) o = (Object k o, Object k (m o), Object k (m (m o)))
id = Kleisli pure
Kleisli a . Kleisli b = Kleisli $ join . fmap a . b
instance ( Monad m a, Cartesian a ) => Cartesian (Kleisli m a) where
type PairObjects (Kleisli m a) b c
= ( ObjectPair a b c
, ObjectPair a (m b) c, ObjectPair a b (m c), ObjectPair a (m b) (m c) )
type UnitObject (Kleisli m a) = UnitObject a
swap = Kleisli $ pure . swap
attachUnit = Kleisli $ pure . attachUnit
detachUnit = Kleisli $ pure . detachUnit
regroup = Kleisli $ pure . regroup
regroup' = Kleisli $ pure . regroup'
instance ( Monad m k, CoCartesian k
, Object k (m (ZeroObject k)), Object k (m (m (ZeroObject k)))
) => CoCartesian (Kleisli m k) where
type SumObjects (Kleisli m k) b c
= ( ObjectSum k b c
, ObjectSum k (m b) c, ObjectSum k b (m c), ObjectSum k (m b) (m c) )
type ZeroObject (Kleisli m k) = ZeroObject k
coSwap = Kleisli $ pure . coSwap
attachZero = Kleisli $ pure . attachZero
detachZero = Kleisli $ pure . detachZero
coRegroup = Kleisli $ pure . coRegroup
coRegroup' = Kleisli $ pure . coRegroup'
maybeAsSum = Kleisli $ pure . maybeAsSum
maybeFromSum = Kleisli $ pure . maybeFromSum
boolAsSum = Kleisli $ pure . boolAsSum
boolFromSum = Kleisli $ pure . boolFromSum
instance ( Monad m a, Arrow a (->), Function a ) => Curry (Kleisli m a) where
type MorphObjects (Kleisli m a) c d
= ( Object a (Kleisli m a c d), Object a (m (Kleisli m a c d))
, Object a (a c (m d))
, ObjectMorphism a c d, ObjectMorphism a c (m d), ObjectMorphism a c (m (m d)) )
curry (Kleisli fUnc) = Kleisli $ pure . arr Kleisli . curry fUnc
uncurry (Kleisli fCur) = Kleisli . arr $
\(b,c) -> join . fmap (arr $ ($c) . runKleisli) . fCur $ b
instance (Monad m a, Arrow a q, Cartesian a) => EnhancedCat (Kleisli m a) q where
arr f = Kleisli $ pure . arr f
instance (Monad m a, Morphism a, Curry a) => Morphism (Kleisli m a) where
first (Kleisli f) = Kleisli $ fzip . (f *** pure)
second (Kleisli f) = Kleisli $ fzip . (pure *** f)
Kleisli f *** Kleisli g = Kleisli $ fzip . (f *** g)
instance (Monad m a, PreArrow a, Curry a) => PreArrow (Kleisli m a) where
Kleisli f &&& Kleisli g = Kleisli $ fzip . (f &&& g)
terminal = Kleisli $ pure . terminal
fst = Kleisli $ pure . fst
snd = Kleisli $ pure . snd
instance (SPDistribute k, Monad m k, PreArrow (Kleisli m k), PreArrChoice (Kleisli m k))
=> SPDistribute (Kleisli m k) where
distribute = Kleisli $ pure . distribute
unDistribute = Kleisli $ pure . unDistribute
boolAsSwitch = Kleisli $ pure . boolAsSwitch
boolFromSwitch = Kleisli $ pure . boolFromSwitch
instance (Monad m a, WellPointed a, ObjectPoint a (m (UnitObject a)))
=> WellPointed (Kleisli m a) where
type PointObject (Kleisli m a) b = (PointObject a b, PointObject a (m b))
globalElement x = Kleisli $ fmap (globalElement x) . pureUnit
unit = kleisliUnit
instance ( Monad m k, Arrow k (->), Function k, PreArrChoice k
, Object k (m (ZeroObject k)), Object k (m (m (ZeroObject k)))
) => MorphChoice (Kleisli m k) where
left (Kleisli f) = Kleisli . arr $ \case { Left x -> fmap coFst . f $ x
; Right y-> (pure . coSnd)`inCategoryOf`f $ y }
right(Kleisli f) = Kleisli . arr $ \case { Left x -> (pure . coFst)`inCategoryOf`f $ x
; Right y-> fmap coSnd . f $ y }
Kleisli f +++ Kleisli g = Kleisli . arr $ \case
Left x -> fmap coFst . f $ x
Right y -> fmap coSnd . g $ y
instance ( Monad m k, Arrow k (->), Function k, PreArrChoice k
, Object k (m (ZeroObject k)), Object k (m (m (ZeroObject k)))
) => PreArrChoice (Kleisli m k) where
Kleisli f ||| Kleisli g = Kleisli $ f ||| g
initial = Kleisli $ pure . initial
coFst = Kleisli $ pure . coFst
coSnd = Kleisli $ pure . coSnd
kleisliUnit :: forall m a . (Monad m a, WellPointed a)
=> CatTagged (Kleisli m a) (UnitObject a)
kleisliUnit = retag (unit :: CatTagged a (UnitObject a))
guard ::( MonadPlus m k, Arrow k (->), Function k
, UnitObject k ~ (), Object k Bool
) => Bool `k` m ()
guard = i . choose fmzero pure
where i = id
when :: ( Monad m k, PreArrow k, u ~ UnitObject k
, ObjectPair k (m u) u
) => Bool -> m u `k` m u
when True = id
when False = pure . terminal
unless :: ( Monad m k, PreArrow k, u ~ UnitObject k
, ObjectPair k (m u) u
) => Bool -> m u `k` m u
unless False = id
unless True = pure . terminal
filterM :: ( PreArrow k, Monad m k, SumToProduct c k k, EndoTraversable c k
, ObjectPair k Bool a, Object k (c a), Object k (m (c a))
, ObjectPair k (Bool, a) (c (Bool, a))
, ObjectPair k (m Bool) (m a)
, ObjectPair k (m (Bool, a)) (m (c (Bool, a)))
, TraversalObject k c (Bool, a)
) => a `k` m Bool -> c a `k` m (c a)
filterM pg = fmap (fmap snd <<< filter fst) <<< mapM (fzip <<< pg &&& pure)
forever :: ( Monad m k, Function k, Arrow k (->), Object k a, Object k b
, Object k (m a), Object k (m (m a)), ObjectPoint k (m b), Object k (m (m b))
) => m a `k` m b
forever = i . arr loop
where loop a = (join . fmap (const $ loop a)) `inCategoryOf` i $ a
i = id
void :: ( Monad m k, PreArrow k
, Object k a, Object k (m a), ObjectPair k a u, u ~ UnitObject k
) => m a `k` m (UnitObject k)
void = fmap terminal