{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PatternSynonyms            #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE ViewPatterns               #-}

module Control.Monad.Weighted
  (WeightedT
  ,runWeightedT
  ,pattern WeightedT
  ,Weighted
  ,runWeighted
  ,pattern Weighted
  ,execWeightedT
  ,evalWeightedT
  ,execWeighted
  ,evalWeighted)
  where


import           Control.Applicative
import           Control.Monad.Identity
import           Control.Monad.State.Strict

import           Control.Monad.Cont.Class
import           Control.Monad.Error.Class
import           Control.Monad.Fail
import           Control.Monad.Reader.Class
import           Control.Monad.Writer.Class

import           Control.Monad.Weighted.Class

import           Data.Coerce
import           Data.Functor.Classes

import           Data.Monoid
import           Data.Semiring

-- | A monad transformer similar to 'WriterT', except that it does not leak
-- space, and it uses the 'Semiring' class, rather than 'Monoid'.
newtype WeightedT s m a =
    WeightedT_ (StateT s m a)
    deriving (Functor,Applicative,Monad,MonadTrans,MonadCont,MonadError e
             ,MonadReader r,MonadFix,MonadFail,MonadIO,Alternative,MonadPlus
             ,MonadWriter w)

runWeightedT
    :: Semiring s
    => WeightedT s m a -> m (a, s)
runWeightedT =
    (coerce :: (StateT s m a -> m (a, s)) -> WeightedT s m a -> m (a, s))
        (`runStateT` one)
{-# INLINE runWeightedT #-}


pattern WeightedT :: (Functor m, Semiring s) =>
        m (a, s) -> WeightedT s m a

pattern WeightedT x <- (runWeightedT -> x)
  where WeightedT y
          = WeightedT_ (StateT (\ s -> (fmap . fmap) (s<.>) y))

type Weighted s = WeightedT s Identity

pattern Weighted :: Semiring s => (a, s) -> Weighted s a

pattern Weighted x <- (runWeighted -> x)
  where Weighted (y, p)
          = WeightedT_ (StateT (\ s -> Identity (y, (<.>) p s)))

runWeighted
    :: Semiring s
    => Weighted s a -> (a, s)
runWeighted =
    (coerce :: (WeightedT s Identity a -> Identity (a, s)) -> (WeightedT s Identity a -> (a, s)))
        runWeightedT

{-# INLINE runWeighted #-}

instance MonadState s m =>
         MonadState s (WeightedT w m) where
    get = lift get
    put = lift . put
    state = lift . state

evalWeightedT
    :: (Monad m, Semiring s)
    => WeightedT s m a -> m a
evalWeightedT =
    (coerce :: (StateT s m a -> m a) -> WeightedT s m a -> m a)
        (`evalStateT` one)

{-# INLINE evalWeightedT #-}

execWeightedT
    :: (Monad m, Semiring s)
    => WeightedT s m a -> m s
execWeightedT =
    (coerce :: (StateT s m a -> m s) -> WeightedT s m a -> m s)
        (`execStateT` one)

{-# INLINE execWeightedT #-}

evalWeighted
    :: Semiring s
    => Weighted s a -> a
evalWeighted =
    (coerce :: (State s a -> a) -> Weighted s a -> a) (`evalState` one)

{-# INLINE evalWeighted #-}

execWeighted
    :: Semiring s
    => Weighted s a -> s
execWeighted =
    (coerce :: (State s a -> s) -> Weighted s a -> s) (`execState` one)

{-# INLINE execWeighted #-}

instance (Foldable m, Semiring w) =>
         Foldable (WeightedT w m) where
    foldMap f =
        foldMap
            (\(x,_) ->
                  f x) .
        runWeightedT

first_
    :: Applicative f
    => (a -> f b) -> (a, c) -> f (b, c)
first_ f (x,y) = flip (,) y <$> f x

instance (Traversable m, Semiring w) =>
         Traversable (WeightedT w m) where
    traverse f x = WeightedT <$> (traverse . first_) f (runWeightedT x)

instance (Eq1 m, Eq w, Semiring w) =>
         Eq1 (WeightedT w m) where
    liftEq eq x y =
        liftEq
            (\(xx,xy) (yx,yy) ->
                  eq xx yx && xy == yy)
            (runWeightedT x)
            (runWeightedT y)

instance (Ord1 m, Ord w, Semiring w) =>
         Ord1 (WeightedT w m) where
    liftCompare cmp x y =
        liftCompare
            (\(xx,xy) (yx,yy) ->
                  cmp xx yx <> compare xy yy)
            (runWeightedT x)
            (runWeightedT y)

instance (Read w, Read1 m, Semiring w, Functor m) =>
         Read1 (WeightedT w m) where
    liftReadsPrec rp rl =
        readsData $ readsUnaryWith (liftReadsPrec rp' rl') "WeightedT" WeightedT
      where
        rp' = liftReadsPrec2 rp rl readsPrec readList
        rl' = liftReadList2 rp rl readsPrec readList

instance (Show w, Show1 m, Semiring w) =>
         Show1 (WeightedT w m) where
    liftShowsPrec sp sl d m =
        showsUnaryWith (liftShowsPrec sp' sl') "WeightedT" d (runWeightedT m)
      where
        sp' = liftShowsPrec2 sp sl showsPrec showList
        sl' = liftShowList2 sp sl showsPrec showList

instance (Eq w, Eq1 m, Eq a, Semiring w) =>
         Eq (WeightedT w m a) where
    (==) = eq1

instance (Ord w, Ord1 m, Ord a, Semiring w) =>
         Ord (WeightedT w m a) where
    compare = compare1

instance (Read w, Read1 m, Read a, Semiring w, Functor m) =>
         Read (WeightedT w m a) where
    readsPrec = readsPrec1

instance (Show w, Show1 m, Show a, Semiring w) =>
         Show (WeightedT w m a) where
    showsPrec = showsPrec1

instance (Semiring w, Monad m) => MonadWeighted w (WeightedT w m) where
    weighted (x,s) = WeightedT (pure (x, s))
    {-# INLINE weighted #-}
    weigh (WeightedT_ s) = WeightedT_ ((,) <$> s <*> get)
    {-# INLINE weigh #-}
    scale (WeightedT_ s) = WeightedT_ (scaleS s)
      where
        scaleS = (=<<) (uncurry (<$) . fmap modify)
    {-# INLINE scale #-}