module LambdaCube.STLC.Substitution where

import           LambdaCube.STLC.Ast
import           LambdaCube.STLC.Lifter

substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue Int
n LCValue
v = Int -> LCTerm -> LCTerm
go Int
n
  where
    go :: Int -> LCTerm -> LCTerm
go Int
m e :: LCTerm
e@(LCVar Int
l) = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCValue -> LCTerm
liftLCValue LCValue
v else LCTerm
e
    go Int
m (LCLam LCType
t LCTerm
b) = LCType -> LCTerm -> LCTerm
LCLam LCType
t (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b
    go Int
m (LCApp LCTerm
f LCTerm
a) = Int -> LCTerm -> LCTerm
go Int
m LCTerm
f LCTerm -> LCTerm -> LCTerm
`LCApp` Int -> LCTerm -> LCTerm
go Int
m LCTerm
a

substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
n LCNormalTerm
nv = Int -> LCNormalTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNormalTerm -> LCNormalTerm
go Int
m (LCNormLam LCType
t LCNormalTerm
b)   = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam LCType
t (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go Int
m (LCNormNeut LCNeutralTerm
neut) = Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
m LCNormalTerm
nv LCNeutralTerm
neut

substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
n LCNormalTerm
nv = Int -> LCNeutralTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNeutralTerm -> LCNormalTerm
go Int
m e :: LCNeutralTerm
e@(LCNeutVar Int
l) = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCNormalTerm
nv else LCNeutralTerm -> LCNormalTerm
LCNormNeut LCNeutralTerm
e
    go Int
m (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
b   -> Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
0 LCNormalTerm
a' LCNormalTerm
b
        LCNormNeut LCNeutralTerm
neut -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
neut LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` LCNormalTerm
a'
      where
        a' :: LCNormalTerm
a' = Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
m LCNormalTerm
nv LCNormalTerm
a