module Data.Comp.Multi.Variables
    (
     HasVars(..),
     GSubst,
     CxtSubst,
     Subst,
     varsToHoles,
     containsVar,
     variables,
     variableList,
     variables',
     appSubst,
     compSubst,
     getBoundVars,
    (&),
    (|->),
    empty
    ) where
import Data.Comp.Multi.Algebra
import Data.Comp.Multi.Derive
import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.Mapping
import Data.Comp.Multi.Ops
import Data.Comp.Multi.Term
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
type GSubst v a = Map v (A a)
type CxtSubst h a f v =  GSubst v (Cxt h f a)
type Subst f v = CxtSubst NoHole (K ()) f v
type SubstFun v a = NatM Maybe (K v) a
substFun :: Ord v => GSubst v a -> SubstFun v a
substFun s (K v) = fmap unA $ Map.lookup v s
class HasVars (f  :: (* -> *) -> * -> *) v where
    
    
    isVar :: f a :=> Maybe v
    isVar _ = Nothing
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    bindsVars :: Mapping m a => f a :=> m (Set v)
    bindsVars _ = empty
$(derive [liftSum] [''HasVars])
isVar' :: (HasVars f v, Ord v) => Set v -> f a :=> Maybe v
isVar' b t = do v <- isVar t
                if v `Set.member` b
                   then Nothing
                   else return v
getBoundVars :: forall f a v i . (HasVars f v, HTraversable f) => f a i -> f (a :*: K (Set v)) i
getBoundVars t = let n :: f (Numbered a) i
                     n = number t
                     m = bindsVars n
                     trans :: Numbered a :-> (a :*: K (Set v))
                     trans (Numbered i x) = x :*: K (lookupNumMap Set.empty i m)
                 in hfmap trans n
hfmapBoundVars :: forall f a b v i . (HasVars f v, HTraversable f)
                  => (Set v -> a :-> b) -> f a i -> f b i
hfmapBoundVars f t = let n :: f (Numbered a) i
                         n = number t
                         m = bindsVars n
                         trans :: Numbered a :-> b
                         trans (Numbered i x) = f (lookupNumMap Set.empty i m) x
                     in hfmap trans n
hfoldlBoundVars :: forall f a b v i . (HasVars f v, HTraversable f)
                  => (b -> Set v ->  a :=> b) -> b -> f a i -> b
hfoldlBoundVars f e t = let n :: f (Numbered a) i
                            n = number t
                            m = bindsVars n
                            trans :: b -> Numbered a :=> b
                            trans x (Numbered i y) = f x (lookupNumMap Set.empty i m) y
                       in hfoldl trans e n
newtype C a b i = C{ unC :: a -> b i }
varsToHoles :: forall f v. (HTraversable f, HasVars f v, Ord v) =>
                Term f :-> Context f (K v)
varsToHoles t = unC (cata alg t) Set.empty
    where alg :: (HTraversable f, HasVars f v, Ord v) => Alg f (C (Set v) (Context f (K v)))
          alg t = C $ \vars -> case isVar t of
            Just v | not (v `Set.member` vars) -> Hole $ K v
            _  -> Term $ hfmapBoundVars run t
              where
                run :: Set v -> C (Set v) (Context f (K v))  :-> Context f (K v)
                run newVars f = f `unC` (newVars `Set.union` vars)
containsVarAlg :: forall v f . (Ord v, HasVars f v, HTraversable f) => v -> Alg f (K Bool)
containsVarAlg v t = K $ hfoldlBoundVars run local t
    where local = case isVar t of
                    Just v' -> v == v'
                    Nothing -> False
          run :: Bool -> Set v -> K Bool i -> Bool
          run acc vars (K b) = acc || (not (v `Set.member` vars) && b)
containsVar :: (Ord v, HasVars f v, HTraversable f, HFunctor f)
            => v -> Cxt h f a :=> Bool
containsVar v = unK . free (containsVarAlg v) (const $ K False)
variableList :: (HasVars f v, HTraversable f, HFunctor f, Ord v)
             => Cxt h f a :=> [v]
variableList = Set.toList . variables
variablesAlg :: (Ord v, HasVars f v, HTraversable f) => Alg f (K (Set v))
variablesAlg t = K $ hfoldlBoundVars run local t
    where local = case isVar t of
                    Just v -> Set.singleton v
                    Nothing -> Set.empty
          run acc bvars (K vars) = acc `Set.union` (vars `Set.difference` bvars)
variables :: (Ord v, HasVars f v, HTraversable f, HFunctor f)
            => Cxt h f a :=> Set v
variables = unK . free variablesAlg (const $ K Set.empty)
variables' :: (Ord v, HasVars f v, HFoldable f, HFunctor f)
            => Const f :=> Set v
variables' c =  case isVar c of
                  Nothing -> Set.empty
                  Just v -> Set.singleton v
class SubstVars v t a where
    substVars :: SubstFun v t -> a :-> a
appSubst :: (Ord v, SubstVars v t a) => GSubst v t -> a :-> a
appSubst subst = substVars (substFun subst)
instance (Ord v, HasVars f v, HTraversable f) => SubstVars v (Cxt h f a) (Cxt h f a) where
    
    substVars subst = doSubst Set.empty
      where doSubst :: Set v -> Cxt h f a :-> Cxt h f a
            doSubst _ (Hole a) = Hole a
            doSubst b (Term t) = case isVar' b t >>= subst . K of
              Just new -> new
              Nothing  -> Term $ hfmapBoundVars run t
                where run :: Set v -> Cxt h f a :-> Cxt h f a
                      run vars = doSubst (b `Set.union` vars)
instance (SubstVars v t a, HFunctor f) => SubstVars v t (f a) where
    substVars subst = hfmap (substVars subst)
compSubst :: (Ord v, HasVars f v, HTraversable f)
          => CxtSubst h a f v -> CxtSubst h a f v -> CxtSubst h a f v
compSubst s1 = Map.map (\ (A t) -> A (appSubst s1 t))