{-# LANGUAGE CPP #-}
module Control.Subcategory.Bind (CBind(..), CMonad, creturn, (-<<)) where
import Control.Subcategory.Functor
import Control.Subcategory.Pointed

import           Control.Monad                   (join)
import qualified Control.Monad.ST.Lazy           as LST
import qualified Control.Monad.ST.Strict         as SST
import           Data.Coerce                     (coerce)
import           Data.Functor.Identity           (Identity)
import qualified Data.Functor.Product            as SOP
import           Data.Hashable                   (Hashable)
import qualified Data.HashMap.Strict             as HM
import qualified Data.HashSet                    as HS
import qualified Data.IntMap                     as IM
import qualified Data.IntSet                     as IS
import           Data.List.NonEmpty              (NonEmpty)
import qualified Data.Map                        as Map
import           Data.MonoTraversable
#if !MIN_VERSION_base(4,16,0)
import qualified Data.Semigroup                  as Sem
#endif
import qualified Data.Sequence                   as Seq
import qualified Data.Set                        as Set
import qualified Data.Tree                       as Tree
import           GHC.Conc                        (STM)
import           Text.ParserCombinators.ReadP    (ReadP)
import           Text.ParserCombinators.ReadPrec (ReadPrec)

class CFunctor m => CBind m where
  (>>-) :: (Dom m a, Dom m b) => m a -> (a -> m b) -> m b
  default (>>-) :: (Dom m a, Dom m b, Dom m (m b)) => m a -> (a -> m b) -> m b
  m a
m >>- a -> m b
f = forall (m :: * -> *) a.
(CBind m, Dom m (m a), Dom m a) =>
m (m a) -> m a
cjoin (forall (f :: * -> *) a b.
(CFunctor f, Dom f a, Dom f b) =>
(a -> b) -> f a -> f b
cmap a -> m b
f m a
m)
  cjoin :: (Dom m (m a), Dom m a) => m (m a) -> m a
  cjoin = (forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- forall a. a -> a
id)

instance (Monad m) => CBind (WrapFunctor m) where
  (>>-) :: forall a b.
           WrapFunctor m a
        -> (a -> WrapFunctor m b) -> WrapFunctor m b
  >>- :: forall a b.
WrapFunctor m a -> (a -> WrapFunctor m b) -> WrapFunctor m b
(>>-) = coerce :: forall a b. Coercible a b => a -> b
coerce @(m a -> (a -> m b) -> m b) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  cjoin :: forall a. WrapFunctor m (WrapFunctor m a) -> WrapFunctor m a
  cjoin :: forall a. WrapFunctor m (WrapFunctor m a) -> WrapFunctor m a
cjoin (WrapFunctor m (WrapFunctor m a)
m) = forall (f :: * -> *) a. f a -> WrapFunctor f a
WrapFunctor forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap coerce :: forall a b. Coercible a b => a -> b
coerce m (WrapFunctor m a)
m)

instance CBind [] where
  >>- :: forall a b. (Dom [] a, Dom [] b) => [a] -> (a -> [b]) -> [b]
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  cjoin :: forall a. (Dom [] [a], Dom [] a) => [[a]] -> [a]
cjoin  = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat

instance CBind IO where
  >>- :: forall a b. (Dom IO a, Dom IO b) => IO a -> (a -> IO b) -> IO b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind STM where
  >>- :: forall a b.
(Dom STM a, Dom STM b) =>
STM a -> (a -> STM b) -> STM b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (SST.ST s) where
  >>- :: forall a b.
(Dom (ST s) a, Dom (ST s) b) =>
ST s a -> (a -> ST s b) -> ST s b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (LST.ST s) where
  >>- :: forall a b.
(Dom (ST s) a, Dom (ST s) b) =>
ST s a -> (a -> ST s b) -> ST s b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Identity where
  >>- :: forall a b.
(Dom Identity a, Dom Identity b) =>
Identity a -> (a -> Identity b) -> Identity b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (Either a) where
  >>- :: forall a b.
(Dom (Either a) a, Dom (Either a) b) =>
Either a a -> (a -> Either a b) -> Either a b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Tree.Tree where
  >>- :: forall a b.
(Dom Tree a, Dom Tree b) =>
Tree a -> (a -> Tree b) -> Tree b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Maybe where
  >>- :: forall a b.
(Dom Maybe a, Dom Maybe b) =>
Maybe a -> (a -> Maybe b) -> Maybe b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind IM.IntMap where
  IntMap a
m >>- :: forall a b.
(Dom IntMap a, Dom IntMap b) =>
IntMap a -> (a -> IntMap b) -> IntMap b
>>- a -> IntMap b
f = forall a b. (Int -> a -> Maybe b) -> IntMap a -> IntMap b
IM.mapMaybeWithKey (\Int
k -> forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IntMap b
f) IntMap a
m

instance Ord k => CBind (Map.Map k) where
  Map k a
m >>- :: forall a b.
(Dom (Map k) a, Dom (Map k) b) =>
Map k a -> (a -> Map k b) -> Map k b
>>- a -> Map k b
f = forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
Map.mapMaybeWithKey (\k
k -> forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Map k b
f) Map k a
m

instance (Hashable k, Eq k) => CBind (HM.HashMap k) where
  HashMap k a
