module LambdaCube.SystemF.Substitution where

import           LambdaCube.SystemF.Ast
import           LambdaCube.SystemF.Lifter

substituteType :: Int -> LCType -> LCTerm -> LCTerm
substituteType :: Int -> LCType -> LCTerm -> LCTerm
substituteType Int
n LCType
v = Int -> LCTerm -> LCTerm
go Int
n
  where
    go :: Int -> LCTerm -> LCTerm
go Int
_ e :: LCTerm
e@(LCVar Int
_)  = LCTerm
e
    go Int
m (LCLam LCType
t LCTerm
b)  = LCType -> LCTerm -> LCTerm
LCLam (Int -> LCType -> LCType -> LCType
substituteTypeInType Int
m LCType
v LCType
t) (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go Int
m LCTerm
b
    go Int
m (LCApp LCTerm
f LCTerm
a)  = Int -> LCTerm -> LCTerm
go Int
m LCTerm
f LCTerm -> LCTerm -> LCTerm
`LCApp` Int -> LCTerm -> LCTerm
go Int
m LCTerm
a
    go Int
m (LCTLam LCTerm
b)   = LCTerm -> LCTerm
LCTLam (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b
    go Int
m (LCTApp LCTerm
f LCType
t) = Int -> LCTerm -> LCTerm
go Int
m LCTerm
f LCTerm -> LCType -> LCTerm
`LCTApp` Int -> LCType -> LCType -> LCType
substituteTypeInType Int
m LCType
v LCType
t

substituteTypeInType :: Int -> LCType -> LCType -> LCType
substituteTypeInType :: Int -> LCType -> LCType -> LCType
substituteTypeInType Int
n LCType
v = Int -> LCType -> LCType
go Int
n
  where
    go :: Int -> LCType -> LCType
go Int
_ LCType
LCBase       = LCType
LCBase
    go Int
m e :: LCType
e@(LCTVar Int
l) = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCType
v else LCType
e
    go Int
m (LCArr LCType
a LCType
b)  = Int -> LCType -> LCType
go Int
m LCType
a LCType -> LCType -> LCType
`LCArr` Int -> LCType -> LCType
go Int
m LCType
b
    go Int
m (LCUniv LCType
a)   = LCType -> LCType
LCUniv (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ Int -> LCType -> LCType
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCType
a

substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue :: Int -> LCValue -> LCTerm -> LCTerm
substituteValue Int
n LCValue
v = Int -> LCTerm -> LCTerm
go Int
n
  where
    go :: Int -> LCTerm -> LCTerm
go Int
m e :: LCTerm
e@(LCVar Int
l)  = if Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then LCValue -> LCTerm
liftLCValue LCValue
v else LCTerm
e
    go Int
m (LCLam LCType
t LCTerm
b)  = LCType -> LCTerm -> LCTerm
LCLam LCType
t (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCTerm
b
    go Int
m (LCApp LCTerm
f LCTerm
a)  = Int -> LCTerm -> LCTerm
go Int
m LCTerm
f LCTerm -> LCTerm -> LCTerm
`LCApp` Int -> LCTerm -> LCTerm
go Int
m LCTerm
a
    go Int
m (LCTLam LCTerm
b)   = LCTerm -> LCTerm
LCTLam (LCTerm -> LCTerm) -> LCTerm -> LCTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCTerm -> LCTerm
go Int
m LCTerm
b
    go Int
m (LCTApp LCTerm
f LCType
t) = Int -> LCTerm -> LCTerm
go Int
m LCTerm
f LCTerm -> LCType -> LCTerm
`LCTApp` LCType
t

substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal :: Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
n LCNormalTerm
v = Int -> LCNormalTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNormalTerm -> LCNormalTerm
go Int
m (LCNormLam LCType
t LCNormalTerm
b) = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam LCType
t (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go Int
m (LCNormTLam LCNormalTerm
b)  = LCNormalTerm -> LCNormalTerm
LCNormTLam (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go Int
m LCNormalTerm
b
    go Int
m (LCNormNeut LCNeutralTerm
nt) = Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
m LCNormalTerm
v LCNeutralTerm
nt

substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral :: Int -> LCNormalTerm -> LCNeutralTerm -> LCNormalTerm
substituteNormalInNeutral Int
n LCNormalTerm
v = Int -> LCNeutralTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNeutralTerm -> LCNormalTerm
go Int
m e :: LCNeutralTerm
e@(LCNeutVar Int
l)
      | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l = LCNormalTerm
v
      | Bool
otherwise = LCNeutralTerm -> LCNormalTerm
LCNormNeut LCNeutralTerm
e
    go Int
m (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
b -> Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
0 LCNormalTerm
a' LCNormalTerm
b
        LCNormTLam LCNormalTerm
_  -> [Char] -> LCNormalTerm
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really type check this?"
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` LCNormalTerm
a'
      where
        a' :: LCNormalTerm
a' = Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
m LCNormalTerm
v LCNormalTerm
a
    go Int
m (LCNeutTApp LCNeutralTerm
f LCType
t) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
_ -> [Char] -> LCNormalTerm
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really type check this?"
        LCNormTLam LCNormalTerm
b  -> Int -> LCType -> LCNormalTerm -> LCNormalTerm
substituteTypeInNormal Int
0 LCType
t LCNormalTerm
b
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCType -> LCNeutralTerm
`LCNeutTApp` LCType
t

substituteTypeInNormal :: Int -> LCType -> LCNormalTerm -> LCNormalTerm
substituteTypeInNormal :: Int -> LCType -> LCNormalTerm -> LCNormalTerm
substituteTypeInNormal Int
n LCType
v = Int -> LCNormalTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNormalTerm -> LCNormalTerm
go Int
m (LCNormLam LCType
t LCNormalTerm
b) = LCType -> LCNormalTerm -> LCNormalTerm
LCNormLam (Int -> LCType -> LCType -> LCType
substituteTypeInType Int
m LCType
v LCType
t) (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go Int
m LCNormalTerm
b
    go Int
m (LCNormTLam LCNormalTerm
b)  = LCNormalTerm -> LCNormalTerm
LCNormTLam (LCNormalTerm -> LCNormalTerm) -> LCNormalTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ Int -> LCNormalTerm -> LCNormalTerm
go (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) LCNormalTerm
b
    go Int
m (LCNormNeut LCNeutralTerm
nt) = Int -> LCType -> LCNeutralTerm -> LCNormalTerm
substituteTypeInNeutral Int
m LCType
v LCNeutralTerm
nt

substituteTypeInNeutral :: Int -> LCType -> LCNeutralTerm -> LCNormalTerm
substituteTypeInNeutral :: Int -> LCType -> LCNeutralTerm -> LCNormalTerm
substituteTypeInNeutral Int
n LCType
v = Int -> LCNeutralTerm -> LCNormalTerm
go Int
n
  where
    go :: Int -> LCNeutralTerm -> LCNormalTerm
go Int
_ e :: LCNeutralTerm
e@(LCNeutVar Int
_) = LCNeutralTerm -> LCNormalTerm
LCNormNeut LCNeutralTerm
e
    go Int
m (LCNeutApp LCNeutralTerm
f LCNormalTerm
a) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
b -> Int -> LCNormalTerm -> LCNormalTerm -> LCNormalTerm
substituteNormalInNormal Int
0 LCNormalTerm
a' LCNormalTerm
b
        LCNormTLam LCNormalTerm
_  -> [Char] -> LCNormalTerm
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really type check this?"
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCNormalTerm -> LCNeutralTerm
`LCNeutApp` LCNormalTerm
a'
      where
        a' :: LCNormalTerm
a' = Int -> LCType -> LCNormalTerm -> LCNormalTerm
substituteTypeInNormal Int
m LCType
v LCNormalTerm
a
    go Int
m (LCNeutTApp LCNeutralTerm
f LCType
t) =
      case Int -> LCNeutralTerm -> LCNormalTerm
go Int
m LCNeutralTerm
f of
        LCNormLam LCType
_ LCNormalTerm
_ -> [Char] -> LCNormalTerm
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really type check this?"
        LCNormTLam LCNormalTerm
b  -> Int -> LCType -> LCNormalTerm -> LCNormalTerm
substituteTypeInNormal Int
0 LCType
t' LCNormalTerm
b
        LCNormNeut LCNeutralTerm
nt -> LCNeutralTerm -> LCNormalTerm
LCNormNeut (LCNeutralTerm -> LCNormalTerm) -> LCNeutralTerm -> LCNormalTerm
forall a b. (a -> b) -> a -> b
$ LCNeutralTerm
nt LCNeutralTerm -> LCType -> LCNeutralTerm
`LCNeutTApp` LCType
t'
      where
        t' :: LCType
t' = Int -> LCType -> LCType -> LCType
substituteTypeInType Int
m LCType
v LCType
t