{-# LANGUAGE RankNTypes #-}

module Ty.Clone ( cloneT ) where


import           A
import           Control.Monad.State.Strict (State, gets, runState)
import qualified Data.IntMap                as IM
import           Lens.Micro                 (Lens')
import           Lens.Micro.Mtl             (modifying, use)
import           Nm
import           U

data TR = TR { TR -> Int
maxT    :: Int
             , TR -> IntMap Int
boundTV :: IM.IntMap Int
             , TR -> IntMap Int
boundSh :: IM.IntMap Int
             , TR -> IntMap Int
boundIx :: IM.IntMap Int
             }

type CM = State TR

maxTLens :: Lens' TR Int
maxTLens :: Lens' TR Int
maxTLens Int -> f Int
f TR
s = (Int -> TR) -> f Int -> f TR
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
x -> TR
s { maxT = x }) (Int -> f Int
f (TR -> Int
maxT TR
s))

boundTVLens :: Lens' TR (IM.IntMap Int)
boundTVLens :: Lens' TR (IntMap Int)
boundTVLens IntMap Int -> f (IntMap Int)
f TR
s = (IntMap Int -> TR) -> f (IntMap Int) -> f TR
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap Int
x -> TR
s { boundTV = x }) (IntMap Int -> f (IntMap Int)
f (TR -> IntMap Int
boundTV TR
s))

boundShLens :: Lens' TR (IM.IntMap Int)
boundShLens :: Lens' TR (IntMap Int)
boundShLens IntMap Int -> f (IntMap Int)
f TR
s = (IntMap Int -> TR) -> f (IntMap Int) -> f TR
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap Int
x -> TR
s { boundSh = x }) (IntMap Int -> f (IntMap Int)
f (TR -> IntMap Int
boundSh TR
s))

boundIxLens :: Lens' TR (IM.IntMap Int)
boundIxLens :: Lens' TR (IntMap Int)
boundIxLens IntMap Int -> f (IntMap Int)
f TR
s = (IntMap Int -> TR) -> f (IntMap Int) -> f TR
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IntMap Int
x -> TR
s { boundIx = x }) (IntMap Int -> f (IntMap Int)
f (TR -> IntMap Int
boundIx TR
s))

-- for clone
freshen :: Lens' TR (IM.IntMap Int) -- ^ TVars, shape var, etc.
        -> Nm a -> CM (Nm a)
freshen :: forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
freshen Lens' TR (IntMap Int)
lens (Nm Text
n (U Int
i) a
l) = do
    ASetter TR TR Int Int -> (Int -> Int) -> StateT TR Identity ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
