module Control.Subcategory.Bind (CBind(..), CMonad, creturn, (-<<)) where import Control.Subcategory.Functor import Control.Subcategory.Pointed import Control.Monad (join) import qualified Control.Monad.ST.Lazy as LST import qualified Control.Monad.ST.Strict as SST import Data.Coerce (coerce) import Data.Functor.Identity (Identity) import qualified Data.Functor.Product as SOP import Data.Hashable (Hashable) import qualified Data.HashMap.Strict as HM import qualified Data.HashSet as HS import qualified Data.IntMap as IM import qualified Data.IntSet as IS import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as Map import Data.MonoTraversable import qualified Data.Semigroup as Sem import qualified Data.Sequence as Seq import qualified Data.Set as Set import qualified Data.Tree as Tree import GHC.Conc (STM) import Text.ParserCombinators.ReadP (ReadP) import Text.ParserCombinators.ReadPrec (ReadPrec) class CFunctor m => CBind m where (>>-) :: (Dom m a, Dom m b) => m a -> (a -> m b) -> m b default (>>-) :: (Dom m a, Dom m b, Dom m (m b)) => m a -> (a -> m b) -> m b m >>- f = cjoin (cmap f m) cjoin :: (Dom m (m a), Dom m a) => m (m a) -> m a cjoin = (>>- id) instance (Monad m) => CBind (WrapFunctor m) where (>>-) :: forall a b. WrapFunctor m a -> (a -> WrapFunctor m b) -> WrapFunctor m b (>>-) = coerce @(m a -> (a -> m b) -> m b) (>>=) cjoin :: forall a. WrapFunctor m (WrapFunctor m a) -> WrapFunctor m a cjoin (WrapFunctor m) = WrapFunctor $ join (fmap coerce m) instance CBind [] where (>>-) = (>>=) cjoin = concat instance CBind IO where (>>-) = (>>=) instance CBind STM where (>>-) = (>>=) instance CBind (SST.ST s) where (>>-) = (>>=) instance CBind (LST.ST s) where (>>-) = (>>=) instance CBind Identity where (>>-) = (>>=) instance CBind (Either a) where (>>-) = (>>=) instance CBind Tree.Tree where (>>-) = (>>=) instance CBind Maybe where (>>-) = (>>=) instance CBind IM.IntMap where m >>- f = IM.mapMaybeWithKey (\k -> IM.lookup k . f) m instance Ord k => CBind (Map.Map k) where m >>- f = Map.mapMaybeWithKey (\k -> Map.lookup k . f) m instance (Hashable k, Eq k) => CBind (HM.HashMap k) where m >>- f = HM.mapMaybeWithKey (\k -> HM.lookup k . f) m instance CBind Set.Set where (>>-) = flip foldMap {-# INLINE (>>-) #-} cjoin = foldMap id {-# INLINE cjoin #-} instance CBind (WrapMono IS.IntSet) where (>>-) = withMonoCoercible $ flip ofoldMap {-# INLINE (>>-) #-} instance CBind NonEmpty where (>>-) = (>>=) {-# INLINE (>>-) #-} instance CBind Seq.Seq where (>>-) = (>>=) {-# INLINE (>>-) #-} instance CBind Sem.Option where (>>-) = (>>=) {-# INLINE (>>-) #-} instance CBind ((->) a) where (>>-) = (>>=) {-# INLINE (>>-) #-} instance CBind HS.HashSet where (>>-) = flip foldMap {-# INLINE (>>-) #-} cjoin = foldMap id {-# INLINE cjoin #-} instance CBind ReadP where (>>-) = (>>=) {-# INLINE (>>-) #-} instance CBind ReadPrec where (>>-) = (>>=) {-# INLINE (>>-) #-} instance Semigroup w => CBind ((,) w) where (m, a) >>- f = let (w, b) = f a in (m <> w, b) {-# INLINE (>>-) #-} cjoin (w, (m, a)) = (w <> m, a) {-# INLINE cjoin #-} infixl 1 >>- infixr 1 -<< (-<<) :: (Dom m b, Dom m a, CBind m) => (a -> m b) -> m a -> m b (-<<) = flip (>>-) {-# INLINE (-<<) #-} instance (CBind m, CBind n) => CBind (SOP.Product m n) where (SOP.Pair a b) >>- f = SOP.Pair (a >>- fstP . f) (b >>- sndP . f) where fstP (SOP.Pair x _) = x sndP (SOP.Pair _ y) = y {-# INLINE (>>-) #-} class (CBind f, CPointed f) => CMonad f instance (CBind f, CPointed f) => CMonad f creturn :: (Dom m a, CMonad m) => a -> m a creturn = cpure {-# INLINE creturn #-}