{-# LANGUAGE ExplicitForAll         #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE UndecidableInstances   #-}
{-# LANGUAGE AllowAmbiguousTypes    #-}
{-# LANGUAGE TypeInType             #-}
{-# LANGUAGE RankNTypes             #-}
{-# LANGUAGE BangPatterns #-}

{-# EXT      InlineAll              #-}
-- {-# LANGUAGE Strict #-} -- https://ghc.haskell.org/trac/ghc/ticket/14815#ticket

module Control.Monad.State.Layered where

import Prelude hiding (Monad) -- hiding ((.)) -- https://ghc.haskell.org/trac/ghc/ticket/14001

import Control.Applicative
import Control.Lens.Utils hiding (Getter, Setter, Setter')
import Control.Monad.Branch
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Identity hiding (Monad)
import Control.Monad.IO.Class
import Control.Monad.Primitive
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Constraint
import Data.Default
import Data.Kind (Type)
import Type.Bool

import qualified Prelude as P
import qualified Control.Monad.State.Strict as S

type M = P.Monad


-------------------
-- === State === --
-------------------

-- === Definition === --

type    State  s     = StateT s Identity
newtype StateT s m a = StateT (S.StateT s m a)
    deriving ( Applicative, Alternative, Functor, M, MonadFail, MonadFix
             , MonadIO, MonadPlus, MonadTrans, MonadThrow, MonadBranch )
makeWrapped ''StateT

type        States  ss = StatesT ss Identity
type family StatesT ss m where
    StatesT '[]       m = m
    StatesT (s ': ss) m = StateT s (StatesT ss m)


-- === State data discovery === --

type        InferStateData    (l :: k) m = InferStateData' k l m
type family InferStateData' k (l :: k) m where
    InferStateData' Type l m = l
    InferStateData' k    l m = StateData l m

type TransStateData l t m = (InferStateData l (t m) ~ InferStateData l m)

type family StateData l m where
    StateData l (StateT s m) = If (MatchedBases l s) s (StateData l m)
    StateData l (t m)        = StateData l m

type family MatchedBases (a :: ka) (b :: kb) :: Bool where
    MatchedBases (a :: k) (b   :: k) = a == b
    MatchedBases (a :: k) (b t :: l) = MatchedBases a b
    MatchedBases (a :: k) (b   :: l) = 'False


-- === Monad === --

-- Definitions

type         Monad  l m = (Getter l m, Setter l m)
class M m => Getter l m where get :: m (InferStateData l m)
class M m => Setter l m where put :: InferStateData l m -> m ()

-- Instancess

instance                       M m                                                      => Getter (l :: Type) (StateT l m) where get   = wrap   S.get    ; {-# INLINE get #-}
instance                       M m                                                      => Setter (l :: Type) (StateT l m) where put a = wrap $ S.put a  ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-}  Getter l m                                               => Getter (l :: Type) (StateT s m) where get   = lift $ get @l   ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-}  Setter l m                                               => Setter (l :: Type) (StateT s m) where put a = lift $ put @l a ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-} (M m, Getter__ ok l (StateT s m), ok ~ MatchedBases l s)  => Getter (l :: k)    (StateT s m) where get   = get__  @ok @l   ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-} (M m, Setter__ ok l (StateT s m), ok ~ MatchedBases l s)  => Setter (l :: k)    (StateT s m) where put a = put__  @ok @l a ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-} (M (t m), MonadTrans t, Getter l m, TransStateData l t m) => Getter (l :: k)    (t m)        where get   = lift $ get @l   ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-} (M (t m), MonadTrans t, Setter l m, TransStateData l t m) => Setter (l :: k)    (t m)        where put a = lift $ put @l a ; {-# INLINE put #-}

-- Helpers

class M m => Getter__ (ok :: Bool) l m where get__ :: m (InferStateData l m)
class M m => Setter__ (ok :: Bool) l m where put__ :: InferStateData l m -> m ()

instance (M m, InferStateData l (StateT s m) ~ s)    => Getter__ 'True  l (StateT s m) where get__   = get @s          ; {-# INLINE get__ #-}
instance (M m, InferStateData l (StateT s m) ~ s)    => Setter__ 'True  l (StateT s m) where put__ a = put @s a        ; {-# INLINE put__ #-}
instance (Getter l m, TransStateData l (StateT s) m) => Getter__ 'False l (StateT s m) where get__   = lift $ get @l   ; {-# INLINE get__ #-}
instance (Setter l m, TransStateData l (StateT s) m) => Setter__ 'False l (StateT s m) where put__ a = lift $ put @l a ; {-# INLINE put__ #-}

-- Replicators

type MonadStates  ss m = (Getters ss m, Setters ss m)
type Getters ss m = Monads__ Getter ss m
type Setters ss m = Monads__ Setter ss m
type family Monads__ p ss m :: Constraint where
    Monads__ p (s ': ss) m = (p s m, Monads__ p ss m)
    Monads__ p '[]       m = ()


-- === Accessing === --

gets ::  l m s a. (Getter l m, s ~ InferStateData l m)
     => Lens' s a -> m a
gets l = view l <$> get @l ; {-# INLINE gets #-}



-- === Top state accessing === --

type family TopStateData m where
    TopStateData (StateT s m) = s
    TopStateData (t m)        = TopStateData m

type Monad'  m = (Getter' m, Setter' m)
type Getter' m = Getter (TopStateData m) m
type Setter' m = Setter (TopStateData m) m

get' ::  m. Getter' m => m (TopStateData m)
put' ::  m. Setter' m => TopStateData m -> m ()
get' = get @(TopStateData m) ; {-# INLINE get' #-}
put' = put @(TopStateData m) ; {-# INLINE put' #-}

gets' ::  m s a. (Getter' m, s ~ TopStateData m) => Lens' s a -> m a
gets' l = view l <$> get' ; {-# INLINE gets' #-}


-- === Construction & running === --

stateT :: (s -> m (a,s)) -> StateT s m a
stateT = StateT . S.StateT ; {-# INLINE stateT #-}

runT  ::  s m a.              StateT s m a -> s -> m (a, s)
evalT ::  s m a. S.Monad m => StateT s m a -> s -> m a
execT ::  s m a. S.Monad m => StateT s m a -> s -> m s
runT  = S.runStateT . unwrap ; {-# INLINE runT  #-}
evalT !m !s = let fst (!a, !_) = a in fst <$> runT m s ; {-# INLINE evalT #-}
execT !m !s = let snd (!_, !a) = a in snd <$> runT m s ; {-# INLINE execT #-}

runDefT  ::  s m a.            Default s  => StateT s m a -> m (a, s)
evalDefT ::  s m a. (S.Monad m,Default s) => StateT s m a -> m a
execDefT ::  s m a. (S.Monad m,Default s) => StateT s m a -> m s
runDefT  !s = runT  s def ; {-# INLINE runDefT  #-}
evalDefT !s = evalT s def ; {-# INLINE evalDefT #-}
execDefT !s = execT s def ; {-# INLINE execDefT #-}

run  ::  s a. State s a -> s -> (a, s)
eval ::  s a. State s a -> s -> a
exec ::  s a. State s a -> s -> s
run  = S.runState  . unwrap ; {-# INLINE run  #-}
eval = S.evalState . unwrap ; {-# INLINE eval #-}
exec = S.execState . unwrap ; {-# INLINE exec #-}

runDef  ::  s a. Default s => State s a -> (a, s)
evalDef ::  s a. Default s => State s a -> a
execDef ::  s a. Default s => State s a -> s
runDef  = flip run  def ; {-# INLINE runDef  #-}
evalDef = flip eval def ; {-# INLINE evalDef #-}
execDef = flip exec def ; {-# INLINE execDef #-}


-- === Generic state modification === --

type InferMonadState s l m = (Monad l m, s ~ InferStateData l m)
modifyM  ::  l s m a. InferMonadState s l m => (s -> m (a, s)) -> m a
modifyM_ ::  l s m a. InferMonadState s l m => (s -> m     s)  -> m ()
modify   ::  l s m a. InferMonadState s l m => (s ->   (a, s)) -> m a
modify_  ::  l s m a. InferMonadState s l m => (s ->       s)  -> m ()
modify    = modifyM  @l . fmap return       ; {-# INLINE modify   #-}
modify_   = modifyM_ @l . fmap return       ; {-# INLINE modify_  #-}
modifyM_  = modifyM  @l . (fmap.fmap) ((),) ; {-# INLINE modifyM_ #-}
modifyM f = do (!a,!t) <- f =<< get @l
               a <$ put @l t
{-# INLINE modifyM #-}

sub           ::  l s m a. InferMonadState s l m =>               m a -> m a
with          ::  l s m a. InferMonadState s l m => s          -> m a -> m a
withModified  ::  l s m a. InferMonadState s l m => (s ->   s) -> m a -> m a
withModifiedM ::  l s m a. InferMonadState s l m => (s -> m s) -> m a -> m a
with              = withModified  @l . const       ; {-# INLINE with          #-}
withModified      = withModifiedM @l . fmap return ; {-# INLINE withModified  #-}
withModifiedM f m = sub @l $ modifyM_ @l f >> m    ; {-# INLINE withModifiedM #-}
sub             m = do s <- get @l
                       m <* put @l s
{-#INLINE sub #-}


-- === Top level state modification === --

type TopMonadState s m = (Monad' m, s ~ TopStateData m)
modifyM'  ::  s m a. TopMonadState s m => (s -> m (a, s)) -> m a
modifyM'_ ::  s m a. TopMonadState s m => (s -> m     s)  -> m ()
modify'   ::  s m a. TopMonadState s m => (s ->   (a, s)) -> m a
modify'_  ::  s m a. TopMonadState s m => (s ->       s)  -> m ()
modify'    = modifyM'  . fmap return       ; {-# INLINE modify'   #-}
modify'_   = modifyM'_ . fmap return       ; {-# INLINE modify'_  #-}
modifyM'_  = modifyM'  . (fmap.fmap) ((),) ; {-# INLINE modifyM'_ #-}
modifyM' f = do (!a,!t) <- f =<< get'
                a <$ put' t
{-# INLINE modifyM' #-}

sub'           ::  s m a. TopMonadState s m =>               m a -> m a
with'          ::  s m a. TopMonadState s m => s          -> m a -> m a
withModified'  ::  s m a. TopMonadState s m => (s ->   s) -> m a -> m a
withModifiedM' ::  s m a. TopMonadState s m => (s -> m s) -> m a -> m a
with'              = withModified'  . const       ; {-# INLINE with'          #-}
withModified'      = withModifiedM' . fmap return ; {-# INLINE withModified'  #-}
withModifiedM' f m = sub' $ modifyM'_ f >> m      ; {-# INLINE withModifiedM' #-}
sub'             m = do s <- get'
                        m <* put' s
{-# INLINE sub' #-}


-- === Other modifications === --

mapT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapT f = _Wrapped %~ S.mapStateT f ; {-# INLINE mapT #-}


-- === Instances === --

instance PrimMonad m => PrimMonad (StateT s m) where
  type PrimState (StateT s m) = PrimState m
  primitive = lift . primitive ; {-# INLINE primitive #-}