module LambdaCube.SystemFw_.Substitution where

import           LambdaCube.SystemFw_.Ast
import           LambdaCube.SystemFw_.Lifter

substituteTypeInType :: Int -> LCType -> LCType -> LCType
substituteTypeInType :: Int -> LCType -> LCType -> LCType
substituteTypeInType Int
n LCType
v = Int -> LCType -> LCType
go Int
n
  where
    go :: Int -> LCType -> LCType
go Int
_ LCType
LCBase        = LCType
LCBase
    go Int
m e :: LCType
e@(LCTVar Int
l)  = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCType
v else LCType
e
    go Int
m (LCArr LCType
a LCType
b)   = Int -> LCType -> LCType
go Int
m LCType
a LCType -> LCType -> LCType
`LCArr` Int -> LCType -> LCType
go Int
m LCType
b
    go Int
m (LCTTLam LCKind
k LCType
b) = LCKind -> LCType -> LCType
LCTTLam LCKind
k (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ Int -> LCType -> LCType
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCType
b
    go Int
m (LCTTApp LCType
f LCType
a) = Int -> LCType -> LCType
go Int
m LCType
f LCType -> LCType -> LCType
`LCTTApp` Int -> LCType -> LCType
go Int
m LCType
a

substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue Int
n LCValue
v = Int -> LCTerm -> LCTerm
go Int
n
  where
    go :: Int -> LCTerm -> LCTerm
go Int
m e :: LCTerm
e@(LCVar Int
l) = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCValue -> LCTerm
liftLCValue LCValue
v else LCTerm
e
    go Int
m (LCLam LCType
t LCTerm
b) = LCType -> LCTerm -> LCTerm
LCLam LCType
t (Int -> LCTerm -> LCTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b)
    go Int
m (LCApp LCTerm
f LCTerm
a) = LCTerm -> LCTerm -> LCTerm
LCApp (Int -> LCTerm -> LCTerm
go Int
m LCTerm
f) (Int -> LCTerm -> LCTerm
go Int
m LCTerm
a)

substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
n LCNormalTerm
v = Int -> LCNormalTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNormalTerm -> LCNormalTerm
go Int
m (LCNormLam LCType
t LCNormalTerm
b) = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam LCType
t (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go Int
m (LCNormNeut LCNeutralTerm
nt) = Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
m LCNormalTerm
v LCNeutralTerm
nt

substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
n LCNormalTerm
v = Int -> LCNeutralTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNeutralTerm -> LCNormalTerm
go Int
m e :: LCNeutralTerm
e@(LCNeutVar Int
l) = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCNormalTerm
v else LCNeutralTerm -> LCNormalTerm
LCNormNeut LCNeutralTerm
e
    go Int
m (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
b -> Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
0 LCNormalTerm
a' LCNormalTerm
b
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` LCNormalTerm
a'
      where
        a' :: LCNormalTerm
a' = Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
m LCNormalTerm
v LCNormalTerm
a