{-# LANGUAGE RankNTypes #-}

module R ( Rs (..), HasRs (..)
         , maxLens
         , rG, rE
         ) where

import           A
import           Control.Monad.State.Strict (StateT, runState)
import           Data.Bifunctor             (second)
import           Data.Functor               (($>))
import qualified Data.IntMap                as IM
import           Lens.Micro                 (Lens')
import           Lens.Micro.Mtl             (modifying, use, (.=))
import           Nm
import           Ty.Clone
import           U

data Rs = Rs { Rs -> Int
max_ :: Int, Rs -> IntMap Int
bound :: IM.IntMap Int }

class HasRs a where
    rename :: Lens' a Rs

instance HasRs Rs where rename :: Lens' Rs Rs
rename = (Rs -> f Rs) -> Rs -> f Rs
forall a. a -> a
id

maxLens :: Lens' Rs Int
maxLens :: Lens' Rs Int
maxLens Int -> f Int
f Rs
s = (Int -> Rs) -> f Int -> f Rs
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
x -> Rs
s { max_ = x }) (Int -> f Int
f (Rs -> Int
max_ Rs
s))

boundLens :: Lens' Rs (IM.IntMap Int)
boundLens :: Lens' Rs (IntMap Int)
boundLens IntMap Int -> f (IntMap Int)
f Rs
s = (IntMap Int -> Rs) -> f (IntMap Int) -> f Rs
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap Int
x -> Rs
s { bound = x }) (IntMap Int -> f (IntMap Int)
f (Rs -> IntMap Int
bound Rs
s))

