{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PatternSynonyms            #-}
{-# LANGUAGE ViewPatterns               #-}

module Control.Monad.Weighted
  (WeightedT
  ,runWeightedT
  ,pattern WeightedT
  ,Weighted
  ,runWeighted
  ,pattern Weighted
  ,execWeightedT
  ,evalWeightedT
  ,execWeighted
  ,evalWeighted)
  where

import           Control.Applicative
import           Control.Monad.Identity
import           Control.Monad.State.Strict
import           Control.Monad.Writer.Class
import           Data.Coerce
import           Data.Semiring

-- | A monad transformer similar to 'WriterT', except that it does not leak
-- space, and it uses the 'Semiring' class, rather than 'Monoid'.
newtype WeightedT s m a =
    WeightedT_ (StateT s m a)
    deriving (Functor,Applicative,Monad,MonadTrans)

runWeightedT
    :: Semiring s
    => WeightedT s m a -> m (a, s)
runWeightedT =
    (coerce :: (StateT s m a -> m (a, s)) -> WeightedT s m a -> m (a, s))
        (`runStateT` one)
{-# INLINE runWeightedT #-}

pattern WeightedT :: (Functor m, Semiring s) => m (a, s) -> WeightedT s m a
pattern WeightedT x <- (runWeightedT -> x) where
  WeightedT y = WeightedT_ (StateT (\s -> (fmap.fmap) (s<.>) y))

type Weighted s = WeightedT s Identity

pattern Weighted :: Semiring s => (a, s) -> Weighted s a
pattern Weighted x <- (runWeighted -> x) where
  Weighted (y,p) = WeightedT_ (StateT (\s -> Identity (y, p<.>s)))

runWeighted
    :: Semiring s
    => Weighted s a -> (a, s)
runWeighted =
    (coerce :: (WeightedT s Identity a -> Identity (a, s)) -> (WeightedT s Identity a -> (a, s)))
        runWeightedT
{-# INLINE runWeighted #-}

instance (Semiring s, Monad m) => MonadWriter (Mul s) (WeightedT s m) where
  writer (x, Mul s) = WeightedT (pure (x, s))
  {-# INLINE writer #-}
  listen (WeightedT_ s) = WeightedT_ ((,) <$> s <*> gets Mul)
  {-# INLINE listen #-}
  pass (WeightedT_ s) = WeightedT_ (passS s) where
    passS = (=<<) (uncurry (<$) . fmap (modify . coerce))
  {-# INLINE pass #-}

evalWeightedT :: (Monad m, Semiring s) => WeightedT s m a -> m a
evalWeightedT =
    (coerce :: (StateT s m a -> m a) -> WeightedT s m a -> m a)
        (`evalStateT` one)
{-# INLINE evalWeightedT #-}

execWeightedT :: (Monad m, Semiring s) => WeightedT s m a -> m s
execWeightedT =
    (coerce :: (StateT s m a -> m s) -> WeightedT s m a -> m s)
        (`execStateT` one)
{-# INLINE execWeightedT #-}

evalWeighted :: Semiring s => Weighted s a -> a
evalWeighted =
    (coerce :: (State s a -> a) -> Weighted s a -> a)
        (`evalState` one)
{-# INLINE evalWeighted #-}

execWeighted :: Semiring s => Weighted s a -> s
execWeighted =
    (coerce :: (State s a -> s) -> Weighted s a -> s)
        (`execState` one)
{-# INLINE execWeighted #-}

instance (Alternative m, Monad m, Semiring s) => Alternative (WeightedT s m) where
  empty = WeightedT empty
  WeightedT x <|> WeightedT y = WeightedT (x <|> y)
  _ <|> _ = undefined