modifying ASetter TR TR Int Int
Lens' TR Int
maxTLens (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    j <- (TR -> Int) -> StateT TR Identity Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TR -> Int
maxT
    modifying lens (IM.insert i j)
    pure $ Nm n (U j) l

tryReplaceInT :: Lens' TR (IM.IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT :: forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT Lens' TR (IntMap Int)
lens n :: Nm a
n@(Nm Text
t (U Int
i) a
l) = do
    st <- Getting (IntMap Int) TR (IntMap Int)
-> StateT TR Identity (IntMap Int)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (IntMap Int) TR (IntMap Int)
Lens' TR (IntMap Int)
lens
    case IM.lookup i st of
        Just Int
j  -> Nm a -> StateT TR Identity (Nm a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> U -> a -> Nm a
forall a. Text -> U -> a -> Nm a
Nm Text
t (Int -> U
U Int
j) a
l)
        Maybe Int
Nothing -> Lens' TR (IntMap Int) -> Nm a -> StateT TR Identity (Nm a)
forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
freshen (IntMap Int -> f (IntMap Int)) -> TR -> f TR
Lens' TR (IntMap Int)
lens Nm a
n

cloneT :: Int -> T a
              -> (Int, T a, IM.IntMap Int) -- ^ Substition on type variables, returned so constraints can be propagated/copied
cloneT :: forall a. Int -> T a -> (Int, T a, IntMap Int)
cloneT Int
u = (\(T a
t, TR Int
 IntMap Int
tvs IntMap Int
_ IntMap Int
_) -> (Int
,T a
t,IntMap Int
tvs))((T a, TR) -> (Int, T a, IntMap Int))
-> (T a -> (T a, TR)) -> T a -> (Int, T a, IntMap Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(State TR (T a) -> TR -> (T a, TR))
-> TR -> State TR (T a) -> (T a, TR)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State TR (T a) -> TR -> (T a, TR)
forall s a. State s a -> s -> (a, s)
runState (Int -> IntMap Int -> IntMap Int -> IntMap Int -> TR
TR Int
u IntMap Int
forall a. IntMap a
IM.empty IntMap Int
forall a. IntMap a
IM.empty IntMap Int
forall a. IntMap a
IM.empty)(State TR (T a) -> (T a, TR))
-> (T a -> State TR (T a)) -> T a -> (T a, TR)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.T a -> State TR (T a)
forall a. T a -> CM (T a)
cT
  where
    cloneIx :: I a -> CM (I a)
    cloneIx :: forall a. I a -> CM (I a)
cloneIx i :: I a
i@Ix{}           = I a -> StateT TR Identity (I a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure I a
i
    cloneIx (StaPlus a
l I a
i I a
i') = a -> I a -> I a -> I a
forall a. a -> I a -> I a -> I a
StaPlus a
l (I a -> I a -> I a)
-> StateT TR Identity (I a) -> StateT TR Identity (I a -> I a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> I a -> StateT TR Identity (I a)
forall a. I a -> CM (I a)
cloneIx I a
i StateT TR Identity (I a -> I a)
-> StateT TR Identity (I a) -> StateT TR Identity (I a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> I a -> StateT TR Identity (I a)
forall a. I a -> CM (I a)
cloneIx I a
i'
    cloneIx (StaMul a
l I a
i I a
i')  = a -> I a -> I a -> I a
forall a. a -> I a -> I a -> I a
StaMul a
l (I a -> I a -> I a)
-> StateT TR Identity (I a) -> StateT TR Identity (I a -> I a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> I a -> StateT TR Identity (I a)
forall a. I a -> CM (I a)
cloneIx I a
i StateT TR Identity (I a -> I a)
-> StateT TR Identity (I a) -> StateT TR Identity (I a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> I a -> StateT TR Identity (I a)
forall a. I a -> CM (I a)
cloneIx I a
i'
    cloneIx (IVar a
l Nm a
n)       = a -> Nm a -> I a
forall a. a -> Nm a -> I a
IVar a
l (Nm a -> I a)
-> StateT TR Identity (Nm a) -> StateT TR Identity (I a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' TR (IntMap Int) -> Nm a -> StateT TR Identity (Nm a)
forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT (IntMap Int -> f (IntMap Int)) -> TR -> f TR
Lens' TR (IntMap Int)
boundIxLens Nm a
n
    cloneIx (IEVar a
l Nm a
n)      = a -> Nm a -> I a
forall a. a -> Nm a -> I a
IEVar a
l (Nm a -> I a)
-> StateT TR Identity (Nm a) -> StateT TR Identity (I a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' TR (IntMap Int) -> Nm a -> StateT TR Identity (Nm a)
forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT (IntMap Int -> f (IntMap Int)) -> TR -> f TR
Lens' TR (IntMap Int)
boundIxLens Nm a
n

    cloneSh :: Sh a -> CM (Sh a)
    cloneSh :: forall a. Sh a -> CM (Sh a)
cloneSh Sh a
Nil           = Sh a -> StateT TR Identity (Sh a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Sh a
forall a. Sh a
Nil
    cloneSh (Cons I a
i Sh a
sh)   = I a -> Sh a -> Sh a
forall a. I a -> Sh a -> Sh a
Cons (I a -> Sh a -> Sh a)
-> StateT TR Identity (I a) -> StateT TR Identity (Sh a -> Sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> I a -> StateT TR Identity (I a)
forall a. I a -> CM (I a)
cloneIx I a
i StateT TR Identity (Sh a -> Sh a)
-> StateT TR Identity (Sh a) -> StateT TR Identity (Sh a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Sh a -> StateT TR Identity (Sh a)
forall a. Sh a -> CM (Sh a)
cloneSh Sh a
sh
    cloneSh (SVar Nm a
n)      = Nm a -> Sh a
forall a. Nm a -> Sh a
SVar (Nm a -> Sh a)
-> StateT TR Identity (Nm a) -> StateT TR Identity (Sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' TR (IntMap Int) -> Nm a -> StateT TR Identity (Nm a)
forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT (IntMap Int -> f (IntMap Int)) -> TR -> f TR
Lens' TR (IntMap Int)
boundShLens Nm a
n
    cloneSh (Rev Sh a
sh)      = Sh a -> Sh a
forall a. Sh a -> Sh a
Rev (Sh a -> Sh a)
-> StateT TR Identity (Sh a) -> StateT TR Identity (Sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Sh a -> StateT TR Identity (Sh a)
forall a. Sh a -> CM (Sh a)
cloneSh Sh a
sh
    cloneSh (Cat Sh a
sh0 Sh a
sh1) = Sh a -> Sh a -> Sh a
forall a. Sh a -> Sh a -> Sh a
Cat (Sh a -> Sh a -> Sh a)
-> StateT TR Identity (Sh a) -> StateT TR Identity (Sh a -> Sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Sh a -> StateT TR Identity (Sh a)
forall a. Sh a -> CM (Sh a)
cloneSh Sh a
sh0 StateT TR Identity (Sh a -> Sh a)
-> StateT TR Identity (Sh a) -> StateT TR Identity (Sh a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Sh a -> StateT TR Identity (Sh a)
forall a. Sh a -> CM (Sh a)
cloneSh Sh a
sh1

    cT :: T a -> CM (T a)
    cT :: forall a. T a -> CM (T a)
cT T a
F            = T a -> StateT TR Identity (T a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure T a
forall a. T a
F
    cT T a
I            = T a -> StateT TR Identity (T a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure T a
forall a. T a
I
    cT T a
B            = T a -> StateT TR Identity (T a)
forall a. a -> StateT TR Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure T a
forall a. T a
B
    cT (Arrow T a
t T a
t') = T a -> T a -> T a
forall a. T a -> T a -> T a
Arrow (T a -> T a -> T a)
-> StateT TR Identity (T a) -> StateT TR Identity (T a -> T a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> T a -> StateT TR Identity (T a)
forall a. T a -> CM (T a)
cT T a
t StateT TR Identity (T a -> T a)
-> StateT TR Identity (T a) -> StateT TR Identity (T a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> T a -> StateT TR Identity (T a)
forall a. T a -> CM (T a)
cT T a
t'
    cT (Arr Sh a
sh T a
t)   = Sh a -> T a -> T a
forall a. Sh a -> T a -> T a
Arr (Sh a -> T a -> T a)
-> StateT TR Identity (Sh a) -> StateT TR Identity (T a -> T a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Sh a -> StateT TR Identity (Sh a)
forall a. Sh a -> CM (Sh a)
cloneSh Sh a
sh StateT TR Identity (T a -> T a)
-> StateT TR Identity (T a) -> StateT TR Identity (T a)
forall a b.
StateT TR Identity (a -> b)
-> StateT TR Identity a -> StateT TR Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> T a -> StateT TR Identity (T a)
forall a. T a -> CM (T a)
cT T a
t
    cT (TVar Nm a
n)     = Nm a -> T a
forall a. Nm a -> T a
TVar (Nm a -> T a)
-> StateT TR Identity (Nm a) -> StateT TR Identity (T a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' TR (IntMap Int) -> Nm a -> StateT TR Identity (Nm a)
forall a. Lens' TR (IntMap Int) -> Nm a -> CM (Nm a)
tryReplaceInT (IntMap Int -> f (IntMap Int)) -> TR -> f TR
Lens' TR (IntMap Int)
boundTVLens Nm a
n
    cT (P [T a]
ts)       = [T a] -> T a
forall a. [T a] -> T a
P ([T a] -> T a)
-> StateT TR Identity [T a] -> StateT TR Identity (T a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (T a -> StateT TR Identity (T a))
-> [T a] -> StateT TR Identity [T a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse T a -> StateT TR Identity (T a)
forall a. T a -> CM (T a)
cT [T a]
ts