{-# LANGUAGE CPP                   #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
#if __GLASGOW_HASKELL__ >= 805
{-# LANGUAGE QuantifiedConstraints #-}
#endif
{-# LANGUAGE UndecidableInstances  #-}
module Bound.ScopeT (
    ScopeT (..),
    (>>>>=),
    
    abstractT, abstract1T, abstractTEither,
    
    abstractTName, abstract1TName,
    
    instantiateT, instantiate1T, instantiateTEither,
    
    fromScopeT,
    toScopeT,
    
    lowerScopeT,
    splatT,
    bindingsT,
    mapBoundT,
    mapScopeT,
    foldMapBoundT,
    foldMapScopeT,
    traverseBoundT_,
    traverseScopeT_,
    traverseBoundT,
    traverseScopeT,
    bitransverseScopeT,
    ) where
import Bound                (Bound (..), Scope (..), Var (..))
import Bound.Name           (Name (..))
import Control.DeepSeq      (NFData (..))
import Control.Monad.Module (Module (..))
import Data.Bifoldable      (bifoldMap, bitraverse_)
import Data.Bifunctor       (bimap)
import Data.Bitraversable   (Bitraversable (..))
import Data.Foldable        (traverse_)
import Data.Functor.Classes
import Data.Hashable        (Hashable (..))
import Data.Hashable.Lifted (Hashable1 (..), hashWithSalt1)
newtype ScopeT b t f a = ScopeT { unscopeT :: t f (Var b (f a)) }
instance (Functor (t f), Functor f) => Functor (ScopeT b t f) where
   fmap f (ScopeT a) = ScopeT $ fmap (fmap (fmap f)) a
instance (Foldable (t f), Foldable f) => Foldable (ScopeT b t f) where
    foldMap f (ScopeT a) = foldMap (foldMap (foldMap f)) a
    foldr f z (ScopeT a) = foldr (flip (foldr (flip (foldr f))))  z a
instance (Traversable (t f), Traversable f) => Traversable (ScopeT b t f) where
    traverse f (ScopeT a) = ScopeT <$> traverse (traverse (traverse f)) a
(>>>>=) :: (Monad f, Functor (t f)) => ScopeT b t f a -> (a -> f c) -> ScopeT b t f c
ScopeT m >>>>= k = ScopeT $ fmap (fmap (>>= k)) m
{-# INLINE (>>>>=) #-}
#if __GLASGOW_HASKELL__ >= 805
instance (forall f. Functor f => Functor (t f)) => Bound (ScopeT n t) where
    (>>>=) = (>>>>=)
#endif
instance (Monad f, Functor (t f)) => Module (ScopeT b t f) f where
    (>>==) = (>>>>=)
instance (Hashable b, Bound t, Monad f, Hashable1 f, Hashable1 (t f)) => Hashable1 (ScopeT b t f) where
    liftHashWithSalt h s m = liftHashWithSalt (liftHashWithSalt h) s (fromScopeT m)
    {-# INLINE liftHashWithSalt #-}
instance (Hashable b, Bound t, Monad f, Hashable1 f, Hashable1 (t f), Hashable a) => Hashable (ScopeT b t f a) where
    hashWithSalt n m = hashWithSalt1 n (fromScopeT m)
    {-# INLINE hashWithSalt #-}
instance NFData (t f (Var b (f a))) => NFData (ScopeT b t f a) where
  rnf scope = rnf (unscopeT scope)
instance (Monad f, Bound t, Eq b, Eq1 (t f), Eq1 f, Eq a) => Eq  (ScopeT b t f a) where (==) = eq1
instance (Monad f, Bound t, Ord b, Ord1 (t f), Ord1 f, Ord a) => Ord  (ScopeT b t f a) where compare = compare1
instance (Show b, Show1 (t f), Show1 f, Show a) => Show (ScopeT b t f a) where showsPrec = showsPrec1
instance (Read b, Read1 (t f), Read1 f, Read a) => Read (ScopeT b t f a) where readsPrec = readsPrec1
instance (Monad f, Bound t, Eq b, Eq1 (t f), Eq1 f) => Eq1 (ScopeT b t f) where
  liftEq f m n = liftEq (liftEq f) (fromScopeT m) (fromScopeT n)
instance (Monad f, Bound t, Ord b, Ord1 (t f), Ord1 f) => Ord1 (ScopeT b t f) where
  liftCompare f m n = liftCompare (liftCompare f) (fromScopeT m) (fromScopeT n)
instance (Show b, Show1 (t f), Show1 f) => Show1 (ScopeT b t f) where
    liftShowsPrec sp sl d (ScopeT x) = showsUnaryWith
        (liftShowsPrec (liftShowsPrec sp' sl') (liftShowList sp' sl'))
        "ScopeT" d x
      where
        sp' = liftShowsPrec sp sl
        sl' = liftShowList sp sl
instance (Read b, Read1 (t f), Read1 f) => Read1 (ScopeT b t f) where
    liftReadsPrec f g = readsData $ readsUnaryWith
        (liftReadsPrec (liftReadsPrec f' g') (liftReadList f' g'))
        "ScopeT" ScopeT
      where
        f' = liftReadsPrec f g
        g' = liftReadList f g
abstractT :: (Functor (t f), Monad f) => (a -> Maybe b) -> t f a -> ScopeT b t f a
abstractT f e = ScopeT (fmap k e) where
    k y = case f y of
        Just z  -> B z
        Nothing -> F (return y)
{-# INLINE abstractT #-}
abstract1T :: (Functor (t f), Monad f, Eq a) => a -> t f a -> ScopeT () t f a
abstract1T a = abstractT (\b -> if a == b then Just () else Nothing)
{-# INLINE abstract1T #-}
abstractTEither :: (Functor (t f),  Monad f) => (a -> Either b c) -> t f a -> ScopeT b t f c
abstractTEither f e = ScopeT (fmap k e) where
    k y = case f y of
        Left z -> B z
        Right y' -> F (return y')
{-# INLINE abstractTEither #-}
abstractTName :: (Functor (t f), Monad f) => (a -> Maybe b) -> t f a -> ScopeT (Name a b) t f a
abstractTName f t = ScopeT (fmap k t) where
    k a = case f a of
        Just b  -> B (Name a b)
        Nothing -> F (return a)
{-# INLINE abstractTName #-}
abstract1TName :: (Functor (t f), Monad f, Eq a) => a -> t f a -> ScopeT (Name a ()) t f a
abstract1TName a = abstractTName (\b -> if a == b then Just () else Nothing)
{-# INLINE abstract1TName #-}
instantiateT :: (Bound t, Monad f) => (b -> f a) -> ScopeT b t f a -> t f a
instantiateT k (ScopeT e) = e >>>= \v -> case v of
    B b -> k b
    F a -> a
{-# INLINE instantiateT #-}
instantiate1T :: (Bound t, Monad f) => f a -> ScopeT b t f a -> t f a
instantiate1T e = instantiateT (const e)
{-# INLINE instantiate1T #-}
instantiateTEither :: (Bound t, Monad f) => (Either b a -> f c) -> ScopeT b t f a -> t f c
instantiateTEither f (ScopeT e) = e >>>= \v -> case v of
    B b -> f (Left b)
    F ea -> ea >>= f . Right
{-# INLINE instantiateTEither #-}
fromScopeT :: (Bound t, Monad f) => ScopeT b t f a -> t f (Var b a)
fromScopeT (ScopeT s) = s >>>= \v -> case v of
    F e -> fmap F e
    B b -> return (B b)
toScopeT :: (Functor (t f), Monad f) => t f (Var b a) -> ScopeT b t f a
toScopeT e = ScopeT (fmap (fmap return) e)
lowerScopeT
    :: (Functor (t f), Functor f)
    => (forall x. t f x -> g x)
    -> (forall x. f x -> g x)
    -> ScopeT b t f a -> Scope b g a
lowerScopeT tf f (ScopeT x) = Scope (tf (fmap (fmap f) x))
splatT :: (Bound t, Monad f) => (a -> f c) -> (b -> f c) -> ScopeT b t f a -> t f c
splatT f unbind (ScopeT e) = e >>>= \v -> case v of
    B b -> unbind b
    F ea -> ea >>= f
{-# INLINE splatT #-}
bindingsT :: Foldable (t f) => ScopeT b t f a -> [b]
bindingsT (ScopeT s) = foldr f [] s where
    f (B v) vs = v : vs
    f _ vs     = vs
{-# INLINE bindingsT #-}
mapBoundT :: Functor (t f) => (b -> b') -> ScopeT b t f a -> ScopeT b' t f a
mapBoundT f (ScopeT s) = ScopeT (fmap f' s) where
    f' (B b) = B (f b)
    f' (F a) = F a
{-# INLINE mapBoundT #-}
mapScopeT
    :: (Functor (t f), Functor f)
    => (b -> d) -> (a -> c)
    -> ScopeT b t f a -> ScopeT d t f c
mapScopeT f g (ScopeT s) = ScopeT $ fmap (bimap f (fmap g)) s
{-# INLINE mapScopeT #-}
foldMapBoundT :: (Foldable (t f), Monoid r) => (b -> r) -> ScopeT b t f a -> r
foldMapBoundT f (ScopeT s) = foldMap f' s where
    f' (B a) = f a
    f' _     = mempty
{-# INLINE foldMapBoundT #-}
foldMapScopeT
    :: (Foldable f, Foldable (t f), Monoid r)
    => (b -> r) -> (a -> r)
    -> ScopeT b t f a -> r
foldMapScopeT f g (ScopeT s) = foldMap (bifoldMap f (foldMap g)) s
{-# INLINE foldMapScopeT #-}
traverseBoundT_ :: (Applicative g, Foldable (t f)) => (b -> g d) -> ScopeT b t f a -> g ()
traverseBoundT_ f (ScopeT s) = traverse_ f' s where
    f' (B a) = () <$ f a
    f' _     = pure ()
{-# INLINE traverseBoundT_ #-}
traverseScopeT_
    :: (Applicative g, Foldable f, Foldable (t f))
    => (b -> g d) -> (a -> g c)
    -> ScopeT b t f a -> g ()
traverseScopeT_ f g (ScopeT s) = traverse_ (bitraverse_ f (traverse_ g)) s
{-# INLINE traverseScopeT_ #-}
traverseBoundT
    :: (Applicative g, Traversable (t f))
    => (b -> g c) -> ScopeT b t f a -> g (ScopeT c t f a)
traverseBoundT f (ScopeT s) = ScopeT <$> traverse f' s where
    f' (B b) = B <$> f b
    f' (F a) = pure (F a)
{-# INLINE traverseBoundT #-}
traverseScopeT
    :: (Applicative g, Traversable f, Traversable (t f))
    => (b -> g d) -> (a -> g c)
    -> ScopeT b t f a -> g (ScopeT d t f c)
traverseScopeT f g (ScopeT s) = ScopeT <$> traverse (bitraverse f (traverse g)) s
{-# INLINE traverseScopeT #-}
bitransverseScopeT
    :: Applicative f
    => (forall x x'. (x -> f x') -> t s x -> f (t' s' x'))  
    -> (forall x x'. (x -> f x') -> s x -> f (s' x'))       
    -> (a -> f a')
    -> ScopeT b t s a
    -> f (ScopeT b t' s' a')
bitransverseScopeT tauT tauS f = fmap ScopeT . tauT (traverse (tauS f)) . unscopeT
{-# INLINE bitransverseScopeT #-}