module Language.Glambda.Shift ( shift, subst ) where
import Language.Glambda.Exp
data Length :: [a] -> * where
LZ :: Length '[]
LS :: Length xs -> Length (x ': xs)
type family (xs :: [a]) ++ (ys :: [a]) :: [a]
type instance '[] ++ ys = ys
type instance (x ': xs) ++ ys = x ': (xs ++ ys)
infixr 5 ++
shift :: forall ts2 t ty. Exp ts2 ty -> Exp (t ': ts2) ty
shift = go LZ
where
go :: forall ts1 ty. Length ts1 -> Exp (ts1 ++ ts2) ty -> Exp (ts1 ++ t ': ts2) ty
go l_ts1 (Var v) = Var (shift_elem l_ts1 v)
go l_ts1 (Lam body) = Lam (go (LS l_ts1) body)
go l_ts1 (App e1 e2) = App (go l_ts1 e1) (go l_ts1 e2)
go l_ts1 (Arith e1 op e2) = Arith (go l_ts1 e1) op (go l_ts1 e2)
go l_ts1 (Cond e1 e2 e3) = Cond (go l_ts1 e1) (go l_ts1 e2) (go l_ts1 e3)
go l_ts1 (Fix e) = Fix (go l_ts1 e)
go _ (IntE n) = IntE n
go _ (BoolE b) = BoolE b
shift_elem :: Length ts1 -> Elem (ts1 ++ ts2) x
-> Elem (ts1 ++ t ': ts2) x
shift_elem LZ e = ES e
shift_elem (LS _) EZ = EZ
shift_elem (LS l) (ES e) = ES (shift_elem l e)
subst :: forall ts2 s t.
Exp ts2 s -> Exp (s ': ts2) t -> Exp ts2 t
subst e = go LZ
where
go :: forall ts1 t. Length ts1 -> Exp (ts1 ++ s ': ts2) t -> Exp (ts1 ++ ts2) t
go len (Var v) = subst_var len v
go len (Lam body) = Lam (go (LS len) body)
go len (App e1 e2) = App (go len e1) (go len e2)
go len (Arith e1 op e2) = Arith (go len e1) op (go len e2)
go len (Cond e1 e2 e3) = Cond (go len e1) (go len e2) (go len e3)
go len (Fix e) = Fix (go len e)
go _ (IntE n) = IntE n
go _ (BoolE b) = BoolE b
subst_var :: forall ts1 t.
Length ts1 -> Elem (ts1 ++ s ': ts2) t
-> Exp (ts1 ++ ts2) t
subst_var LZ EZ = e
subst_var LZ (ES v) = Var v
subst_var (LS _) EZ = Var EZ
subst_var (LS len) (ES v) = shift (subst_var len v)