module DDC.Core.Transform.Rewrite.Match
(
SubstInfo
, emptySubstInfo
, match)
where
import DDC.Core.Exp
import DDC.Type.Transform.Crush
import Data.Set (Set)
import Data.Map (Map)
import qualified DDC.Type.Sum as Sum
import qualified DDC.Type.Transform.AnonymizeT as T
import qualified DDC.Core.Transform.AnonymizeX as T
import qualified DDC.Core.Transform.Reannotate as T
import qualified DDC.Type.Equiv as TE
import qualified Data.Map as Map
import qualified Data.Set as Set
type SubstInfo a n
= (Map n (Exp a n), Map n (Type n))
emptySubstInfo :: SubstInfo a n
emptySubstInfo
= (Map.empty, Map.empty)
lookupx n (xs,_)
= Map.lookup n xs
insertx n x (xs,tys)
= (Map.insert n x xs, tys)
match :: (Show a, Show n, Ord n)
=> SubstInfo a n
-> Set n
-> Exp a n
-> Exp a n
-> Maybe (SubstInfo a n)
match m bs (XVar _ (UName n)) r
| n `Set.member` bs
= case lookupx n m of
Nothing -> return $ insertx n r m
Just x
->
let x' = T.anonymizeX $ T.reannotate (const ()) x
r' = T.anonymizeX $ T.reannotate (const ()) r
in if x' == r'
then Just m
else Nothing
match m _ (XVar _ v1) (XVar _ v2)
| v1 == v2 = Just m
match m _ (XCon _ c1) (XCon _ c2)
| c1 == c2 = Just m
match m bs (XApp _ x11 x12) (XApp _ x21 x22)
= do m' <- match m bs x11 x21
match m' bs x12 x22
match m bs (XCast _ c1 x1) (XCast _ c2 x2)
| eqCast c1 c2
= match m bs x1 x2
match (xs, tys) bs (XType _ t1) (XType _ t2)
= do tys' <- matchT t1 t2 bs tys
return (xs, tys')
match m _ (XWitness _ w1) (XWitness _ w2)
| eqWit w1 w2 = return m
match _ _ _ _
= Nothing
eqCast :: Ord n => Cast a n -> Cast a n -> Bool
eqCast lc rc
= clean lc == clean rc
where clean c
= T.reannotate (const ())
$ case c of
CastWeakenEffect eff -> CastWeakenEffect $ T.anonymizeT eff
CastWeakenClosure clo -> CastWeakenClosure $ map T.anonymizeX clo
CastPurify wit -> CastPurify wit
CastForget wit -> CastForget wit
CastBox -> CastBox
CastRun -> CastRun
eqWit :: Ord n => Witness a n -> Witness a n -> Bool
eqWit lw rw
= T.reannotate (const ()) lw
== T.reannotate (const ()) rw
type VarSet n = Set.Set n
type Subst n = Map.Map n (Type n)
matchT :: Ord n
=> Type n
-> Type n
-> VarSet n
-> Subst n
-> Maybe (Subst n)
matchT t1 t2 vs subst
= let t1' = unpackSumT $ crushSomeT t1
t2' = unpackSumT $ crushSomeT t2
in case (t1', t2') of
(TCon tc1, TCon tc2)
| tc1 == tc2
-> Just subst
(TApp t11 t12, TApp t21 t22)
-> matchT t11 t21 vs subst >>= matchT t12 t22 vs
(TSum ts1, TSum ts2)
-> let ts1' = Sum.toList ts1
ts2' = Sum.toList ts2
go (l:ls) (r:rs) s = matchT l r vs s >>= go ls rs
go _ _ s = Just s
in if length ts1' /= length ts2'
then Nothing
else go ts1' ts2' subst
(TVar (UName n), _)
| Set.member n vs
, Nothing <- Map.lookup n subst
-> Just $ Map.insert n t2' subst
| Set.member n vs
, Just t1'' <- Map.lookup n subst
, TE.equivT t1'' t2'
-> Just subst
(TVar (UName n), TVar v2)
| not $ Set.member n vs
, UName n == v2
-> Just subst
(TVar (UIx i), TVar v2)
| UIx i == v2
-> Just subst
(TVar (UPrim n t), TVar v2)
| UPrim n t == v2
-> Just subst
(_, _) -> Nothing
unpackSumT :: Type n -> Type n
unpackSumT (TSum ts)
| [t] <- Sum.toList ts = t
unpackSumT tt = tt