{-# LANGUAGE UndecidableInstances #-}
module Algebra.Monad.Base (
  module Algebra.Classes,module Algebra.Applicative,module Algebra.Core,
  module Algebra.Traversable,module Algebra.Lens,
  
  -- * Monad utilities
  Kleisli(..),i'Kleisli,
  (=<<),joinMap,(<=<),(>=>),(>>),(<*=),only,return,
  foldlM,foldrM,findM,while,until,
  bind2,bind3,(>>>=),(>>>>=),

  -- * Monadic Lenses
  Action,Action',
  
  -- * Instance utilities
  Compose'(..),i'Compose'
  ) where

import Algebra.Classes
import Algebra.Applicative
import Algebra.Core hiding (flip)
import Algebra.Traversable
import Algebra.Lens
import qualified Control.Monad.Fix as Fix

type Action s t a b = forall m. Monad m => LensLike m s t a b
type Action' a b = Action b b a a

instance MonadIO IO where liftIO = id
instance (MonadIO m,MonadTrans t,Monad (t m)) => MonadIO (t m) where
  liftIO = lift . liftIO

-- MonadFix instances
instance MonadFix Id where mfix = cfix
instance MonadFix ((->) b) where mfix = cfix
instance MonadFix [] where mfix f = fix (f . head)
instance MonadFix (Either e) where mfix f = fix (f . either undefined id)
instance MonadFix IO where mfix = Fix.mfix

instance (Traversable g,Monad f,Monad g) => Monad (f:.:g) where
  join = Compose .map join.join.map sequence.getCompose.map getCompose
instance (MonadFix f,Traversable g,Monad g) => MonadFix (f:.:g) where
  mfix f = Compose $ mfix (map join . traverse (getCompose . f))
instance Monad m => MonadTrans ((:.:) m) where
  lift = Compose . pure
instance Monad m => ConcreteMonad ((:.:) m) where
  generalize = i'Compose %%~ map (pure.yb i'Id)

instance MonadFix m => Monad (Backwards m) where
  join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma)
instance MonadFuture m (Backwards m) where
  future = Backwards
instance MonadFix m => MonadFix (Backwards m) where
  mfix f = by i'Backwards $ mfix (yb i'Backwards.f)
instance MonadTrans Backwards where
  lift = Backwards
instance ConcreteMonad Backwards where
  generalize = i'Backwards %%~ pure.yb i'Id

newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b }
instance Functor f => Functor (Kleisli f a) where
  map f (Kleisli k) = Kleisli (map2 f k)
instance Contravariant f => Contravariant (Kleisli f a) where
  collect f = Kleisli (\a -> project (($a) . runKleisli) f)
instance Monad m => Category (Kleisli m) where
  id = Kleisli pure
  Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f)
instance Monad m => Choice (Kleisli m) where
  Kleisli f <|> Kleisli g = Kleisli (f <|> g)
instance Monad m => Split (Kleisli m) where
  Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c)
instance Isomorphic (a -> m b) (c -> m' d) (Kleisli m a b) (Kleisli m' c d) where
  i'_ = iso Kleisli runKleisli

cfix :: Contravariant c => (a -> c a) -> c a
cfix = map fix . collect

mfixing :: MonadFix f => (b -> f (a, b)) -> f a
mfixing f = fst<$>mfix (\ ~(_,b) -> f b )

i'Kleisli :: Iso (Kleisli m a b) (Kleisli m' c d) (a -> m b) (c -> m' d)
i'Kleisli = i'_ 

folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c  
folding i f e t = yb i (foldMap (by i . f) t) e
foldlM :: (Foldable t,Monad m) => (a -> b -> m a) -> a -> t b -> m a
foldlM = folding (i'Kleisli.i'Endo) . flip
foldrM :: (Foldable t,Monad m) => (b -> a -> m a) -> t b -> a -> m a
foldrM = flip . folding (i'Kleisli.i'Endo.i'Dual)
findM :: (Foldable t,Monad m) => (a -> m (Maybe b)) -> t a -> m (Maybe b)
findM f = foldr fun (return Nothing)
  where fun a b = maybe b (return . Just) =<< f a

while :: Monad m => m Bool -> m ()
while e = fix (\w -> e >>= bool w unit)
until :: Monad m => m (Maybe a) -> m a
until e = fix (\w -> e >>= maybe w return)

bind2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c
bind2 f a b = join (f<$>a<*>b)
(>>>=) :: Monad m => (m a,m b) -> (a -> b -> m c) -> m c
(a,b) >>>= f = bind2 f a b
bind3 :: Monad m => (a -> b -> c -> m d) -> m a -> m b -> m c -> m d
bind3 f a b c = join (f<$>a<*>b<*>c)
(>>>>=) :: Monad m => (m a,m b,m c) -> (a -> b -> c -> m d) -> m d
(a,b,c) >>>>= f = bind3 f a b c

infixr 2 =<<
infixl 1 <*=,>>
(>>) :: Applicative f => f a -> f b -> f b
(>>) = (*>)
(=<<) :: Monad m => (a -> m b) -> m a -> m b
(=<<) = flip (>>=)
(<=<) :: Monad m => (b -> m c) -> (a -> m b) -> (a -> m c)
f <=< g = \a -> g a >>= f
(>=>) :: Monad m => (a -> m b) -> (b -> m c) -> (a -> m c)
(>=>) = flip (<=<)
(<*=) :: Monad m => m a -> (a -> m b) -> m a
a <*= f = a >>= ((>>)<$>f<*>return)
only :: (Monoid (m ()),Monad m) => (a -> Bool) -> m a -> m a
only p m = m <*= guard . p
return :: Unit f => a -> f a
return = pure

joinMap :: Monad m => (a -> m b) -> m a -> m b
joinMap = (=<<)

newtype Compose' f g a = Compose' ((g:.:f) a)
                       deriving (Semigroup,Monoid,Unit,Functor,Applicative,Monad,MonadFix,Foldable,Traversable)
i'Compose' :: Iso (Compose' f g a) (Compose' h i b) (g (f a)) (i (h b))
i'Compose' = i'Compose.iso Compose' (\(Compose' c) -> c)
instance Monad m => MonadTrans (Compose' m) where
  lift = by i'Compose' . map pure
instance Monad m => ConcreteMonad (Compose' m) where
  generalize = i'Compose' %%~ pure . yb i'Id