module DDC.Core.Transform.LiftX
(LiftX(..))
where
import DDC.Core.Exp
class LiftX (c :: * -> *) where
liftAtDepthX
:: forall n. Ord n
=> Int
-> Int
-> c n
-> c n
liftX :: forall n. Ord n
=> Int
-> c n
-> c n
liftX n xx = liftAtDepthX n 0 xx
instance LiftX Bound where
liftAtDepthX n d uu
= case uu of
UName{} -> uu
UPrim{} -> uu
UIx i t
| d <= i -> UIx (i + n) t
| otherwise -> uu
instance LiftX (Exp a) where
liftAtDepthX n d xx
= let down = liftAtDepthX n d
in case xx of
XVar a u -> XVar a (down u)
XCon{} -> xx
XApp a x1 x2 -> XApp a (down x1) (down x2)
XLAM a b x -> XLAM a b (down x)
XLam a b x -> XLam a b (liftAtDepthX n (d + 1) x)
XLet a lets x
-> let (lets', levels) = liftAtDepthXLets n d lets
in XLet a lets' (liftAtDepthX n (d + levels) x)
XCase a x alts -> XCase a (down x) (map down alts)
XCast a cc x -> XCast a cc (down x)
XType{} -> xx
XWitness{} -> xx
instance LiftX (Alt a) where
liftAtDepthX n d (AAlt p x)
= case p of
PDefault
-> AAlt PDefault (liftAtDepthX n d x)
PData _ bs
-> let d' = d + countBAnons bs
in AAlt p (liftAtDepthX n d' x)
liftAtDepthXLets
:: forall a n. Ord n
=> Int
-> Int
-> Lets a n
-> (Lets a n, Int)
liftAtDepthXLets n d lts
= case lts of
LLet m b x
-> let inc = countBAnons [b]
x' = liftAtDepthX n (d+inc) x
in (LLet m b x', inc)
LRec bs
-> let inc = countBAnons (map fst bs)
bs' = map (\(b,e) -> (b, liftAtDepthX n (d+inc) e)) bs
in (LRec bs', inc)
LLetRegion _b bs -> (lts, countBAnons bs)
LWithRegion _ -> (lts, 0)
countBAnons = length . filter isAnon
where isAnon (BAnon _) = True
isAnon _ = False