{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# lANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Data.Cache.Trace
 ( CacheEvent(..), CacheTrace
 , MonadTrace(..), MonadAtomicRefTraced(..), atomicModifyRefTracedM'_
 , Tracable(..)
 ) where

import           Control.Applicative
import           Control.Monad
import qualified Control.Monad.Fail as Fail
import           Control.Monad.Fix
import           Control.Monad.Ref
import           Control.Monad.Trans
import qualified Data.DList as DList
import           Data.Functor.Identity
import           Prelude hiding (lookup)

data CacheEvent k t v
 = CacheEvict
   { _ceKey      :: !k
   , _ceTracking :: !t
   , _ceValue    :: v
   }
 | CacheAdd
   { _ceKey      :: !k
   , _ceTracking :: !t
   , _ceValue    :: v
   }
 deriving (Read, Show, Eq, Ord)

type CacheTrace k t v = DList.DList (CacheEvent k t v)

data family Tracable (trc :: Bool) :: * -> (* -> *) -> * -> *
newtype instance Tracable 'False w m a = UntracedT { runUntracedT :: m a }
newtype instance Tracable 'True  w m a = TracedT { runTracedT :: m (a, w) }

class MonadTrace (trc :: Bool) where
  trace :: Applicative m => w -> Tracable trc w m ()

instance MonadTrace 'True where
  trace w = TracedT (pure ((), w))
  {-# INLINE trace #-}

instance MonadTrace 'False where
  trace _ = UntracedT (pure  ())
  {-# INLINE trace #-}

class MonadAtomicRef m => MonadAtomicRefTraced trc w m where
  -- |Atomically mutate the contents of a reference with trace data output
  atomicModifyRefTraced  :: Ref m a -> (a -> (a, w, b)) -> Tracable trc w m b
  -- |Strict version of atomicModifyRefTraced. This forces both the value stored in
  -- the reference as well as the value returned but not the trace.
  atomicModifyRefTraced' :: Ref m a -> (a -> (a, w, b)) -> Tracable trc w m b
  -- | Strict Monadic update function so that tracing is closer to zero cost when unused.
  atomicModifyRefTracedM' :: Ref m a -> (forall m' . (Monad m') => a -> Tracable trc w m' (a, b)) -> Tracable trc w m b

atomicModifyRefTracedM'_ :: (MonadAtomicRefTraced trc w m, forall mg . (Monad mg) => Functor (Tracable trc w mg))
                         => Ref m a
                         -> (forall m' . (Monad m', Functor (Tracable trc w m')) => a -> Tracable trc w m' a)
                         -> Tracable trc w m ()
atomicModifyRefTracedM'_ r f = atomicModifyRefTracedM' r (fmap (,()) .f)

instance (MonadAtomicRef m, Monoid w) => MonadAtomicRefTraced 'True w m where
  atomicModifyRefTraced r f = do
    (w, b) <- lift $ atomicModifyRef r ((\(a, w, b) -> (a, (w, b))) . f)
    trace w
    pure b
  {-# INLINE atomicModifyRefTraced #-}
  atomicModifyRefTraced' r f = do
    (w, b) <- lift $ atomicModifyRef r $
      \x -> let (a, w, b) = f x
             in (a, a `seq` (w, b))
    trace w
    b `seq` pure b
  {-# INLINE atomicModifyRefTraced' #-}
  atomicModifyRefTracedM' r f = do
    (b, w) <- lift $ atomicModifyRef' r $
      (\((a, b), w) -> (a, (b, w))) . runIdentity . runTracedT . f
    trace w
    pure b
  {-# INLINE atomicModifyRefTracedM' #-}

instance (MonadAtomicRef m, Monoid w) => MonadAtomicRefTraced 'False w m where
  atomicModifyRefTraced r f = lift $ atomicModifyRef r ((\(a, _, b) -> (a, b)) . f)
  {-# INLINE atomicModifyRefTraced #-}
  atomicModifyRefTraced' r f = lift $ atomicModifyRef' r ((\(a, _, b) -> (a, b)) . f)
  {-# INLINE atomicModifyRefTraced' #-}
  atomicModifyRefTracedM' r f = lift $ atomicModifyRef' r (runIdentity . runUntracedT . f)
  {-# INLINE atomicModifyRefTracedM' #-}

instance Functor m => Functor (Tracable 'True w m) where
  fmap f (TracedT m) = TracedT $ fmap (\ ~(a, w') -> (f a, w')) $ m
  {-# INLINE fmap #-}

instance Functor m => Functor (Tracable 'False w m) where
  fmap f (UntracedT m) = UntracedT $ fmap f m
  {-# INLINE fmap #-}

instance (Monoid w, Monad m) => Applicative (Tracable 'True w m) where
  pure a = TracedT $ pure (a, mempty)
  {-# INLINE pure #-}
  mf <*> mv = TracedT $ do
    ~(f, w')  <- runTracedT mf
    ~(v, w'') <- runTracedT mv
    pure (f v, w' `mappend` w'')
  {-# INLINE (<*>) #-}

instance (Monad m) => Applicative (Tracable 'False w m) where
  pure a = UntracedT $ pure a
  {-# INLINE pure #-}
  mf <*> mv = UntracedT $ (runUntracedT mf) <*> (runUntracedT mv)
  {-# INLINE (<*>) #-}

instance (Monoid w, MonadPlus m) => Alternative (Tracable 'True w m) where
  empty   = TracedT $ mzero
  {-# INLINE empty #-}
  m <|> n = TracedT $ runTracedT m `mplus` runTracedT n
  {-# INLINE (<|>) #-}

instance (Monoid w, MonadPlus m) => Alternative (Tracable 'False w m) where
  empty   = UntracedT $ mzero
  {-# INLINE empty #-}
  m <|> n = UntracedT $ runUntracedT m `mplus` runUntracedT n
  {-# INLINE (<|>) #-}

instance (Monoid w, Monad m) => Monad (Tracable 'True w m) where
  m >>= k  = TracedT $ do
    ~(a, w')  <- runTracedT m
    ~(b, w'') <- runTracedT (k a)
    return (b, w' `mappend` w'')
  {-# INLINE (>>=) #-}

instance (Monad m) => Monad (Tracable 'False w m) where
  m >>= k  = UntracedT $ runUntracedT m >>= \a -> runUntracedT (k a)
  {-# INLINE (>>=) #-}

instance (Monoid w, Fail.MonadFail m) => Fail.MonadFail (Tracable 'True w m) where
  fail msg = TracedT $ Fail.fail msg
  {-# INLINE fail #-}

instance (Fail.MonadFail m) => Fail.MonadFail (Tracable 'False w m) where
  fail msg = UntracedT $ Fail.fail msg
  {-# INLINE fail #-}

instance (Monoid w, MonadPlus m) => MonadPlus (Tracable 'True w m) where
  mzero       = TracedT $ mzero
  {-# INLINE mzero #-}
  m `mplus` n = TracedT $ runTracedT m `mplus` runTracedT n
  {-# INLINE mplus #-}

instance (Monoid w, MonadPlus m) => MonadPlus (Tracable 'False w m) where
  mzero       = UntracedT $ mzero
  {-# INLINE mzero #-}
  m `mplus` n = UntracedT $ runUntracedT m `mplus` runUntracedT n
  {-# INLINE mplus #-}

instance (Monoid w, MonadFix m) => MonadFix (Tracable 'True w m) where
  mfix m = TracedT $ mfix $ \ ~(a, _) -> runTracedT (m a)
  {-# INLINE mfix #-}

instance (MonadFix m) => MonadFix (Tracable 'False w m) where
  mfix m = UntracedT $ mfix $ \ a -> runUntracedT (m a)
  {-# INLINE mfix #-}

instance (Monoid w) => MonadTrans (Tracable 'True w) where
  lift = TracedT . fmap (\a -> (a, mempty))
  {-# INLINE lift #-}

instance (Monoid w) => MonadTrans (Tracable 'False w) where
  lift = UntracedT
  {-# INLINE lift #-}

instance (Monoid w, MonadIO m) => MonadIO (Tracable 'True w m) where
  liftIO = lift . liftIO
  {-# INLINE liftIO #-}

instance (Monoid w, MonadIO m) => MonadIO (Tracable 'False w m) where
  liftIO = lift . liftIO
  {-# INLINE liftIO #-}