{-# LANGUAGE ViewPatterns #-}
module LambdaCube.STLC.Substitution
  ( substituteValue
  , substituteNormalInNormal
  ) where

import           LambdaCube.STLC.Ast
import           LambdaCube.STLC.Lifter

substituteValue :: LCValue -> Int -> LCTerm -> LCTerm
substituteValue :: LCValue -> Int -> LCTerm -> LCTerm
substituteValue LCValue
v = (LCValue, Int) -> Int -> LCTerm -> LCTerm
substDefValue (LCValue
v, Int
0)

substituteNormalInNormal :: LCNormalTerm -> Int -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal :: LCNormalTerm -> Int -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal LCNormalTerm
v = (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
substDefNormalInNormal (LCNormalTerm
v, Int
0)

substDefValue :: (LCValue, Int) -> Int -> LCTerm -> LCTerm
substDefValue :: (LCValue, Int) -> Int -> LCTerm -> LCTerm
substDefValue = (LCValue, Int) -> Int -> LCTerm -> LCTerm
go
  where
    go :: (LCValue, Int) -> Int -> LCTerm -> LCTerm
go (LCValue, Int)
dv     Int
x (LCVar ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x) -> Bool
True))  = (LCValue, Int) -> LCTerm
shiftValue (LCValue, Int)
dv
    go (LCValue, Int)
_      Int
x e :: LCTerm
e@(LCVar ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
x) -> Bool
True)) = LCTerm
e
    go (LCValue, Int)
