module Language.Haskell.TH.Alpha (
areExpAEq,
expEqual,
(@=),
AlphaEq(..)
) where
import Language.Haskell.TH
import Language.Haskell.TH.Syntax (Quasi)
import Language.Haskell.TH.Desugar
import Data.Function (on)
import Control.Monad.State
import Control.Monad.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.Morph
import Data.Maybe (isJust)
import qualified Data.Map as Map
import Control.Applicative
type Lookup = (Map.Map Name Int, Map.Map Name Int, Int)
emptyLookup :: Lookup
emptyLookup = (Map.empty, Map.empty, 0)
data LookupTbl = LookupTbl
{ insertLR :: Name -> Name -> LookupTbl
, eqInTbl :: Name -> Name -> Bool
, isInL :: Name -> Bool
, isInR :: Name -> Bool
}
mapLookup :: Lookup -> LookupTbl
mapLookup !(ls,rs,cnt) = LookupTbl
{ insertLR = \a b -> mapLookup (Map.insert a cnt ls,
Map.insert b cnt rs,
cnt + 1)
, eqInTbl = \a b -> Map.lookup a ls == Map.lookup b rs
, isInL = \a -> isJust $ Map.lookup a ls
, isInR = \b -> isJust $ Map.lookup b rs
}
newtype LookupSTM m b = LookupST {
unLookupST :: StateT LookupTbl (MaybeT m) b
} deriving (Functor, Applicative, Monad, MonadState LookupTbl
, MonadPlus, Alternative)
instance MonadTrans (LookupSTM) where
lift m = LookupST $ StateT (\tbl -> MaybeT $ m >>= \x -> return $ Just (x, tbl))
hoist' :: (Monad m) => (forall a . m a -> n a) -> LookupSTM m b -> LookupSTM n b
hoist' nat lkstm = LookupST $ StateT (\tbl -> MaybeT . nat . runMaybeT $ runStateT (unLookupST lkstm) tbl)
instance MFunctor LookupSTM where
hoist = hoist'
toQ :: LookupST b -> LookupSTQ b
toQ = hoist generalize
type LookupST b = LookupSTM Identity b
type LookupSTQ b = LookupSTM Q b
runLookupST :: Monad m => LookupSTM m a -> LookupTbl -> m (Maybe (a, LookupTbl))
runLookupST st tbl = runMaybeT $ runStateT (unLookupST st) tbl
runLookupST' :: LookupST a -> LookupTbl -> Maybe (a, LookupTbl)
runLookupST' = (runIdentity .) . runLookupST
class AlphaEq a m | a -> m where
lkEq :: a -> a -> LookupSTM m ()
(@=) :: (Monad m, AlphaEq a m) => a -> a -> m Bool
x @= y = liftM isJust $ runLookupST (lkEq x y) (mapLookup emptyLookup)
infix 4 @=
areExpAEq :: Quasi m
=> ExpQ
-> ExpQ
-> m Bool
areExpAEq e1 e2 = let expM = (join .) . liftM2 expEqual
in expM (runQ e1) (runQ e2)
instance AlphaEq Exp Q where
lkEq e1 e2 = do
e1' <- lift $ dsExp e1
e2' <- lift $ dsExp e2
toQ $ expEqual' e1' e2'
expEqual :: Quasi m => Exp -> Exp -> m Bool
expEqual t1 t2 = do
t1' <- dsExp t1
t2' <- dsExp t2
let lkt = mapLookup emptyLookup
return $ isJust $ runLookupST' (lkEq t1' t2') lkt
instance AlphaEq DExp Identity where
lkEq = expEqual'
expEqual' :: DExp -> DExp -> LookupST ()
expEqual' (DVarE a1 ) (DVarE a2 ) = a1 ~=~ a2
expEqual' (DConE a1 ) (DConE a2 ) = a1 ~=~ a2
expEqual' (DLitE l1 ) (DLitE l2 ) = guard $ l1 == l2
expEqual' (DAppE a1 b1 ) (DAppE a2 b2 ) = lkEq a1 a2 >> lkEq b1 b2
expEqual' (DLamE a1 b1 ) (DLamE a2 b2 ) = do
guard $ ((==) `on` length) a1 a2
zipWithM_ insertLRLST a1 a2
lkEq b1 b2
return ()
expEqual' (DCaseE a1 b1) (DCaseE a2 b2) = do
guard $ length b1 == length b2
lkEq a1 a2
zipWithM_ lkEq b1 b2
return ()
expEqual' (DLetE a1 b1 ) (DLetE a2 b2 ) = zipWithM_ lkEq a1 a2 >> lkEq b1 b2
expEqual' (DSigE a1 b1 ) (DSigE a2 b2 ) = lkEq a1 a2 >> lkEq b1 b2
expEqual' _ _ = mzero
instance AlphaEq DMatch Identity where
lkEq = matchEqual
matchEqual :: DMatch -> DMatch -> LookupST ()
matchEqual (DMatch pat1 exp1) (DMatch pat2 exp2) = lkEq pat1 pat2
>> lkEq exp1 exp2
instance AlphaEq DLetDec Identity where
lkEq = letDecEqual
letDecEqual :: DLetDec -> DLetDec -> LookupST ()
letDecEqual (DFunD n1 cls1 ) (DFunD n2 cls2 ) = do
guard $ n1 == n2
zipWithM_ lkEq cls1 cls2
letDecEqual (DValD pat1 exp1 ) (DValD pat2 exp2 ) =
lkEq exp1 exp2 >> lkEq pat1 pat2
letDecEqual (DSigD _name1 typ1) (DSigD _name2 typ2) =
lkEq typ1 typ2
letDecEqual (DInfixD fx1 name1) (DInfixD fx2 name2) = guard $ fx1 == fx2
&& name1 == name2
letDecEqual _ _ = mzero
instance AlphaEq DType Identity where
lkEq = typeEqual
typeEqual :: DType -> DType -> LookupST ()
typeEqual (DForallT tybs1 ctx1 typ1) (DForallT tybs2 ctx2 typ2) = do
zipWithM_ insertLRLSTty tybs1 tybs2
zipWithM_ lkEq ctx1 ctx2
lkEq typ1 typ2
typeEqual (DAppT ty1 arg1 ) (DAppT ty2 arg2 ) =
lkEq ty1 ty2 >> lkEq arg1 arg2
typeEqual (DSigT ty1 knd1 ) (DSigT ty2 knd2 ) = do
guard $ show knd1 == show knd2
lkEq ty1 ty2
typeEqual (DConT n1 ) (DConT n2 ) =
guard $ show n1 == show n2
typeEqual (DVarT n1 ) (DVarT n2 ) =
n1 ~=~ n2
typeEqual _ _ = mzero
instance AlphaEq DKind Identity where
lkEq = kindEqual
kindEqual :: DKind -> DKind -> LookupST ()
kindEqual (DForallK ns1 typ1 ) (DForallK ns2 typ2 ) = do
zipWithM_ insertLRLST ns1 ns2
lkEq typ1 typ2
kindEqual (DVarK n1 ) (DVarK n2 ) = n1 ~=~ n2
kindEqual (DArrowK knda1 kndb1) (DArrowK knda2 kndb2) = lkEq knda1 knda2
>> lkEq kndb1 kndb2
kindEqual DStarK DStarK = return ()
kindEqual _ _ = mzero
instance AlphaEq DClause Identity where
lkEq = clauseEqual
clauseEqual :: DClause -> DClause -> LookupST ()
clauseEqual (DClause pats1 exp1) (DClause pats2 exp2) =
zipWithM_ lkEq pats1 pats2 >> lkEq exp1 exp2
instance AlphaEq DPred Identity where
lkEq = predEqual
predEqual :: DPred -> DPred -> LookupST ()
predEqual (DAppPr pred1 typ1 ) (DAppPr pred2 typ2 ) = lkEq pred1 pred2
>> lkEq typ1 typ2
predEqual (DSigPr pred1 kind1) (DSigPr pred2 kind2) = lkEq pred1 pred2
>> lkEq kind1 kind2
predEqual (DVarPr n1 ) (DVarPr n2 ) = n1 ~=~ n2
predEqual (DConPr n1 ) (DConPr n2 ) = n1 ~=~ n2
predEqual _ _ = mzero
instance AlphaEq DPat Identity where
lkEq = patEqual
patEqual :: DPat -> DPat -> LookupST ()
patEqual (DLitPa lit1 ) (DLitPa lit2 ) = guard $ lit1 == lit2
patEqual (DVarPa n1 ) (DVarPa n2 ) = insertLRLST n1 n2
patEqual (DConPa n1 p1 ) (DConPa n2 p2 ) = do
n1 ~=~ n2
guard $ length p1 == length p2
zipWithM_ lkEq p1 p2
patEqual (DTildePa pat1) (DTildePa pat2) = lkEq pat1 pat2
patEqual (DBangPa pat1 ) (DBangPa pat2 ) = lkEq pat1 pat2
patEqual DWildPa DWildPa = return ()
patEqual _ _ = mzero
(~=~) :: Name -> Name -> LookupST ()
a ~=~ b = do
tbl <- get
guard $ eqInTbl tbl a b
bol <- isInL' a
unless bol $ guard $ show a == show b
isInL' :: Name -> LookupST Bool
isInL' n = do
tbl <- get
return $ isInL tbl n
insertLRLST :: Name -> Name -> LookupST ()
insertLRLST a b = modify $ \tbl -> insertLR tbl a b
insertLRLSTty :: DTyVarBndr -> DTyVarBndr -> LookupST ()
insertLRLSTty (DPlainTV n1 ) (DPlainTV n2 ) = insertLRLST n1 n2
insertLRLSTty (DKindedTV n1 k1) (DKindedTV n2 k2) = do
guard $ show k1 == show k2
insertLRLST n1 n2
insertLRLSTty _ _ = mzero