{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}
{- |
Pure reference implementation for the @ExtRef@ interface.

The implementation use @unsafeCoerce@ internally, but its effect cannot escape.
-}
module Control.MLens.ExtRef.Pure
    ( Ext, runExt, runExt_
    ) where

import Control.Monad.State
import Control.Category
import Control.Category.Product
import Data.Sequence
import Data.Foldable (toList)
import Prelude hiding ((.), id, splitAt, length)

import Unsafe.Coerce

import Data.MLens
import Data.MLens.Ref
import Control.MLens.ExtRef


data CC m x = forall a . CC a (a -> x -> m (a, x))

ap_ :: Monad m => x -> CC m x -> m (x, CC m x)
ap_ x (CC a set) = do
    (a', x') <- set a x
    return (x', CC a' set)

unsafeData :: CC m x -> a
unsafeData (CC x _) = unsafeCoerce x


newtype ST m = ST (Seq (CC m (ST m)))

initST :: ST m
initST = ST empty

extend_
    :: Monad m
    => (a -> ST m -> m (a, ST m))
    -> (a -> ST m -> m (a, ST m))
    -> a
    -> ST m
    -> ((ST m -> a, a -> ST m -> m (ST m)), ST m)
extend_ rk kr a0 x0
    = ((getM, setM), x0 ||> CC a0 kr)
  where
    getM = unsafeData . head . snd . limit x0

    setM a x = case limit x0 x of
        (zs, _ : ys) -> do
            (a', re) <- rk a zs
            foldM ((liftM (uncurry (||>)) .) . ap_) (re ||> CC a' kr) ys

    ST x ||> c = ST (x |> c)

    limit (ST x) (ST y) = ST *** toList $ splitAt (length x) y



newtype Ext i m a = Ext { unExt :: StateT (ST m) m a }
    deriving (Functor, Monad)

instance MonadTrans (Ext i) where
    lift = Ext . lift

extRef_ :: Monad m => MLens (Ext i m) a x -> MLens (Ext i m) a x -> a -> Ext i m (Ref (Ext i m) a)
extRef_ r1 r2 a0 = Ext $ do
    a1 <- g a0
    (t,z) <- state $ extend_ (runStateT . f) (runStateT . g) a1
    return $ MLens $ \c -> Ext (gets t) >>= \x -> return
            ( x
            , \a -> Ext $ (StateT $ liftM ((,) ()) . z a) >> return c
            )
   where
    f b = unExt $ getL r2 b >>= flip (setL r1) b
    g b = unExt $ getL r1 b >>= flip (setL r2) b

instance (Monad m) => NewRef (Ext i m) where
    newRef = extRef_ unitLens unitLens

instance (Monad m) => ExtRef (Ext i m) where
    extRef = extRef_ . (. unitLens)

-- | Basic running of the @(Ext i m)@ monad.
runExt :: Monad m => (forall i . Ext i m a) -> m a
runExt s = evalStateT (unExt s) initST

{- |
Advanced running of the @(Ext i m)@ monad.

@Functor@ in contexts would not be needed if it were a superclass of @Monad@.
-}
runExt_
    :: forall c m . (Functor m, NewRef m)
    => (forall n . (Monad n, Functor n) => Morph m n -> Morph n m -> c n -> c m)
    -> (forall i . c (Ext i m)) -> m (c m)
runExt_ mapI int = do
    vx <- newRef initST
    let unlift f = do
            x <- readRef vx
            (b, x) <- runStateT (unExt f) x
            writeRef vx x
            return b
    return $ mapI lift unlift int