_      Int
_ (LCVar Int
y)                 = Int -> LCTerm
LCVar (Int -> LCTerm) -> Int -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    go (LCValue
v, Int
s) Int
x (LCLam LCType
t LCTerm
b)               = LCType -> LCTerm -> LCTerm
LCLam LCType
t (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ (LCValue, Int) -> Int -> LCTerm -> LCTerm
go (LCValue
v, Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b
    go (LCValue, Int)
dv     Int
x (LCApp LCTerm
f LCTerm
a)               = (LCValue, Int) -> Int -> LCTerm -> LCTerm
go (LCValue, Int)
dv Int
x LCTerm
f LCTerm -> LCTerm -> LCTerm
`LCApp` (LCValue, Int) -> Int -> LCTerm -> LCTerm
go (LCValue, Int)
dv Int
x LCTerm
a

substDefNormalInNormal :: (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
substDefNormalInNormal :: (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
substDefNormalInNormal = (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
go
  where
    go :: (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
go (LCNormalTerm
v, Int
s) Int
x (LCNormLam LCType
t LCNormalTerm
b) = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam LCType
t (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
go (LCNormalTerm
v, Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go (LCNormalTerm, Int)
dv     Int
x (LCNormNeut LCNeutralTerm
nt) = (LCNormalTerm, Int) -> Int -> LCNeutralTerm -> LCNormalTerm
substDefNormalInNeutral (LCNormalTerm, Int)
dv Int
x LCNeutralTerm
nt

substDefNormalInNeutral :: (LCNormalTerm, Int) -> Int -> LCNeutralTerm -> LCNormalTerm
substDefNormalInNeutral :: (LCNormalTerm, Int) -> Int -> LCNeutralTerm -> LCNormalTerm
substDefNormalInNeutral (LCNormalTerm, Int)
dv Int
x = LCNeutralTerm -> LCNormalTerm
go
  where
    go :: LCNeutralTerm -> LCNormalTerm
go (LCNeutVar ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x) -> Bool
True)) = (LCNormalTerm, Int) -> LCNormalTerm
shiftNormal (LCNormalTerm, Int)
dv
    go e :: LCNeutralTerm
e@(LCNeutVar ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
x) -> Bool
True)) = LCNeutralTerm -> LCNormalTerm
LCNormNeut LCNeutralTerm
e
    go (LCNeutVar Int
y) = LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm)
-> (Int -> LCNeutralTerm) -> Int -> LCNormalTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> LCNeutralTerm
LCNeutVar (Int -> LCNormalTerm) -> Int -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    go (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) =
      case LCNeutralTerm -> LCNormalTerm
go LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
b -> LCNormalTerm -> Int -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal LCNormalTerm
a' Int
0 LCNormalTerm
b
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` LCNormalTerm
a'
      where
        a' :: LCNormalTerm
a' = (LCNormalTerm, Int) -> Int -> LCNormalTerm -> LCNormalTerm
substDefNormalInNormal (LCNormalTerm, Int)
dv Int
x LCNormalTerm
a

shift :: (LCTerm, Int) -> LCTerm
shift :: (LCTerm, Int) -> LCTerm
shift = Int -> (LCTerm, Int) -> LCTerm
shiftMin Int
0

shiftMin :: Int -> (LCTerm, Int) -> LCTerm
shiftMin :: Int -> (LCTerm, Int) -> LCTerm
shiftMin Int
n' (LCTerm
v, Int
s) = Int -> LCTerm -> LCTerm
go Int
n' LCTerm
v
  where
    go :: Int -> LCTerm -> LCTerm
go Int
n (LCVar Int
x)   = Int -> LCTerm
LCVar (Int -> LCTerm) -> Int -> LCTerm
forall a b. (a -> b) -> a -> b
$ if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n then Int
x else Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s
    go Int
n (LCLam LCType
t LCTerm
b) = LCType -> LCTerm -> LCTerm
LCLam LCType
t (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b
    go Int
n (LCApp LCTerm
f LCTerm
a) = Int -> LCTerm -> LCTerm
go Int
n LCTerm
f LCTerm -> LCTerm -> LCTerm
`LCApp` Int -> LCTerm -> LCTerm
go Int
n LCTerm
a

shiftValue :: (LCValue, Int) -> LCTerm
shiftValue :: (LCValue, Int) -> LCTerm
shiftValue (LCValue
v, Int
s) = (LCTerm, Int) -> LCTerm
shift (LCValue -> LCTerm
liftLCValue LCValue
v, Int
s)

shiftNormal :: (LCNormalTerm, Int) -> LCNormalTerm
shiftNormal :: (LCNormalTerm, Int) -> LCNormalTerm
shiftNormal = Int -> (LCNormalTerm, Int) -> LCNormalTerm
shiftNormalMin Int
0

shiftNormalMin :: Int -> (LCNormalTerm, Int) -> LCNormalTerm
shiftNormalMin :: Int -> (LCNormalTerm, Int) -> LCNormalTerm
shiftNormalMin Int
n' (LCNormalTerm
v, Int
s) = Int -> LCNormalTerm -> LCNormalTerm
go Int
n' LCNormalTerm
v
  where
    go :: Int -> LCNormalTerm -> LCNormalTerm
go Int
n (LCNormLam LCType
t LCNormalTerm
b) = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam LCType
t (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go Int
n (LCNormNeut LCNeutralTerm
nt) = LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> (LCNeutralTerm, Int) -> LCNeutralTerm
shiftNeutralMin Int
n (LCNeutralTerm
nt, Int
s)

shiftNeutralMin :: Int -> (LCNeutralTerm, Int) -> LCNeutralTerm
shiftNeutralMin :: Int -> (LCNeutralTerm, Int) -> LCNeutralTerm
shiftNeutralMin Int
n (LCNeutralTerm
v, Int
s) = LCNeutralTerm -> LCNeutralTerm
go LCNeutralTerm
v
  where
    go :: LCNeutralTerm -> LCNeutralTerm
go (LCNeutVar Int
x)   = Int -> LCNeutralTerm
LCNeutVar (Int -> LCNeutralTerm) -> Int -> LCNeutralTerm
forall a b. (a -> b) -> a -> b
$ if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n then Int
x else Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s
    go (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) = LCNeutralTerm -> LCNeutralTerm
go LCNeutralTerm
f LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` Int -> (LCNormalTerm, Int) -> LCNormalTerm
shiftNormalMin Int
n (LCNormalTerm
a, Int
s)