{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PatternSynonyms            #-}
{-# LANGUAGE Strict                     #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE ViewPatterns               #-}

-- | This module provides monad transformer similar to
-- 'Control.Monad.Writer.Strict.WriterT', implemented using 'StateT', making it
-- tail recursive. (The traditional writer always leaks space: see
-- <https://mail.haskell.org/pipermail/libraries/2013-March/019528.html here>
-- for more information).
-- <https://downloads.haskell.org/~ghc/latest/docs/html/users_guide/glasgow_exts.html#pattern-synonyms Pattern Synonyms>
-- are used to provide the same interface as
-- 'Control.Monad.Writer.Strict.WriterT'. Unfortunately, current GHC warns
-- whenever these patterns are used that there are unmatched patterns: the
-- <https://ghc.haskell.org/trac/ghc/ticket/8779 COMPLETE> pragma should solve
-- this problem in future version of GHC.
-- A pattern synonym is also provided for a non-transformer version of writer.
-- Again, this is just 'StateT' underneath, but its interface looks as if it was
-- defined like so:
-- > newtype Writer w a = Writer { runWriter :: (a, w) }
-- The other difference between this monad and
-- 'Control.Monad.Writer.Strict.WriterT' is that it relies on '<.>' from
-- 'Semiring', rather than 'mappend' from 'Monoid'.
module Control.Monad.Weighted
   -- * Transformer
  ,pattern WeightedT
   -- * Plain
  ,pattern Weighted

import           Control.Applicative

import           Control.Monad.Fail
import           Control.Monad.Identity

import           Control.Monad.Reader.Class
import           Control.Monad.Weighted.Class
import           Control.Monad.Writer.Class
import           Control.Monad.Cont.Class
import           Control.Monad.Error.Class

import           Control.Monad.State.Strict

import           Data.Coerce

import           Data.Functor.Classes

import           Data.Monoid
import           Data.Semiring

-- | A monad transformer similar to 'Control.Monad.Writer.Strict.WriterT', except
-- that it does not leak space. It is implemented using a state monad, so that
-- `mappend` is tail recursive. See
-- <https://mail.haskell.org/pipermail/libraries/2013-March/019528.html this>
-- email to the Haskell libraries committee for more information.
-- It also uses '<.>' from 'Semiring', rather than 'mappend' from 'Monoid' when
-- combining computations.
-- Wherever possible, coercions are used to eliminate any overhead from the
-- newtype wrapper.
newtype WeightedT s m a =
    WeightedT_ (StateT s m a)
    deriving (Functor,Applicative,Monad,MonadTrans,MonadCont
             ,MonadError e,MonadReader r,MonadFix,MonadFail
             ,MonadIO,Alternative,MonadPlus,MonadWriter w)

-- | Run a weighted computation in the underlying monad.
    :: Semiring s
    => WeightedT s m a -> m (a, s)
runWeightedT =
    (coerce :: (StateT s m a -> m (a, s)) -> WeightedT s m a -> m (a, s))
        (`runStateT` one)

{-# INLINE runWeightedT #-}

{-# ANN module "HLint: ignore Use second" #-}

-- | This pattern gives the newtype wrapper around 'StateT' the same interface
-- as 'Control.Monad.Writer.Strict.WriterT'. Unfortunately, GHC currently warns
-- that a function is incomplete wherever this pattern is used. This issue
-- should be solved in a future version of GHC, when the
-- <https://ghc.haskell.org/trac/ghc/ticket/8779 COMPLETE> pragma is
-- implemented.
pattern WeightedT :: (Functor m, Semiring s) =>
        m (a, s) -> WeightedT s m a

pattern WeightedT x <- (runWeightedT -> x)
  where WeightedT y
          = WeightedT_ (StateT (\ s -> fmap (\ (x, p) -> (x, s <.> p)) y))

-- | A type synonym for the plain (non-transformer) version of 'Weighted'. This
-- can be used as if it were defined as:
-- > newtype Weighted w a = Weighted { runWeighted :: (a, w) }
type Weighted s = WeightedT s Identity

-- | This pattern gives the newtype wrapper around 'StateT' the same interface
-- as as if it was defined like so:
-- > newtype Weighted w a = Weighted { runWeighted :: (a, w) }
-- Unfortunately GHC warns that a function is incomplete wherever this pattern
-- is used. This issue should be solved in a future version of GHC, when the
-- <https://ghc.haskell.org/trac/ghc/ticket/8779 COMPLETE> pragma is
-- implemented.
-- >>> execWeighted $ traverse (\x -> Weighted ((), x)) [1..5]
-- 120
pattern Weighted :: Semiring s => (a, s) -> Weighted s a

pattern Weighted x <- (runWeighted -> x)
  where Weighted (y, p)
          = WeightedT_ (StateT (\ s -> Identity (y, (<.>) p s)))

-- | Run a weighted computation.
-- >>> runWeighted $ traverse (\x -> Weighted (show x, x)) [1..5]
-- (["1","2","3","4","5"],120)
    :: Semiring s
    => Weighted s a -> (a, s)
runWeighted =
    (coerce :: (WeightedT s Identity a -> Identity (a, s)) -> (WeightedT s Identity a -> (a, s)))

{-# INLINE runWeighted #-}

instance MonadState s m =>
         MonadState s (WeightedT w m) where
    get = lift get
    put = lift . put
    state = lift . state

-- | Run a weighted computation in the underlying monad, and return its result.
    :: (Monad m, Semiring s)
    => WeightedT s m a -> m a
evalWeightedT =
    (coerce :: (StateT s m a -> m a) -> WeightedT s m a -> m a)
        (`evalStateT` one)

{-# INLINE evalWeightedT #-}

-- | Run a weighted computation in the underlying monad, and collect its weight.
    :: (Monad m, Semiring s)
    => WeightedT s m a -> m s
execWeightedT =
    (coerce :: (StateT s m a -> m s) -> WeightedT s m a -> m s)
        (`execStateT` one)

{-# INLINE execWeightedT #-}

-- | Run a weighted computation, and return its result.
    :: Semiring s
    => Weighted s a -> a
evalWeighted =
    (coerce :: (State s a -> a) -> Weighted s a -> a) (`evalState` one)

{-# INLINE evalWeighted #-}

-- | Run a weighted computation, and collect its weight.
    :: Semiring s
    => Weighted s a -> s
execWeighted =
    (coerce :: (State s a -> s) -> Weighted s a -> s) (`execState` one)

{-# INLINE execWeighted #-}

instance (Foldable m, Semiring w) =>
         Foldable (WeightedT w m) where
    foldMap f =
            (\(x,_) ->
                  f x) .

    :: Applicative f
    => (a -> f b) -> (a, c) -> f (b, c)
first_ f (x,y) = flip (,) y <$> f x

instance (Traversable m, Semiring w) =>
         Traversable (WeightedT w m) where
    traverse f x = WeightedT <$> (traverse . first_) f (runWeightedT x)

instance (Eq1 m, Eq w, Semiring w) =>
         Eq1 (WeightedT w m) where
    liftEq eq x y =
            (\(xx,xy) (yx,yy) ->
                  eq xx yx && xy == yy)
            (runWeightedT x)
            (runWeightedT y)

instance (Ord1 m, Ord w, Semiring w) =>
         Ord1 (WeightedT w m) where
    liftCompare cmp x y =
            (\(xx,xy) (yx,yy) ->
                  cmp xx yx <> compare xy yy)
            (runWeightedT x)
            (runWeightedT y)

instance (Read w, Read1 m, Semiring w, Functor m) =>
         Read1 (WeightedT w m) where
    liftReadsPrec rp rl =
        readsData $
        readsUnaryWith (liftReadsPrec rp' rl') "WeightedT" WeightedT
        rp' = liftReadsPrec2 rp rl readsPrec readList
        rl' = liftReadList2 rp rl readsPrec readList

instance (Show w, Show1 m, Semiring w) =>
         Show1 (WeightedT w m) where
    liftShowsPrec sp sl d m =
        showsUnaryWith (liftShowsPrec sp' sl') "WeightedT" d (runWeightedT m)
        sp' = liftShowsPrec2 sp sl showsPrec showList
        sl' = liftShowList2 sp sl showsPrec showList

instance (Eq w, Eq1 m, Eq a, Semiring w) =>
         Eq (WeightedT w m a) where
    (==) = eq1

instance (Ord w, Ord1 m, Ord a, Semiring w) =>
         Ord (WeightedT w m a) where
    compare = compare1

instance (Read w, Read1 m, Read a, Semiring w, Functor m) =>
         Read (WeightedT w m a) where
    readsPrec = readsPrec1

instance (Show w, Show1 m, Show a, Semiring w) =>
         Show (WeightedT w m a) where
    showsPrec = showsPrec1

instance (Semiring w, Monad m) =>
         MonadWeighted w (WeightedT w m) where
    weighted (x,s) = WeightedT (pure (x, s))
    {-# INLINE weighted #-}
    weigh (WeightedT_ s) = WeightedT_ ((,) <$> s <*> get)
    {-# INLINE weigh #-}
    scale (WeightedT_ s) = WeightedT_ (scaleS s)
        scaleS = (=<<) (uncurry (<$) . fmap modify)
    {-# INLINE scale #-}