-- Make sure you don't have cycles in the renames map!
replaceUnique :: (Monad m, HasRs s) => U -> StateT s m U
replaceUnique :: forall (m :: * -> *) s. (Monad m, HasRs s) => U -> StateT s m U
replaceUnique u :: U
u@(U Int
i) = do
    rSt <- Getting (IntMap Int) s (IntMap Int) -> StateT s m (IntMap Int)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((Rs -> Const (IntMap Int) Rs) -> s -> Const (IntMap Int) s
forall a. HasRs a => Lens' a Rs
Lens' s Rs
rename((Rs -> Const (IntMap Int) Rs) -> s -> Const (IntMap Int) s)
-> ((IntMap Int -> Const (IntMap Int) (IntMap Int))
    -> Rs -> Const (IntMap Int) Rs)
-> Getting (IntMap Int) s (IntMap Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(IntMap Int -> Const (IntMap Int) (IntMap Int))
-> Rs -> Const (IntMap Int) Rs
Lens' Rs (IntMap Int)
boundLens)
    case IM.lookup i rSt of
        Maybe Int
Nothing -> U -> StateT s m U
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure U
u
        Just Int
j  -> U -> StateT s m U
forall (m :: * -> *) s. (Monad m, HasRs s) => U -> StateT s m U
replaceUnique (Int -> U
U Int
j)

replaceVar :: (Monad m, HasRs s) => Nm a -> StateT s m (Nm a)
replaceVar :: forall (m :: * -> *) s a.
(Monad m, HasRs s) =>
Nm a -> StateT s m (Nm a)
replaceVar (Nm Text
n U
u a
l) = do
    u' <- U -> StateT s m U
forall (m :: * -> *) s. (Monad m, HasRs s) => U -> StateT s m U
replaceUnique U
u
    pure $ Nm n u' l

doLocal :: (HasRs s, Monad m) => StateT s m a -> StateT s m a
doLocal :: forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
StateT s m a -> StateT s m a
doLocal StateT s m a
act = do
    preB <- Getting (IntMap Int) s (IntMap Int) -> StateT s m (IntMap Int)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((Rs -> Const (IntMap Int) Rs) -> s -> Const (IntMap Int) s
forall a. HasRs a => Lens' a Rs
Lens' s Rs
rename((Rs -> Const (IntMap Int) Rs) -> s -> Const (IntMap Int) s)
-> ((IntMap Int -> Const (IntMap Int) (IntMap Int))
    -> Rs -> Const (IntMap Int) Rs)
-> Getting (IntMap Int) s (IntMap Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(IntMap Int -> Const (IntMap Int) (IntMap Int))
-> Rs -> Const (IntMap Int) Rs
Lens' Rs (IntMap Int)
boundLens)
    act <* ((rename.boundLens) .= preB)

freshen :: (HasRs s, Monad m) => Nm a -> StateT s m (Nm a)
freshen :: forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
Nm a -> StateT s m (Nm a)
freshen (Nm Text
t (U Int
i) a
l) = do
    m <- Getting Int s Int -> StateT s m Int
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((Rs -> Const Int Rs) -> s -> Const Int s
forall a. HasRs a => Lens' a Rs
Lens' s Rs
rename((Rs -> Const Int Rs) -> s -> Const Int s)
-> ((Int -> Const Int Int) -> Rs -> Const Int Rs)
-> Getting Int s Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int -> Const Int Int) -> Rs -> Const Int Rs
Lens' Rs Int
maxLens)
    let nU=Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
    rename.maxLens .= nU
    modifying (rename.boundLens) (IM.insert i nU) $> Nm t (U nU) l

-- globally unique
rG :: Int -> E a -> (E a, Int)
rG :: forall a. Int -> E a -> (E a, Int)
rG Int
i = (Rs -> Int) -> (E a, Rs) -> (E a, Int)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Rs -> Int
max_ ((E a, Rs) -> (E a, Int))
-> (E a -> (E a, Rs)) -> E a -> (E a, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State Rs (E a) -> Rs -> (E a, Rs))
-> Rs -> State Rs (E a) -> (E a, Rs)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State Rs (E a) -> Rs -> (E a, Rs)
forall s a. State s a -> s -> (a, s)
runState (Int -> IntMap Int -> Rs
Rs Int
i IntMap Int
forall a. IntMap a
IM.empty) (State Rs (E a) -> (E a, Rs))
-> (E a -> State Rs (E a)) -> E a -> (E a, Rs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. E a -> State Rs (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE

{-# INLINABLE liftR #-}
liftR :: (HasRs s, Monad m) => T a -> StateT s m (T a)
liftR :: forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
T a -> StateT s m (T a)
liftR T a
t = do
    i <- Getting Int s Int -> StateT s m Int
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((Rs -> Const Int Rs) -> s -> Const Int s
forall a. HasRs a => Lens' a Rs
Lens' s Rs
rename((Rs -> Const Int Rs) -> s -> Const Int s)
-> ((Int -> Const Int Int) -> Rs -> Const Int Rs)
-> Getting Int s Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int -> Const Int Int) -> Rs -> Const Int Rs
Lens' Rs Int
maxLens)
    let (u,t',_) = cloneT i t
    (rename.maxLens .= u) $> t'

{-# INLINABLE rE #-}
rE :: (HasRs s, Monad m) => E a -> StateT s m (E a)
rE :: forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE (Lam a
l Nm a
n E a
e) = StateT s m (E a) -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
StateT s m a -> StateT s m a
doLocal (StateT s m (E a) -> StateT s m (E a))
-> StateT s m (E a) -> StateT s m (E a)
forall a b. (a -> b) -> a -> b
$ do
    n' <- Nm a -> StateT s m (Nm a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
Nm a -> StateT s m (Nm a)
freshen Nm a
n
    Lam l n' <$> rE e
rE (Let a
l (Nm a
n, E a
) E a
e) = do
    eϵ' <- E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a

    n' <- freshen n
    Let l (n', eϵ') <$> rE e
rE (Def a
l (Nm a
n, E a
) E a
e) = do
    eϵ' <- E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a

    n' <- freshen n
    Def l (n', eϵ') <$> rE e
rE (LLet a
l (Nm a
n, E a
) E a
e) = do
    eϵ' <- E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a

    n' <- freshen n
    LLet l (n', eϵ') <$> rE e
rE e :: E a
e@Builtin{} = E a -> StateT s m (E a)
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure E a
e
rE e :: E a
e@FLit{} = E a -> StateT s m (E a)
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure E a
e
rE e :: E a
e@ILit{} = E a -> StateT s m (E a)
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure E a
e
rE e :: E a
e@BLit{} = E a -> StateT s m (E a)
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure E a
e
rE (ALit a
l [E a]
es) = a -> [E a] -> E a
forall a. a -> [E a] -> E a
ALit a
l ([E a] -> E a) -> StateT s m [E a] -> StateT s m (E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (E a -> StateT s m (E a)) -> [E a] -> StateT s m [E a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE [E a]
es
rE (Tup a
l [E a]
es) = a -> [E a] -> E a
forall a. a -> [E a] -> E a
Tup a
l ([E a] -> E a) -> StateT s m [E a] -> StateT s m (E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (E a -> StateT s m (E a)) -> [E a] -> StateT s m [E a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE [E a]
es
rE (EApp a
l E a
e E a
e') = a -> E a -> E a -> E a
forall a. a -> E a -> E a -> E a
EApp a
l (E a -> E a -> E a) -> StateT s m (E a) -> StateT s m (E a -> E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e StateT s m (E a -> E a) -> StateT s m (E a) -> StateT s m (E a)
forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e'
rE (Cond a
l E a
e E a
e' E a
e'') = a -> E a -> E a -> E a -> E a
forall a. a -> E a -> E a -> E a -> E a
Cond a
l (E a -> E a -> E a -> E a)
-> StateT s m (E a) -> StateT s m (E a -> E a -> E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e StateT s m (E a -> E a -> E a)
-> StateT s m (E a) -> StateT s m (E a -> E a)
forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e' StateT s m (E a -> E a) -> StateT s m (E a) -> StateT s m (E a)
forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e''
rE (Var a
l Nm a
n) = a -> Nm a -> E a
forall a. a -> Nm a -> E a
Var a
l (Nm a -> E a) -> StateT s m (Nm a) -> StateT s m (E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Nm a -> StateT s m (Nm a)
forall (m :: * -> *) s a.
(Monad m, HasRs s) =>
Nm a -> StateT s m (Nm a)
replaceVar Nm a
n
rE (Ann a
l E a
e T a
t) = a -> E a -> T a -> E a
forall a. a -> E a -> T a -> E a
Ann a
l (E a -> T a -> E a) -> StateT s m (E a) -> StateT s m (T a -> E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> E a -> StateT s m (E a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE E a
e StateT s m (T a -> E a) -> StateT s m (T a) -> StateT s m (E a)
forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> T a -> StateT s m (T a)
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
T a -> StateT s m (T a)
liftR T a
t
rE (Id a
l (AShLit [Int]
is [E (T ())]
es)) = a -> Idiom -> E a
forall a. a -> Idiom -> E a
Id a
l (Idiom -> E a) -> ([E (T ())] -> Idiom) -> [E (T ())] -> E a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [E (T ())] -> Idiom
AShLit [Int]
is ([E (T ())] -> E a) -> StateT s m [E (T ())] -> StateT s m (E a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (E (T ()) -> StateT s m (E (T ())))
-> [E (T ())] -> StateT s m [E (T ())]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse E (T ()) -> StateT s m (E (T ()))
forall s (m :: * -> *) a.
(HasRs s, Monad m) =>
E a -> StateT s m (E a)
rE [E (T ())]
es