{-# LANGUAGE BangPatterns #-}
module Agda.Utils.Lens
  ( module Agda.Utils.Lens
  , (<&>) 
  ) where
import Control.Applicative ( Const(Const), getConst )
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Functor.Identity
import Agda.Utils.Functor ((<&>))
type Lens' i o = forall f. Functor f => (i -> f i) -> o -> f o
type LensGet i o = o -> i
type LensSet i o = i -> o -> o
type LensMap i o = (i -> i) -> o -> o
lFst :: Lens' a (a, b)
lFst f (x, y) = (, y) <$> f x
lSnd :: Lens' b (a, b)
lSnd f (x, y) = (x,) <$> f y
infixl 8 ^.
(^.) :: o -> Lens' i o -> i
o ^. l = getConst $ l Const o
set :: Lens' i o -> LensSet i o
set l = over l . const
over :: Lens' i o -> LensMap i o
over l f o = runIdentity $ l (Identity . f) o
focus :: Monad m => Lens' i o -> StateT i m a -> StateT o m a
focus l m = StateT $ \ o -> do
  (a, i) <- runStateT m (o ^. l)
  return (a, set l i o)
use :: MonadState o m => Lens' i o -> m i
use l = do !x <- gets (^. l)
           return x
infix 4 .=
(.=) :: MonadState o m => Lens' i o -> i -> m ()
l .= i = modify $ set l i
infix 4 %=
(%=) :: MonadState o m => Lens' i o -> (i -> i) -> m ()
l %= f = modify $ over l f
infix 4 %==
(%==) :: MonadState o m => Lens' i o -> (i -> m i) -> m ()
l %== f = put =<< l f =<< get
infix 4 %%=
(%%=) :: MonadState o m => Lens' i o -> (i -> m (i, r)) -> m r
l %%= f = do
  o <- get
  (o', r) <- runWriterT $ l (WriterT . f) o
  put o'
  return r
locallyState :: MonadState o m => Lens' i o -> (i -> i) -> m r -> m r
locallyState l f k = do
  old <- use l
  l %= f
  x <- k
  l .= old
  return x
view :: MonadReader o m => Lens' i o -> m i
view l = asks (^. l)
locally :: MonadReader o m => Lens' i o -> (i -> i) -> m a -> m a
locally l = local . over l
locally' :: ((o -> o) -> m a -> m a) -> Lens' i o -> (i -> i) -> m a -> m a
locally' local l = local . over l
key :: Ord k => k -> Lens' (Maybe v) (Map k v)
key k f m = f (Map.lookup k m) <&> \ v -> Map.alter (const v) k m