module Language.Fixpoint.Visitor (
  
     Visitor (..)
  
  , defaultVisitor
  
  , trans
  
  , fold
  
  , kvars
  , envKVars
  , mapKVars, mapKVars'
  
  , foldSort, mapSort
  ) where
#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative       (Applicative, (<$>), (<*>))
import           Data.Monoid
import           Data.Traversable          (Traversable, traverse)
#endif
import           Control.Monad.Trans.State (State, modify, runState)
import           Language.Fixpoint.Types
import qualified Data.HashSet as S
import qualified Data.List    as L
data Visitor acc ctx = Visitor {
 
    ctxExpr :: ctx -> Expr -> ctx
  , ctxPred :: ctx -> Pred -> ctx
  
  , txExpr  :: ctx -> Expr -> Expr
  , txPred  :: ctx -> Pred -> Pred
  
  , accExpr :: ctx -> Expr -> acc
  , accPred :: ctx -> Pred -> acc
  }
defaultVisitor :: Monoid acc => Visitor acc ctx
defaultVisitor = Visitor {
    ctxExpr    = const 
  , ctxPred    = const 
  , txExpr     = \_ x -> x
  , txPred     = \_ x -> x
  , accExpr    = \_ _ -> mempty
  , accPred    = \_ _ -> mempty
  }
fold         :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> a
fold v c a t = snd $ execVisitM v c a visit t
trans        :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t
trans v c _ z = fst $ execVisitM v c mempty visit z
execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> State a t) -> t -> (t, a)
execVisitM v c a f x = runState (f v c x) a
type VisitM acc = State acc
accum :: (Monoid a) => a -> VisitM a ()
accum = modify . mappend
(<$$>) ::  (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
f <$$> x = traverse f x
class Visitable t where
  visit :: (Monoid a) => Visitor a c -> c -> t -> VisitM a t
instance Visitable Expr where
  visit = visitExpr
instance Visitable Pred where
  visit = visitPred
instance Visitable Refa where
  visit v c (Refa p) =  Refa <$> visit v c p
instance Visitable Reft where
  visit v c (Reft (x, ra)) = (Reft . (x, )) <$> visit v c ra
visitMany :: (Monoid a, Visitable t) => Visitor a ctx -> ctx -> [t] -> VisitM a [t]
visitMany v c xs = visit v c <$$> xs
visitExpr :: (Monoid a) => Visitor a ctx -> ctx -> Expr -> VisitM a Expr
visitExpr v = vE
  where
    vP     = visitPred v
    vE c e = accum acc >> step c' e' where c'  = ctxExpr v c e
                                           e'  = txExpr v c' e
                                           acc = accExpr v c' e
    step _ e@EBot         = return e
    step _ e@(ESym _)     = return e
    step _ e@(ECon _)     = return e
    step _ e@(ELit _ _)   = return e
    step _ e@(EVar _)     = return e
    step c (EApp f es)    = EApp f     <$> (vE c <$$> es)
    step c (ENeg e)       = ENeg       <$> vE c e
    step c (EBin o e1 e2) = EBin o     <$> vE c e1 <*> vE c e2
    step c (EIte p e1 e2) = EIte       <$> vP c p  <*> vE c e1 <*> vE c e2
    step c (ECst e t)     = (`ECst` t) <$> vE c e
visitPred :: (Monoid a) => Visitor a ctx -> ctx -> Pred -> VisitM a Pred
visitPred v = vP
  where
    
    
    vE      = visitExpr v
    vP c p  = accum acc >> step c' p' where c'   = ctxPred v c p
                                            p'   = txPred v c' p
                                            acc  = accPred v c' p
    step c (PAnd  ps)      = PAnd     <$> (vP c <$$> ps)
    step c (POr  ps)       = POr      <$> (vP c <$$> ps)
    step c (PNot p)        = PNot     <$> vP c p
    step c (PImp p1 p2)    = PImp     <$> vP c p1 <*> vP c p2
    step c (PIff p1 p2)    = PIff     <$> vP c p1 <*> vP c p2
    step c (PBexp  e)      = PBexp    <$> vE c e
    step c (PAtom r e1 e2) = PAtom r  <$> vE c e1 <*> vE c e2
    step c (PAll xts p)    = PAll xts <$> vP c p
    step _ p@(PKVar _ _)   = return p 
    step _ p@PTrue         = return p
    step _ p@PFalse        = return p
    step _ p@PTop          = return p
mapKVars :: Visitable t => (KVar -> Maybe Pred) -> t -> t
mapKVars f = mapKVars' f'
  where
    f' (kv, _) = f kv
mapKVars' :: Visitable t => ((KVar, Subst) -> Maybe Pred) -> t -> t
mapKVars' f             = trans kvVis () []
  where
    kvVis              = defaultVisitor { txPred = txK }
    txK _ (PKVar k su)
      | Just p' <- f (k, su) = subst su p'
    txK _ p            = p
kvars :: Visitable t => t -> [KVar]
kvars                = fold kvVis () []
  where
    kvVis            = defaultVisitor { accPred = kv }
    kv _ (PKVar k _) = [k]
    kv _ _           = []
envKVars :: BindEnv -> SubC a -> [KVar]
envKVars be c = squish [ kvs sr |  (_, sr) <- envCs be (senv c)]
  where
    squish = S.toList  . S.fromList . concat
    kvs    = kvars . sr_reft
foldSort :: (a -> Sort -> a) -> a -> Sort -> a
foldSort f = step
  where
    step b t          = go (f b t) t
    go b (FFunc _ ts) = L.foldl' step b ts
    go b (FApp _ ts)  = L.foldl' step b ts
    go b _            = b
mapSort :: (Sort -> Sort) -> Sort -> Sort
mapSort f = step
  where
    step            = go . f
    go (FFunc n ts) = FFunc n $ step <$> ts
    go (FApp c ts)  = FApp c  $ step <$> ts
    go t            = t