{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module AlgebraCheckers.Unification where
import Control.Applicative
import Control.Monad.State
import Control.Monad.Trans.Writer
import Data.Data
import Data.Function
import Data.Generics.Aliases
import Data.Generics.Schemes
import qualified Data.Map as M
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Prelude hiding (exp)
import {-# SOURCE #-} AlgebraCheckers.Types
data SubExp = SubExp
{ seExp :: Exp
, seSubId :: Int
} deriving (Eq, Ord, Show)
deModuleName :: Data a => a -> a
deModuleName = everywhere $ mkT $ \case
NameQ _ -> NameS
NameG _ _ _ -> NameS
n -> n
unboundVars :: Exp -> [Name]
unboundVars = everything (++) $
mkQ [] $ \case
UnboundVarE n -> [n]
_ -> []
bindVars :: Data a => M.Map Name Exp -> a -> a
bindVars m = everywhere $ mkT $ \case
e@(UnboundVarE n) ->
case M.lookup n m of
Just e' -> e'
Nothing -> e
t -> t
rebindVars :: Data a => M.Map Name Exp -> a -> a
rebindVars m = everywhere $ mkT $ \case
e@(VarE n) ->
case M.lookup n m of
Just e' -> e'
Nothing -> e
t -> t
renameVars :: Data a => (String -> String) -> a -> a
renameVars f = everywhere $ mkT $ \case
UnboundVarE n -> UnboundVarE . mkName . f $ nameBase n
t -> t
type Subst = M.Map Name Exp
sub :: Data a => Subst -> a -> a
sub = bindVars
unifySub :: Subst -> Exp -> Exp -> Maybe Subst
unifySub s a
= fmap ((M.map =<< sub) . flip mappend s)
. on unify (sub s) a
type Critical = (Exp, Exp)
criticalPairs :: Law a -> Law a -> [Critical]
criticalPairs other me = do
let (otherlhs, otherrhs)
= renameVars (++ "1") (lawLhsExp other, lawRhsExp other)
(melhs, merhs)
= renameVars (++ "2") (lawLhsExp me, lawRhsExp me)
pat <- subexps melhs
Just subs <- pure $ unify (seExp pat) otherlhs
let res = bindVars subs (merhs, replaceSubexp pat (const otherrhs) melhs)
guard $ uncurry (/=) res
let (a,b) = res
pure (min a b, max a b)
subexps :: Exp -> [SubExp]
subexps e =
flip evalState 0 $ execWriterT $
everywhereM (mkM $ \e' -> do
ix <- get
modify (+1)
case e' of
UnboundVarE _ -> pure ()
se -> tell [(SubExp se ix)]
pure e'
) e
replaceSubexp :: SubExp -> (Exp -> Exp) -> Exp -> Exp
replaceSubexp (SubExp _ ix) f old =
flip evalState 0 $
everywhereM (mkM $ \e' -> do
ix' <- get
modify (+1)
pure $ case ix == ix' of
True -> f e'
False -> e'
) old
equalUpToAlpha :: Exp -> Exp -> Bool
equalUpToAlpha a b =
maybe
False
(\subst -> all isUnbound subst
&& uncurry (==) (bindVars subst (a, b)))
(unify a b)
where
isUnbound (UnboundVarE _) = True
isUnbound _ = False
unify :: Exp -> Exp -> Maybe Subst
unify (ParensE exp1) exp2 = unify exp1 exp2
unify exp1 (ParensE exp2) = unify exp1 exp2
unify (UnboundVarE name) exp = pure $ M.singleton name exp
unify exp (UnboundVarE name) = pure $ M.singleton name exp
unify (AppE f1 exp1) (AppE f2 exp2) = do
s1 <- unify f1 f2
s2 <- unifySub s1 exp1 exp2
pure s2
unify (AppTypeE exp1 t1) (AppTypeE exp2 t2) = do
guard $ t1 == t2
unify exp1 exp2
unify (InfixE (Just lhs1) exp1 (Just rhs1))
(InfixE (Just lhs2) exp2 (Just rhs2)) = do
s1 <- unify exp1 exp2
s2 <- unifySub s1 lhs1 lhs2
s3 <- unifySub s2 rhs1 rhs2
pure s3
unify (InfixE Nothing exp1 (Just rhs1))
(InfixE Nothing exp2 (Just rhs2)) = do
s1 <- unify exp1 exp2
unifySub s1 rhs1 rhs2
unify (InfixE (Just lhs1) exp1 Nothing)
(InfixE (Just lhs2) exp2 Nothing) = do
s1 <- unify lhs1 lhs2
unifySub s1 exp1 exp2
unify (TupE exps1) (TupE exps2) = do
guard $ exps1 == exps2
foldM (uncurry . unifySub) mempty $ zip exps1 exps2
unify (CondE cond1 then1 else1) (CondE cond2 then2 else2) = do
s1 <- unify cond1 cond2
s2 <- unifySub s1 then1 then2
unifySub s2 else1 else2
unify (SigE exp1 t1) (SigE exp2 t2) = do
guard $ t1 == t2
unify exp1 exp2
unify a b = do
guard $ a == b
pure mempty