{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}

{-|
Module     : Control.Monad.Trans.Choice.Covariant
Copyright  : (c) Eamon Olive, 2020
             (c) Louis Hyde,  2020
License    : AGPL-3
Maintainer : ejolive97@gmail.com
Stability  : experimental

-}
module Control.Monad.Trans.Choice.Covariant
  ( ChoiceT ( )
  , runChoiceT
  , mapChoiceT
  , runBacktrackableChoiceT
  ) where

-- Internal imports

import Control.Monad.Class.Choice
  ( MonadChoice (choose)
  )

-- External imports

import Control.Arrow
  ( second
  )
import Control.Monad
  ( (>=>)
  )
import Control.Monad.IO.Class
  ( MonadIO
    ( liftIO
    )
  )
import Control.Monad.Reader.Class
  ( MonadReader
    ( ask
    , local
    )
  )
import Control.Monad.RWS.Class
  ( MonadRWS
  )
import Control.Monad.State.Class
  ( MonadState
    ( state
    )
  )
import Control.Monad.Trans
  ( MonadTrans
    ( lift
    )
  )
import Control.Monad.Writer.Class
  ( MonadWriter
    ( writer
    , listen
    , pass
    )
  )

data ChoiceT f m a where
  FixedT :: a -> ChoiceT f m a
  LiftThenT :: m b -> (b -> ChoiceT f m a) -> ChoiceT f m a
  ChooseThenT :: f b -> (b -> ChoiceT f m a) -> ChoiceT f m a

instance Functor (ChoiceT f m) where
  fmap f (FixedT value) = FixedT $ f value
  fmap f (LiftThenT action next) = LiftThenT action $ fmap f . next
  fmap f (ChooseThenT options next) = ChooseThenT options $ fmap f . next

instance Applicative (ChoiceT f m) where
  pure = FixedT -- === LiftThenT (pure value) $ FixedT

  (FixedT f) <*> choice = fmap f choice
  fChoice <*> (FixedT f) = fmap ($f) fChoice
  -- Proof:
  --
  -- > LiftThenT action next = lift action >>= next
  -- > 
  -- > (a >>= b) >>= (<$> c) = (a >>= b) >>= (<$> c)
  -- > (a >>= b) >>= (<$> c) = a >>= (\x -> (b x >>= (<$> c)))
  -- > (a >>= b) <*> c = a >>= (\x -> (b x <*> c))
  (LiftThenT   fAction  nextF) <*> dependentChoice = LiftThenT   fAction  $ (<*> dependentChoice) . nextF
  (ChooseThenT fOptions nextF) <*> dependentChoice = ChooseThenT fOptions $ (<*> dependentChoice) . nextF

instance Monad (ChoiceT f m) where
  (FixedT value) >>= f = f value
  (LiftThenT action next) >>= f = LiftThenT action $ next >=> f
  (ChooseThenT options next) >>= f = ChooseThenT options $ next >=> f

instance MonadChoice f (ChoiceT f m) where
  choose options = ChooseThenT options pure

instance MonadTrans (ChoiceT f) where
  lift action = LiftThenT action pure

instance MonadReader r m => MonadReader r (ChoiceT f m) where
  ask = lift ask
  local f = mapChoiceT (local f)

instance MonadState s m => MonadState s (ChoiceT f m) where
  state = lift . state

instance MonadWriter w m => MonadWriter w (ChoiceT f m) where
  writer = lift . writer
  listen (FixedT value) = FixedT (value, mempty)
  listen (LiftThenT action next) = LiftThenT (listen action) $ \(result, output) -> fmap (second $ mappend output) $ listen $ next result
  listen (ChooseThenT options next) = ChooseThenT options $ listen . next
  pass = go mempty
    where
      go :: w -> ChoiceT f m (a, w -> w) -> ChoiceT f m a
      go acc (FixedT (value, f)) = writer (value, f acc)
      go acc (LiftThenT action next) = LiftThenT (listen action) (\(a,w)-> go (mappend acc w) $ next a)
      go acc (ChooseThenT options next) = ChooseThenT options $ go acc . next

instance MonadRWS r w s m => MonadRWS  r w s (ChoiceT f m)

instance MonadIO m => MonadIO (ChoiceT f m) where
  liftIO = lift . liftIO

mapChoiceT :: (forall x. m x -> n x) -> ChoiceT f m a -> ChoiceT f n a
mapChoiceT _ (FixedT value) = FixedT value
mapChoiceT f (LiftThenT action next) = LiftThenT (f action) (mapChoiceT f . next)
mapChoiceT f (ChooseThenT options next) = ChooseThenT options (mapChoiceT f . next)

runChoiceT :: Monad m => (forall x. f x -> m x) -> ChoiceT f m a -> m a
runChoiceT _ (FixedT value) = pure value
runChoiceT chooser (LiftThenT action next) = action >>= (runChoiceT chooser . next)
runChoiceT chooser (ChooseThenT options next) = chooser options >>= (runChoiceT chooser . next)

-- | A variant of 'runChoiceT' that allows for the selection function to
-- output 'Nothing' to represent the desire to potentially select a different
-- option for the previous choice.
runBacktrackableChoiceT ::
  forall f m a. Monad m
    =>
      (forall x. f x -> m (Maybe x)) -- ^ The selection function.
                                     -- If the result of the outputted computation is 'Nothing' then
                                     -- 'runBacktrackableChoiceT' will backtrack to the previous choice.
                                     -- Otherwise, the result is a 'Just' then it will proceed like 'runChoiceT'.
        -> ChoiceT f m a             -- ^ The choice structure to run
          -> m (Maybe a)             -- ^ The resulting computation. If the selection function backtracks on the
                                     -- first choice the result of the computation will be 'Nothing'.
runBacktrackableChoiceT _ (FixedT firstValue) = return $ Just firstValue
runBacktrackableChoiceT chooser (LiftThenT action next) = action >>= runBacktrackableChoiceT chooser . next
runBacktrackableChoiceT chooser (ChooseThenT options next) =
  do
    maybeChoice <- chooser options
    case maybeChoice of
      Nothing -> return Nothing
      Just choice -> do
        maybeResult <- runBacktrackableChoiceT chooser $ next choice
        case maybeResult of
          Nothing -> runBacktrackableChoiceT chooser $ ChooseThenT options next
          Just result -> return $ Just result