m >>- :: forall a b.
(Dom (HashMap k) a, Dom (HashMap k) b) =>
HashMap k a -> (a -> HashMap k b) -> HashMap k b
>>- a -> HashMap k b
f = forall k v1 v2.
(k -> v1 -> Maybe v2) -> HashMap k v1 -> HashMap k v2
HM.mapMaybeWithKey (\k
k -> forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup k
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> HashMap k b
f) HashMap k a
m

instance CBind Set.Set where
  >>- :: forall a b.
(Dom Set a, Dom Set b) =>
Set a -> (a -> Set b) -> Set b
(>>-) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
  {-# INLINE (>>-) #-}
  cjoin :: forall a. (Dom Set (Set a), Dom Set a) => Set (Set a) -> Set a
cjoin = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. a -> a
id
  {-# INLINE cjoin #-}

instance CBind (WrapMono IS.IntSet) where
  >>- :: forall a b.
(Dom (WrapMono IntSet) a, Dom (WrapMono IntSet) b) =>
WrapMono IntSet a -> (a -> WrapMono IntSet b) -> WrapMono IntSet b
(>>-) = forall mono r.
(Coercible (WrapMono mono (Element mono)) mono => r) -> r
withMonoCoercible forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall mono m.
(MonoFoldable mono, Monoid m) =>
(Element mono -> m) -> mono -> m
ofoldMap
  {-# INLINE (>>-) #-}

instance CBind NonEmpty where
  >>- :: forall a b.
(Dom NonEmpty a, Dom NonEmpty b) =>
NonEmpty a -> (a -> NonEmpty b) -> NonEmpty b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind Seq.Seq where
  >>- :: forall a b.
(Dom Seq a, Dom Seq b) =>
Seq a -> (a -> Seq b) -> Seq b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}
#if !MIN_VERSION_base(4,16,0)
instance CBind Sem.Option where
  (>>-) = (>>=)
  {-# INLINE (>>-) #-}
#endif

instance CBind ((->) a) where
  >>- :: forall a b.
(Dom ((->) a) a, Dom ((->) a) b) =>
(a -> a) -> (a -> a -> b) -> a -> b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind HS.HashSet where
  >>- :: forall a b.
(Dom HashSet a, Dom HashSet b) =>
HashSet a -> (a -> HashSet b) -> HashSet b
(>>-) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
  {-# INLINE (>>-) #-}
  cjoin :: forall a.
(Dom HashSet (HashSet a), Dom HashSet a) =>
HashSet (HashSet a) -> HashSet a
cjoin = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. a -> a
id
  {-# INLINE cjoin #-}

instance CBind ReadP where
  >>- :: forall a b.
(Dom ReadP a, Dom ReadP b) =>
ReadP a -> (a -> ReadP b) -> ReadP b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind ReadPrec where
  >>- :: forall a b.
(Dom ReadPrec a, Dom ReadPrec b) =>
ReadPrec a -> (a -> ReadPrec b) -> ReadPrec b
(>>-) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance Semigroup w => CBind ((,) w) where
  (w
m, a
a) >>- :: forall a b.
(Dom ((,) w) a, Dom ((,) w) b) =>
(w, a) -> (a -> (w, b)) -> (w, b)
>>- a -> (w, b)
f =
    let (w
w, b
b) = a -> (w, b)
f a
a
    in (w
m forall a. Semigroup a => a -> a -> a
<> w
w, b
b)
  {-# INLINE (>>-) #-}
  cjoin :: forall a.
(Dom ((,) w) (w, a), Dom ((,) w) a) =>
(w, (w, a)) -> (w, a)
cjoin (w
w, (w
m, a
a)) = (w
w forall a. Semigroup a => a -> a -> a
<> w
m, a
a)
  {-# INLINE cjoin #-}

infixl 1 >>-
infixr 1 -<<

(-<<) :: (Dom m b, Dom m a, CBind m) => (a -> m b) -> m a -> m b
-<< :: forall (m :: * -> *) b a.
(Dom m b, Dom m a, CBind m) =>
(a -> m b) -> m a -> m b
(-<<) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
(>>-)
{-# INLINE (-<<) #-}

instance (CBind m, CBind n) => CBind (SOP.Product m n) where
  (SOP.Pair m a
a n a
b) >>- :: forall a b.
(Dom (Product m n) a, Dom (Product m n) b) =>
Product m n a -> (a -> Product m n b) -> Product m n b
>>- a -> Product m n b
f = forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
SOP.Pair (m a
a forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> f a
fstP forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Product m n b
f) (n a
b forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- forall {k} {f :: k -> *} {g :: k -> *} {a :: k}.
Product f g a -> g a
sndP forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Product m n b
f)
    where
      fstP :: Product f g a -> f a
fstP (SOP.Pair f a
x g a
_) = f a
x
      sndP :: Product f g a -> g a
sndP (SOP.Pair f a
_ g a
y) = g a
y
  {-# INLINE (>>-) #-}

class    (CBind f, CPointed f) => CMonad f
instance (CBind f, CPointed f) => CMonad f

creturn :: (Dom m a, CMonad m) => a -> m a
creturn :: forall (m :: * -> *) a. (Dom m a, CMonad m) => a -> m a
creturn = forall (f :: * -> *) a. (CPointed f, Dom f a) => a -> f a
cpure
{-# INLINE creturn #-}