module Zabt.Internal where
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map
class Eq v => Freshen v where
freshen :: v -> v
instance Freshen Int where
freshen n = n + 1
freshPred :: Freshen v => (v -> Bool) -> (v -> v)
freshPred used = go where
go v = if used v then go (freshen v) else v
freshSet :: (Ord v, Freshen v) => Set v -> (v -> v)
freshSet used = freshPred (flip Set.member used)
data Nameless v f x
= F v
| B Int
| Abst v x
| Branch (f x)
deriving (Show, Functor)
instance (Eq v, Eq x, Eq (f x)) => Eq (Nameless v f x) where
F va == F vb = va == vb
B ixa == B ixb = ixa == ixb
Abst _ ta == Abst _ tb = ta == tb
Branch fa == Branch fb = fa == fb
instance (Ord v, Ord x, Ord (f x)) => Ord (Nameless v f x) where
F va `compare` F vb = va `compare` vb
B ixa `compare` B ixb = ixa `compare` ixb
Abst _ ta `compare` Abst _ tb = ta `compare` tb
Branch fa `compare` Branch fb = fa `compare` fb
data Term v f
= Term
{ free :: Set v
, open :: Nameless v f (Term v f)
}
close :: (Ord v, Foldable f) => Nameless v f (Term v f) -> Term v f
close nls = case nls of
F v -> Term { free = Set.singleton v, open = nls }
B i -> Term { free = Set.empty, open = nls }
Abst v nls' -> nls' { open = Abst v nls' }
Branch f -> Term { free = foldMap free f, open = Branch f }
instance (Show v, Show (Nameless v f (Term v f))) => Show (Term v f) where
showsPrec p t = showsPrec p (open t)
deriving instance (Eq v, Eq (f (Term v f))) => Eq (Term v f)
deriving instance (Ord v, Ord (f (Term v f))) => Ord (Term v f)
data View v f x
= VVar v
| VLam v x
| VPat (f x)
deriving (Functor, Show, Eq)
pattern Var a <- (unfold -> VVar a) where
Var a = fold (VVar a)
pattern Lam v t <- (unfold -> VLam v t) where
Lam v t = fold (VLam v t)
pattern Pat f <- (unfold -> VPat f) where
Pat f = fold (VPat f)
fold :: (Functor f, Foldable f, Ord v) => View v f (Term v f) -> Term v f
fold v = case v of
VVar v -> _var v
VLam v t -> let tm = _abstr v t in tm { open = Abst v tm }
VPat f -> Term { free = foldMap free f, open = Branch f }
unfold :: (Functor f, Foldable f, Ord v, Freshen v) => Term v f -> View v f (Term v f)
unfold t@(Term { free = fvs, open = nls }) =
case nls of
F v -> VVar v
B _idx -> error "naked bound variable, invariant broken!"
Abst v t' ->
let v' = freshSet fvs v in
VLam v' (_subst (1, _var v') t')
_var :: v -> Term v f
_var v = Term { free = Set.singleton v, open = F v }
_abstr :: (Functor f, Ord v) => v -> Term v f -> Term v f
_abstr name = go 1 where
go idx t@(Term { free = fvs, open = nls })
| Set.member name fvs =
Term { free = Set.delete name fvs
, open = case nls of
F v | v == name -> B idx
| otherwise -> F v
B i -> nls
Abst v t' -> open (go (idx + 1) t')
Branch f -> Branch (fmap (go idx) f)
}
| otherwise = t
_subst :: (Foldable f, Functor f, Ord v) => (Int, Term v f) -> (Term v f -> Term v f)
_subst (idx, value) = go idx where
go idx t@(Term { free = fvs, open = nls }) =
case nls of
F v -> t
B idx' | idx == idx' -> value
| otherwise -> t
Abst v t' ->
let res = go (idx + 1) t' in
res { open = Abst v res }
Branch f ->
let res = fmap (go idx) f in
Term { free = foldMap free res, open = Branch res }
freeVars :: Term v f -> Set v
freeVars = free
subst :: (Foldable f, Functor f, Ord v) => Map v (Term v f) -> (Term v f -> Term v f)
subst ss = go where
loose = Set.fromList (Map.keys ss)
go t
| Set.null (Set.intersection loose (free t)) = t
| otherwise =
case open t of
F v -> case Map.lookup v ss of
Nothing -> t
Just value -> value
B _ -> t
Abst v t' ->
let res = go t' in
res { open = Abst v res }
Branch f ->
let res = fmap go f in
Term { free = foldMap free res, open = Branch res }
subst1 (v, value) = subst (Map.singleton v value)