module DDC.Type.Rewrite
( Rewrite(..)
, Sub(..)
, BindStack(..)
, pushBind
, pushBinds
, substBound
, bind1, bind0, bind0s
, use1, use0)
where
import DDC.Core.Exp
import DDC.Type.Compounds
import Data.List
import Data.Set (Set)
import qualified DDC.Type.Sum as Sum
import qualified Data.Set as Set
data Sub n
= Sub
{
subBound :: Bound n
, subShadow0 :: Bool
, subConflict1 :: Set n
, subConflict0 :: Set n
, subStack1 :: BindStack n
, subStack0 :: BindStack n }
data BindStack n
= BindStack
{
stackBinds :: [Bind n]
, stackAll :: [Bind n]
, stackAnons :: Int
, stackNamed :: Int }
pushBinds :: Ord n => Set n -> BindStack n -> [Bind n] -> (BindStack n, [Bind n])
pushBinds fns stack bs
= mapAccumL (pushBind fns) stack bs
pushBind
:: Ord n
=> Set n
-> BindStack n
-> Bind n
-> (BindStack n, Bind n)
pushBind fns bs@(BindStack stack env dAnon dName) bb
= case bb of
BAnon t
-> ( BindStack (BAnon t : stack) (BAnon t : env) (dAnon + 1) dName
, BAnon t)
BName n t
| Set.member n fns
-> ( BindStack (BName n t : stack) (BAnon t : env) dAnon (dName + 1)
, BAnon t)
| otherwise
-> ( BindStack stack (BName n t : env) dAnon dName
, bb)
_ -> (bs, bb)
substBound
:: Ord n
=> BindStack n
-> Bound n
-> Bound n
-> Either
(Bound n)
Int
substBound (BindStack binds _ dAnon dName) u u'
| UName n1 _ <- u
, UName n2 _ <- u'
, n1 == n2
= Right (dAnon + dName)
| UIx i1 _ <- u
, UIx i2 _ <- u'
, i1 + dAnon == i2
= Right (dAnon + dName)
| UName _ t <- u'
, Just ix <- findIndex (boundMatchesBind u') binds
= Left $ UIx ix t
| UIx i2 t <- u'
, i2 > dAnon
, cutOffset <- case u of
UIx{} -> 1
_ -> 0
= Left $ UIx (i2 + dName cutOffset) t
| otherwise
= Left u'
bind1 :: Ord n => Sub n -> Bind n -> (Sub n, Bind n)
bind1 sub b
= let (stackT', b') = pushBind (subConflict1 sub) (subStack1 sub) b
in (sub { subStack1 = stackT' }, b')
bind0 :: Ord n => Sub n -> Bind n -> (Sub n, Bind n)
bind0 sub b
= let b1 = rewriteWith sub b
(stackX', b2) = pushBind (subConflict0 sub) (subStack0 sub) b1
in ( sub { subStack0 = stackX'
, subShadow0 = subShadow0 sub
|| namedBoundMatchesBind (subBound sub) b2 }
, b2)
bind0s :: Ord n => Sub n -> [Bind n] -> (Sub n, [Bind n])
bind0s = mapAccumL bind0
use1 :: Ord n => Sub n -> Bound n -> Bound n
use1 sub u
| UName _ t <- u
, BindStack binds _ _ _ <- subStack1 sub
, Just ix <- findIndex (boundMatchesBind u) binds
= UIx ix t
| otherwise
= u
use0 :: Ord n => Sub n -> Bound n -> Bound n
use0 sub u
| UName _ t <- u
, BindStack binds _ _ _ <- subStack0 sub
, Just ix <- findIndex (boundMatchesBind u) binds
= UIx ix (rewriteWith sub t)
| otherwise
= rewriteWith sub u
class Rewrite (c :: * -> *) where
rewriteWith :: Ord n => Sub n -> c n -> c n
instance Rewrite Bind where
rewriteWith sub bb
= replaceTypeOfBind (rewriteWith sub (typeOfBind bb)) bb
instance Rewrite Bound where
rewriteWith sub uu
= replaceTypeOfBound (rewriteWith sub (typeOfBound uu)) uu
instance Rewrite LetMode where
rewriteWith sub lm
= case lm of
LetStrict -> lm
LetLazy (Just t) -> LetLazy (Just $ rewriteWith sub t)
LetLazy Nothing -> LetLazy Nothing
instance Rewrite Cast where
rewriteWith sub cc
= let down = rewriteWith sub
in case cc of
CastWeakenEffect eff -> CastWeakenEffect (down eff)
CastWeakenClosure clo -> CastWeakenClosure (down clo)
CastPurify w -> CastPurify (down w)
CastForget w -> CastForget (down w)
instance Rewrite Type where
rewriteWith sub tt
= let down = rewriteWith
in case tt of
TVar u -> TVar (use1 sub u)
TCon{} -> tt
TForall b t
-> let (sub1, b') = bind1 sub b
t' = down sub1 t
in TForall b' t'
TApp t1 t2 -> TApp (down sub t1) (down sub t2)
TSum ts -> TSum (down sub ts)
instance Rewrite TypeSum where
rewriteWith sub ts
= Sum.fromList (Sum.kindOfSum ts)
$ map (rewriteWith sub)
$ Sum.toList ts
instance Rewrite Witness where
rewriteWith sub ww
= let down = rewriteWith
in case ww of
WVar u -> WVar (use0 sub u)
WCon{} -> ww
WApp w1 w2 -> WApp (down sub w1) (down sub w2)
WJoin w1 w2 -> WJoin (down sub w1) (down sub w2)
WType t -> WType (down sub t)